Pyspark:如何编码复杂数据帧算法问题(根据条件求和) [英] Pyspark: How to code Complicated Dataframe algorithm problem (summing with condition)

查看:94
本文介绍了Pyspark:如何编码复杂数据帧算法问题(根据条件求和)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个看起来像这样的数据框:

I have a dataframe looks like this:

TEST_schema = StructType([StructField("date", StringType(), True),\
                              StructField("Trigger", StringType(), True),\
                              StructField("value", FloatType(), True),\
                              StructField("col1", IntegerType(), True),
                             StructField("col2", IntegerType(), True),
                             StructField("want", FloatType(), True)])
TEST_data = [('2020-08-01','T',0.0,3,5,0.5),('2020-08-02','T',0.0,-1,4,0.0),('2020-08-03','T',0.0,-1,3,0.0),('2020-08-04','F',0.2,3,3,0.7),('2020-08-05','T',0.3,1,4,0.9),\
                 ('2020-08-06','F',0.2,-1,3,0.0),('2020-08-07','T',0.2,-1,4,0.0),('2020-08-08','T',0.5,-1,5,0.0),('2020-08-09','T',0.0,-1,5,0.0)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 

+----------+-------+-----+----+----+
|      date|Trigger|value|col1|col2|
+----------+-------+-----+----+----+
|2020-08-01|      T|  0.0|   3|   5| 
|2020-08-02|      T|  0.0|  -1|   4| 
|2020-08-03|      T|  0.0|  -1|   3| 
|2020-08-04|      F|  0.2|   3|   3| 
|2020-08-05|      T|  0.3|   1|   4|
|2020-08-06|      F|  0.2|  -1|   3|
|2020-08-07|      T|  0.2|  -1|   4|
|2020-08-08|      T|  0.5|  -1|   5| 
|2020-08-09|      T|  0.0|  -1|   5|
+----------+-------+-----+----+----+

date:排序很好

Trigger:仅 T F

value:任意随机的十进制(浮点)值

value : any random decimal (float) value

col1:表示天数,并且不能小于-1.** -1 <= col1<无限**

col1 : represents number of days and can not be lower than -1.** -1<= col1 < infinity**

col2:表示天数,不能为负数. col2> = 0

col2 : represents number of days and cannot be negative. col2 >= 0

**计算逻辑**

如果为col1 == -1, then return 0,否则为Trigger == T,下图将有助于理解逻辑.

If col1 == -1, then return 0, otherwise if Trigger == T, the following diagram will help to understand the logic.

如果我们查看红色",则+3来自col1,即2020-08-01的col1==3,这意味着我们跳了3行,同时也取了差异(col2 - col1) -1 = ( 5-3) -1 = 1.(在2020-08-01时) 1 表示对下一个值0.2 + 0.3 = 0.5求和.相同的逻辑适用于蓝色"

If we look at "red color", +3 came from col1 which is col1==3 at 2020-08-01, what it means is that we jump 3 rows,and at the same time also take the difference (col2 - col1) -1 = ( 5-3) -1 = 1. (at 2020-08-01) 1 represents summing the next value which is 0.2 + 0.3 = 0.5. same logic apply for "blue color"

绿色"表示用于trigger == "F"时仅取(col2 -1)=3-1 =2(2020-08-04), 2 表示接下来两个值的总和.这是0.2+0.3+0.2 = 0.7

The "green color" is for when trigger == "F" then just take (col2 -1)=3-1 =2 (2020-08-04), 2 represent sum of next two values. which is 0.2+0.3+0.2 = 0.7

如果我根本不想要任何条件,假设我们有这个 df

What if I want no conditions at all, let's say we have this df

TEST_schema = StructType([StructField("date", StringType(), True),\
                              StructField("value", FloatType(), True),\
                             StructField("col2", IntegerType(), True)])
TEST_data = [('2020-08-01',0.0,5),('2020-08-02',0.0,4),('2020-08-03',0.0,3),('2020-08-04',0.2,3),('2020-08-05',0.3,4),\
                 ('2020-08-06',0.2,3),('2020-08-07',0.2,4),('2020-08-08',0.5,5),('2020-08-09',0.0,5)]
rdd3 = sc.parallelize(TEST_data)
TEST_df = sqlContext.createDataFrame(TEST_data, TEST_schema)
TEST_df = TEST_df.withColumn("date",to_date("date", 'yyyy-MM-dd'))
TEST_df.show() 


+----------+-----+----+
|      date|value|col2|
+----------+-----+----+
|2020-08-01|  0.0|   5|
|2020-08-02|  0.0|   4|
|2020-08-03|  0.0|   3|
|2020-08-04|  0.2|   3|
|2020-08-05|  0.3|   4|
|2020-08-06|  0.2|   3|
|2020-08-07|  0.2|   4|
|2020-08-08|  0.5|   5|
|2020-08-09|  0.0|   5|
+----------+-----+----+

当我们有触发器=="F"时,同样的逻辑适用.条件,所以col2 -1但在这种情况下没有条件.

Same logic applies for when we had Trigger == "F" condition, so col2 -1 but no condition in this case.

推荐答案

IIUC,我们可以使用Windows函数collect_list获取所有相关行,按date对结构数组进行排序,然后基于切片这个数组.可以根据以下条件定义每个切片 start_idx span :

IIUC, we can use Windows function collect_list to get all related rows, sort the array of structs by date and then do the aggregation based on a slice of this array. the start_idx and span of each slice can be defined based on the following:

  1. 如果 col1 = -1 ,则 start_idx = 1 span = 0 ,因此没有任何汇总
  2. 否则,如果 Trigger = 'F',则 start_idx = 1 span = col2
  3. else start_idx = col1 + 1 span = col2-col1
  1. If col1 = -1, start_idx = 1 and span = 0, so nothing is aggregated
  2. else if Trigger = 'F', then start_idx = 1 and span = col2
  3. else start_idx = col1+1 and span = col2-col1

请注意,功能片的索引是基于 1的.

Notice that the index for the function slice is 1-based.

代码:

from pyspark.sql.functions import to_date, sort_array, collect_list, struct, expr
from pyspark.sql import Window

w1 = Window.orderBy('date').rowsBetween(0, Window.unboundedFollowing)

# columns used to do calculations, date must be the first field for sorting purpose
cols = ["date", "value", "start_idx", "span"]

df_new = (TEST_df
    .withColumn('start_idx', expr("IF(col1 = -1 OR Trigger = 'F', 1, col1+1)")) 
    .withColumn('span', expr("IF(col1 = -1, 0, IF(Trigger = 'F', col2, col2-col1))")) 
    .withColumn('dta', sort_array(collect_list(struct(*cols)).over(w1))) 
    .withColumn("want1", expr("aggregate(slice(dta,start_idx,span), 0D, (acc,x) -> acc+x.value)"))
)

结果:

df_new.show()
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|      date|Trigger|value|col1|col2|want|start_idx|span|                 dta|             want1|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+
|2020-08-01|      T|  0.0|   3|   5| 0.5|        4|   2|[[2020-08-01, T, ...|0.5000000149011612|
|2020-08-02|      T|  0.0|  -1|   4| 0.0|        1|   0|[[2020-08-02, T, ...|               0.0|
|2020-08-03|      T|  0.0|  -1|   3| 0.0|        1|   0|[[2020-08-03, T, ...|               0.0|
|2020-08-04|      F|  0.2|   3|   3| 0.7|        1|   3|[[2020-08-04, F, ...|0.7000000178813934|
|2020-08-05|      T|  0.3|   1|   4| 0.9|        2|   3|[[2020-08-05, T, ...|0.9000000059604645|
|2020-08-06|      F|  0.2|  -1|   3| 0.0|        1|   0|[[2020-08-06, F, ...|               0.0|
|2020-08-07|      T|  0.2|  -1|   4| 0.0|        1|   0|[[2020-08-07, T, ...|               0.0|
|2020-08-08|      T|  0.5|  -1|   5| 0.0|        1|   0|[[2020-08-08, T, ...|               0.0|
|2020-08-09|      T|  0.0|  -1|   5| 0.0|        1|   0|[[2020-08-09, T, ...|               0.0|
+----------+-------+-----+----+----+----+---------+----+--------------------+------------------+

一些说明:

  1. 切片函数除定位数组外还需要两个参数.在我们的代码中,start_idx是起始索引,而span是切片的长度.在代码中,我使用 IF 语句根据原始帖子中的图表规格来计算 start_idx span .

  1. The slice function requires two parameters besides the targeting array. in our code, start_idx is the starting index and span is the length of the slice. In the code, I use IF statements to calculate start_idx and span based on the diagram specs in your original post.

collect_list + sort_array 在Windows w1上生成的数组覆盖了从当前行到Window末尾的行(请参见w1任务).然后,我们使用 aggregate 函数内的 slice 函数来仅检索必要的数组项.

The resulting arrays from collect_list + sort_array over a Window w1 cover rows from the current row till the end of the Window(see w1 assignment). we then use slice function inside the aggregate function to retrieve only necessary array items.

SparkSQL内置函数聚合采用以下形式:

the SparkSQL builtin function aggregate takes the following form:

 aggregate(expr, start, merge, finish) 

可以跳过第四个参数finish的位置.在我们的例子中,可以将其重新格式化为(您可以复制以下内容以替换 expr .withColumn('want1', expr(""" .... """)中的代码):

where the 4th argument finish can be skipped. in our case, it can be reformatted as (you can copy the following to replace the code inside expr .withColumn('want1', expr(""" .... """)):

 aggregate(
   /* targeting array, use slice function to take only part of the array `dta` */
   slice(dta,start_idx,span), 
   /* start, zero_value used for reduce */
   0D, 
   /* merge, similar to reduce function */
   (acc,x) -> acc+x.value,
   /* finish, skipped in the post, but you can do some post-processing here, for example, round-up the result from merge */
   acc -> round(acc, 2)
 )

aggregate 函数的工作方式类似于Python中的 reduce 函数,第二个参数是零值(0Ddouble(0)的快捷方式,用于强制转换聚合变量acc的数据类型.

aggregate function works like the reduce function in Python, the 2nd argument is the zero value (0D is the shortcut for double(0) which is to typecast the data type of the aggregation variable acc).

,如果 col2< col1 其中 Trigger = 'T' col1 != -1 存在,它将产生当前代码中的 span 为负数.在这种情况下,我们应该使用全尺寸的Window规范:

as mentioned in the comments, if col2 < col1 where Trigger = 'T' and col1 != -1 exists, it will yield a negative span in the current code. In such case, we should use a full-size Window spec:

 w1 = Window.orderBy('date').rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)        

并使用 array_position 查找当前行的位置(

and use array_position to find the position of the current row (refer to one of my recent posts) and then calculate start_idx based on this position.

这篇关于Pyspark:如何编码复杂数据帧算法问题(根据条件求和)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆