如何在PySpark中计算具有变化的窗口大小的滚动总和 [英] How to calculate rolling sum with varying window sizes in PySpark

查看:83
本文介绍了如何在PySpark中计算具有变化的窗口大小的滚动总和的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个Spark数据框,其中包含一段时间内某些商店中某些产品的销售预测数据.对于下一个N个值的窗口大小,如何计算预测"的滚动总和?

I have a spark dataframe that contains sales prediction data for some products in some stores over a time period. How do I calculate the rolling sum of Predictions for a window size of next N values?

输入数据

+-----------+---------+------------+------------+---+
| ProductId | StoreId |    Date    | Prediction | N |
+-----------+---------+------------+------------+---+
|         1 |     100 | 2019-07-01 | 0.92       | 2 |
|         1 |     100 | 2019-07-02 | 0.62       | 2 |
|         1 |     100 | 2019-07-03 | 0.89       | 2 |
|         1 |     100 | 2019-07-04 | 0.57       | 2 |
|         2 |     200 | 2019-07-01 | 1.39       | 3 |
|         2 |     200 | 2019-07-02 | 1.22       | 3 |
|         2 |     200 | 2019-07-03 | 1.33       | 3 |
|         2 |     200 | 2019-07-04 | 1.61       | 3 |
+-----------+---------+------------+------------+---+

预期的输出数据

+-----------+---------+------------+------------+---+------------------------+
| ProductId | StoreId |    Date    | Prediction | N |       RollingSum       |
+-----------+---------+------------+------------+---+------------------------+
|         1 |     100 | 2019-07-01 | 0.92       | 2 | sum(0.92, 0.62)        |
|         1 |     100 | 2019-07-02 | 0.62       | 2 | sum(0.62, 0.89)        |
|         1 |     100 | 2019-07-03 | 0.89       | 2 | sum(0.89, 0.57)        |
|         1 |     100 | 2019-07-04 | 0.57       | 2 | sum(0.57)              |
|         2 |     200 | 2019-07-01 | 1.39       | 3 | sum(1.39, 1.22, 1.33)  |
|         2 |     200 | 2019-07-02 | 1.22       | 3 | sum(1.22, 1.33, 1.61 ) |
|         2 |     200 | 2019-07-03 | 1.33       | 3 | sum(1.33, 1.61)        |
|         2 |     200 | 2019-07-04 | 1.61       | 3 | sum(1.61)              |
+-----------+---------+------------+------------+---+------------------------+

在Python中有很多问题和答案,但是在PySpark中找不到任何问题.

There are lots of questions and answers to this problem in Python but I couldn't find any in PySpark.

类似问题1
此处中有一个类似的问题,但在此问题中帧大小固定为3.在提供的答案rangeBetween函数中使用,它仅适用于固定大小的帧,因此我不能将其用于变化的大小.

Similar Question 1
There is a similar question here but in this one frame size is fixed to 3. In the provided answer rangeBetween function is used and it is only working with fixed sized frames so I cannot use it for varying sizes.

类似问题2
此处,还有一个类似的问题.在这一本书中,建议为所有可能的大小编写案例,但不适用于我的案例,因为我不知道我需要计算多少个不同的帧大小.

Similar Question 2
There is also a similar question here. In this one, writing cases for all possible sizes is suggested but it is not applicable for my case since I don't know how many distinct frame sizes I need to calculate.

解决方案尝试1
我尝试使用熊猫udf解决问题:

Solution attempt 1
I've tried to solve the problem using a pandas udf:

rolling_sum_predictions = predictions.groupBy('ProductId', 'StoreId').apply(calculate_rolling_sums)

calculate_rolling_sums 是熊猫udf,我在python中解决了这个问题.该解决方案适用于少量测试数据.但是,当数据变大时(在我的情况下,输入df大约有1B行),计算将花费很长时间.

calculate_rolling_sums is a pandas udf where I solve the problem in python. This solution works with a small amount of test data. However, when the data gets bigger (in my case, the input df has around 1B rows), calculations take so long.

解决方案尝试2
我已经使用了上述类似问题1 的答案的变通办法.我计算了最大可能的N,使用它创建了列表,然后通过对列表进行切片来计算预测的总和.

Solution attempt 2
I have used a workaround of the answer of Similar Question 1 above. I've calculated the biggest possible N, created the list using it and then calculate the sum of predictions by slicing the list.

predictions = predictions.withColumn('DayIndex', F.rank().over(Window.partitionBy('ProductId', 'StoreId').orderBy('Date')))

# find the biggest period
biggest_period = predictions.agg({"N": "max"}).collect()[0][0]

# calculate rolling predictions starting from the DayIndex
w = (Window.partitionBy(F.col("ProductId"), F.col("StoreId")).orderBy(F.col('DayIndex')).rangeBetween(0, biggest_period - 1))
rolling_prediction_lists = predictions.withColumn("next_preds", F.collect_list("Prediction").over(w))

# calculate rolling forecast sums
pred_sum_udf = udf(lambda preds, period: float(np.sum(preds[:period])), FloatType())
rolling_pred_sums = rolling_prediction_lists \
    .withColumn("RollingSum", pred_sum_udf("next_preds", "N"))

此解决方案也适用于测试数据.我还没有机会用原始数据对其进行测试,但是无论它是否有效,我都不喜欢这种解决方案.有没有更聪明的方法来解决这个问题?

This solution is also works with the test data. I couldn't have chance to test it with the original data yet but whether it works or not I do not like this solution. Is there any smarter way to solve this?

推荐答案

如果您使用的是Spark 2.4+,则可以使用新的

If you're using spark 2.4+, you can use the new higher-order array functions slice and aggregate to efficiently implement your requirement without any UDFs:

summed_predictions = predictions\
   .withColumn("summed", F.collect_list("Prediction").over(Window.partitionBy("ProductId", "StoreId").orderBy("Date").rowsBetween(Window.currentRow, Window.unboundedFollowing))\
   .withColumn("summed", F.expr("aggregate(slice(summed,1,N), cast(0 as double), (acc,d) -> acc + d)"))

summed_predictions.show()
+---------+-------+-------------------+----------+---+------------------+
|ProductId|StoreId|               Date|Prediction|  N|            summed|
+---------+-------+-------------------+----------+---+------------------+
|        1|    100|2019-07-01 00:00:00|      0.92|  2|              1.54|
|        1|    100|2019-07-02 00:00:00|      0.62|  2|              1.51|
|        1|    100|2019-07-03 00:00:00|      0.89|  2|              1.46|
|        1|    100|2019-07-04 00:00:00|      0.57|  2|              0.57|
|        2|    200|2019-07-01 00:00:00|      1.39|  3|              3.94|
|        2|    200|2019-07-02 00:00:00|      1.22|  3|              4.16|
|        2|    200|2019-07-03 00:00:00|      1.33|  3|2.9400000000000004|
|        2|    200|2019-07-04 00:00:00|      1.61|  3|              1.61|
+---------+-------+-------------------+----------+---+------------------+

这篇关于如何在PySpark中计算具有变化的窗口大小的滚动总和的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆