PySpark从其数组对象值获取相关记录 [英] PySpark get related records from its array object values
问题描述
我有一个带有ID列的spark数据框,它与其他列一起具有一个数组列,该列包含其相关记录的ID作为其值.
I have a spark dataframe that has an ID column and along with other columns, it has an array column that contains the IDs of its related records, as its value.
示例数据框将是
ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456]
345 | alen | [789]
456 | sam | [789,999]
789 | marc | [111]
555 | dan | [333]
从以上所述,我需要将所有相关的子ID附加到父ID的数组列中.生成的DF应该像
From the above, I need to append all the related child Id's to the array column of the parent ID. The resultant DF should be like
ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456,789,999,111]
345 | alen | [789,111]
456 | sam | [789,999,111]
789 | marc | [111]
555 | dan | [333]
需要有关如何做的帮助.谢谢
need help on how to do it. thanks
推荐答案
处理此任务的一种方法是自行完成 leftjoin ,更新 RELATED_IDLIST ,请执行几次迭代,直到满足某些条件为止(仅当整个层次结构的最大深度较小时才起作用).对于 Spark 2.3 ,我们可以将ArrayType列转换为逗号分隔的StringType列,并使用SQL内置函数
One way to handle this task is to do self leftjoin, update the RELATED_IDLIST, do this several iterations until some conditions satisfy (this works only when the max-depth of the whole hierarchy is small). For Spark 2.3, we can convert the ArrayType column into a comma-delimitered StringType column, use SQL builtin function find_in_set and a new column PROCESSED_IDLIST
to set up the join-conditions, see below for the main logic:
功能
from pyspark.sql import functions as F
import pandas as pd
# define a function which takes a dataframe as input, does a self left-join and then return another
# dataframe with exactly the same schema as the input dataframe. do the same repeatly until some conditions satisfy
def recursive_join(d, max_iter=10):
# function to find direct child-IDs and merge into RELATED_IDLIST
def find_child_idlist(_df):
return _df.alias('d1').join(
_df.alias('d2'),
F.expr("find_in_set(d2.ID,d1.RELATED_IDLIST)>0 AND find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1"),
"left"
).groupby("d1.ID", "d1.NAME").agg(
F.expr("""
/* combine d1.RELATED_IDLIST with all matched entries from collect_set(d2.RELATED_IDLIST)
* and remove trailing comma from when all d2.RELATED_IDLIST are NULL */
trim(TRAILING ',' FROM
concat_ws(",", first(d1.RELATED_IDLIST), concat_ws(",", collect_list(d2.RELATED_IDLIST)))
) as RELATED_IDLIST"""),
F.expr("first(d1.RELATED_IDLIST) as PROCESSED_IDLIST")
)
# below the main code logic
d = find_child_idlist(d).persist()
if (d.filter("RELATED_IDLIST!=PROCESSED_IDLIST").count() > 0) & (max_iter > 1):
d = recursive_join(d, max_iter-1)
return d
# define pandas_udf to remove duplicate from an ArrayType column
get_uniq = F.pandas_udf(lambda s: pd.Series([ list(set(x)) for x in s ]), "array<int>")
位置:
-
在函数
find_child_idlist()
中,左联接必须满足以下两个条件:
in the function
find_child_idlist()
, the left-join must satisfy the following two conditions:
- d2.ID 位于 d1.RELATED_IDLIST 中:
find_in_set(d2.ID,d1.RELATED_IDLIST)> 0
- d2.ID 不在 d1.PROCESSED_IDLIST 中:
find_in_set(d2.ID,d1.PROCESSED_IDLIST)< 1
- d2.ID is in d1.RELATED_IDLIST:
find_in_set(d2.ID,d1.RELATED_IDLIST)>0
- d2.ID not in d1.PROCESSED_IDLIST:
find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1
退出recursive_join.1
quit the recursive_join when no row satisfying RELATED_IDLIST!=PROCESSED_IDLIST
or max_iter > 1
正在处理:
-
设置数据框:
set up dataframe:
df = spark.createDataFrame([
(123, "mike", [345,456]), (345, "alen", [789]), (456, "sam", [789,999]),
(789, "marc", [111]), (555, "dan", [333])
],["ID", "NAME", "RELATED_IDLIST"])
添加新列 PROCESSED_IDLIST
,以将 RELATED_IDLIST
保存在上一个联接中,然后执行 recursive_join()
add a new column PROCESSED_IDLIST
to save RELATED_IDLIST
in the previous join, and do recursive_join()
df1 = df.withColumn('RELATED_IDLIST', F.concat_ws(',','RELATED_IDLIST')) \
.withColumn('PROCESSED_IDLIST', F.col('ID'))
df_new = recursive_join(df1, 5)
df_new.show(10,0)
+---+----+-----------------------+-----------------------+
|ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST |
+---+----+-----------------------+-----------------------+
|555|dan |333 |333 |
|789|marc|111 |111 |
|345|alen|789,111 |789,111 |
|123|mike|345,456,789,789,999,111|345,456,789,789,999,111|
|456|sam |789,999,111 |789,999,111 |
+---+----+-----------------------+-----------------------+
将 RELATED_IDLIST
拆分为整数数组,然后使用pandas_udf函数删除重复的数组元素:
split RELATED_IDLIST
into array of integers and then use pandas_udf function to drop duplicate array elements:
df_new.withColumn("RELATED_IDLIST", get_uniq(F.split('RELATED_IDLIST', ',').cast('array<int>'))).show(10,0)
+---+----+-------------------------+-----------------------+
|ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST |
+---+----+-------------------------+-----------------------+
|555|dan |[333] |333 |
|789|marc|[111] |111 |
|345|alen|[789, 111] |789,111 |
|123|mike|[999, 456, 111, 789, 345]|345,456,789,789,999,111|
|456|sam |[111, 789, 999] |789,999,111 |
+---+----+-------------------------+-----------------------+
这篇关于PySpark从其数组对象值获取相关记录的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!