Pyspark:如何编码复杂的数据帧算法问题(带条件求和) [英] Pyspark: How to code Complicated Dataframe algorithm problem (summing with condition)

查看:28
本文介绍了Pyspark:如何编码复杂的数据帧算法问题(带条件求和)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个如下所示的数据框:

I have a dataframe looks like this:

TEST_schema = StructType([StructField("date", StringType(), True),\
                              StructField("Trigger", StringType(), True),\
                              StructField("value", FloatType(), True),\
                              StructField("col1", IntegerType(), True),
                             StructField("col2", IntegerType(), True),
                             StructField("want", FloatType(), True)])
TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02','T',0.0,-1,4,0.0),('2020-08-03','T',0.0,-1,3,0.0),('2020-08-04','F',0.2,3,3,0.7),('2020-08-05','T',0.3,1,4,0.9),\
                 ('2020-08-06','F',0.2,-1,3,0.0),('2020-08-07','T',0.2,-1,4,0.0),('2020-08-08','T',0.5,-1,5,0.0),('2020-08-09','T',0.0,-1,5,0.0)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 

+----------+-------+-----+----+----+
|      date|Trigger|value|col1|col2|
+----------+-------+-----+----+----+
|2020-08-01|      T|  0.0|   3|   5| 
|2020-08-02|      T|  0.0|  -1|   4| 
|2020-08-03|      T|  0.0|  -1|   3| 
|2020-08-04|      F|  0.2|   3|   3| 
|2020-08-05|      T|  0.3|   1|   4|
|2020-08-06|      F|  0.2|  -1|   3|
|2020-08-07|      T|  0.2|  -1|   4|
|2020-08-08|      T|  0.5|  -1|   5| 
|2020-08-09|      T|  0.0|  -1|   5|
+----------+-------+-----+----+----+

date : 排序很好

Trigger : 仅 TF

value : 任何随机十进制(浮点数)值

value : any random decimal (float) value

col1 :代表天数,不能低于-1.** -1<= col1 <;无限**

col1 : represents number of days and can not be lower than -1.** -1<= col1 < infinity**

col2 :表示天数,不能为负数.col2 >= 0

col2 : represents number of days and cannot be negative. col2 >= 0

**计算逻辑**

如果col1 == -1,则返回0,否则如果Trigger == T,下图有助于理解逻辑.

If col1 == -1, then return 0, otherwise if Trigger == T, the following diagram will help to understand the logic.

如果我们看红色",+3 来自 col1,即 col1==3 在 2020-08-01,这意味着我们跳了 3 行,在同时还要取差值(col2 - col1) -1 = ( 5-3) -1 = 1. (at 2020-08-01) 1 表示求和下一个值是 0.2 + 0.3 = 0.5.同样的逻辑适用于蓝色"

If we look at "red color", +3 came from col1 which is col1==3 at 2020-08-01, what it means is that we jump 3 rows,and at the same time also take the difference (col2 - col1) -1 = ( 5-3) -1 = 1. (at 2020-08-01) 1 represents summing the next value which is 0.2 + 0.3 = 0.5. same logic apply for "blue color"

绿色"用于当 trigger == "F" 然后取 (col2 -1)=3-1 =2 (2020-08-04), 2 表示接下来两个值的总和.即 0.2+0.3+0.2 = 0.7

The "green color" is for when trigger == "F" then just take (col2 -1)=3-1 =2 (2020-08-04), 2 represent sum of next two values. which is 0.2+0.3+0.2 = 0.7

如果我根本不需要条件怎么办,假设我们有这个 df

What if I want no conditions at all, let's say we have this df

TEST_schema = StructType([StructField("date", StringType(), True),\
                              StructField("value", FloatType(), True),\
                             StructField("col2", IntegerType(), True)])
TEST_data = [('2020-08-01',0.0,5),('2020-08-02',0.0,4),('2020-08-03',0.0,3),('2020-08-04',0.2,3),('2020-08-05',0.3,4),\
                 ('2020-08-06',0.2,3),('2020-08-07',0.2,4),('2020-08-08',0.5,5),('2020-08-09',0.0,5)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 


+----------+-----+----+
|      date|value|col2|
+----------+-----+----+
|2020-08-01|  0.0|   5|
|2020-08-02|  0.0|   4|
|2020-08-03|  0.0|   3|
|2020-08-04|  0.2|   3|
|2020-08-05|  0.3|   4|
|2020-08-06|  0.2|   3|
|2020-08-07|  0.2|   4|
|2020-08-08|  0.5|   5|
|2020-08-09|  0.0|   5|
+----------+-----+----+

同样的逻辑适用于我们有 Trigger == "F";条件,所以 col2 -1 但在这种情况下没有条件.

Same logic applies for when we had Trigger == "F" condition, so col2 -1 but no condition in this case.

推荐答案

IIUC,我们可以使用Windows函数collect_list来获取所有相关的行,按date对struct数组进行排序code> 然后根据 slice 这个数组.每个切片start_idxspan可以根据以下内容定义:

IIUC, we can use Windows function collect_list to get all related rows, sort the array of structs by date and then do the aggregation based on a slice of this array. the start_idx and span of each slice can be defined based on the following:

  1. 如果 col1 = -1start_idx = 1span =0,所以没有聚合
  2. else if Trigger = 'F',则 start_idx = 1span = col2
  3. else start_idx = col1+1span = col2-col1
  1. If col1 = -1, start_idx = 1 and span = 0, so nothing is aggregated
  2. else if Trigger = 'F', then start_idx = 1 and span = col2
  3. else start_idx = col1+1 and span = col2-col1

请注意,函数切片的索引1-based.

Notice that the index for the function slice is 1-based.

代码:

from pyspark.sql.functions import to_date, sort_array, collect_list, struct, expr
from pyspark.sql import Window

w1 = Window.orderBy('date').rowsBetween(0, Window.unboundedFollowing)

# columns used to do calculations, date must be the first field for sorting purpose
cols = ["date", "value", "start_idx", "span"]

df_new = (TEST_df
    .withColumn('start_idx', expr("IF(col1 = -1 OR Trigger = 'F', 1, col1+1)")) 
    .withColumn('span', expr("IF(col1 = -1, 0, IF(Trigger = 'F', col2, col2-col1))")) 
    .withColumn('dta', sort_array(collect_list(struct(*cols)).over(w1))) 
    .withColumn("want1", expr("aggregate(slice(dta,start_idx,span), 0D, (acc,x) -> acc+x.value)"))
)

结果:

df_new.show()
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|      date|Trigger|value|col1|col2|want|start_idx|span|                 dta|             want1|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|2020-08-01|      T|  0.0|   3|   5| 0.5|        4|   2|[[2020-08-01, T, ...|0.5000000149011612|
|2020-08-02|      T|  0.0|  -1|   4| 0.0|        1|   0|[[2020-08-02, T, ...|               0.0|
|2020-08-03|      T|  0.0|  -1|   3| 0.0|        1|   0|[[2020-08-03, T, ...|               0.0|
|2020-08-04|      F|  0.2|   3|   3| 0.7|        1|   3|[[2020-08-04, F, ...|0.7000000178813934|
|2020-08-05|      T|  0.3|   1|   4| 0.9|        2|   3|[[2020-08-05, T, ...|0.9000000059604645|
|2020-08-06|      F|  0.2|  -1|   3| 0.0|        1|   0|[[2020-08-06, F, ...|               0.0|
|2020-08-07|      T|  0.2|  -1|   4| 0.0|        1|   0|[[2020-08-07, T, ...|               0.0|
|2020-08-08|      T|  0.5|  -1|   5| 0.0|        1|   0|[[2020-08-08, T, ...|               0.0|
|2020-08-09|      T|  0.0|  -1|   5| 0.0|        1|   0|[[2020-08-09, T, ...|               0.0|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+

一些解释:

  1. slice 函数除了定位数组之外还需要两个参数.在我们的代码中,start_idx 是起始索引,span 是切片的长度.在代码中,我使用 IF 语句根据您原始帖子中的图表规范计算 start_idxspan.

  1. The slice function requires two parameters besides the targeting array. in our code, start_idx is the starting index and span is the length of the slice. In the code, I use IF statements to calculate start_idx and span based on the diagram specs in your original post.

collect_list + sort_array 的结果数组在一个窗口 w1 覆盖从当前行到窗口结束的行(参见 w1 分配).然后我们在 aggregate 函数中使用 slice 函数来仅检索必要的数组项.

The resulting arrays from collect_list + sort_array over a Window w1 cover rows from the current row till the end of the Window(see w1 assignment). we then use slice function inside the aggregate function to retrieve only necessary array items.

SparkSQL 内置函数 aggregate 采用以下形式:

the SparkSQL builtin function aggregate takes the following form:

 aggregate(expr, start, merge, finish) 

第四个参数 finish 可以跳过.在我们的例子中,它可以重新格式化为(您可以复制以下内容来替换 expr .withColumn('want1', expr("" ....""")):

where the 4th argument finish can be skipped. in our case, it can be reformatted as (you can copy the following to replace the code inside expr .withColumn('want1', expr(""" .... """)):

 aggregate(
   /* targeting array, use slice function to take only part of the array `dta` */
   slice(dta,start_idx,span), 
   /* start, zero_value used for reduce */
   0D, 
   /* merge, similar to reduce function */
   (acc,x) -> acc+x.value,
   /* finish, skipped in the post, but you can do some post-processing here, for example, round-up the result from merge */
   acc -> round(acc, 2)
 )

aggregate 函数的工作方式类似于 Python 中的 reduce 函数,第二个参数是零值(0D 的快捷方式)double(0) 即对聚合变量acc的数据类型进行类型转换.

aggregate function works like the reduce function in Python, the 2nd argument is the zero value (0D is the shortcut for double(0) which is to typecast the data type of the aggregation variable acc).

如评论中所述,如果 col2 其中 Trigger = 'T'col1 != -1 存在,它将产生当前代码中的负 span.在这种情况下,我们应该使用全尺寸的 Window 规范:

as mentioned in the comments, if col2 < col1 where Trigger = 'T' and col1 != -1 exists, it will yield a negative span in the current code. In such case, we should use a full-size Window spec:

 w1 = Window.orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)        

并使用 array_position 来查找当前行的位置 (参考我最近的一篇博文),然后根据这个位置计算start_idx.

and use array_position to find the position of the current row (refer to one of my recent posts) and then calculate start_idx based on this position.

这篇关于Pyspark:如何编码复杂的数据帧算法问题(带条件求和)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆