PySpark 从其数组对象值中获取相关记录 [英] PySpark get related records from its array object values
问题描述
我有一个 spark 数据框,它有一个 ID 列和其他列,它有一个数组列,其中包含其相关记录的 ID,作为其值.
I have a spark dataframe that has an ID column and along with other columns, it has an array column that contains the IDs of its related records, as its value.
示例数据框将是
ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456]
345 | alen | [789]
456 | sam | [789,999]
789 | marc | [111]
555 | dan | [333]
从上面可以看出,我需要将所有相关的子 ID 附加到父 ID 的数组列中.得到的 DF 应该是这样的
From the above, I need to append all the related child Id's to the array column of the parent ID. The resultant DF should be like
ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456,789,999,111]
345 | alen | [789,111]
456 | sam | [789,999,111]
789 | marc | [111]
555 | dan | [333]
需要有关如何操作的帮助.谢谢
need help on how to do it. thanks
推荐答案
处理此任务的一种方法是self leftjoin,更新RELATED_IDLIST, 这样做几次迭代,直到满足某些条件(这仅在整个层次结构的 max-depth 很小时才有效).对于 Spark 2.3,我们可以将 ArrayType 列转换为逗号分隔的 StringType 列,使用 SQL 内置函数 find_in_set 和一个新列 PROCESSED_IDLIST
来设置连接条件,主要逻辑见下文:
One way to handle this task is to do self leftjoin, update the RELATED_IDLIST, do this several iterations until some conditions satisfy (this works only when the max-depth of the whole hierarchy is small). For Spark 2.3, we can convert the ArrayType column into a comma-delimitered StringType column, use SQL builtin function find_in_set and a new column PROCESSED_IDLIST
to set up the join-conditions, see below for the main logic:
功能:
from pyspark.sql import functions as F
import pandas as pd
# define a function which takes a dataframe as input, does a self left-join and then return another
# dataframe with exactly the same schema as the input dataframe. do the same repeatly until some conditions satisfy
def recursive_join(d, max_iter=10):
# function to find direct child-IDs and merge into RELATED_IDLIST
def find_child_idlist(_df):
return _df.alias('d1').join(
_df.alias('d2'),
F.expr("find_in_set(d2.ID,d1.RELATED_IDLIST)>0 AND find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1"),
"left"
).groupby("d1.ID", "d1.NAME").agg(
F.expr("""
/* combine d1.RELATED_IDLIST with all matched entries from collect_set(d2.RELATED_IDLIST)
* and remove trailing comma from when all d2.RELATED_IDLIST are NULL */
trim(TRAILING ',' FROM
concat_ws(",", first(d1.RELATED_IDLIST), concat_ws(",", collect_list(d2.RELATED_IDLIST)))
) as RELATED_IDLIST"""),
F.expr("first(d1.RELATED_IDLIST) as PROCESSED_IDLIST")
)
# below the main code logic
d = find_child_idlist(d).persist()
if (d.filter("RELATED_IDLIST!=PROCESSED_IDLIST").count() > 0) & (max_iter > 1):
d = recursive_join(d, max_iter-1)
return d
# define pandas_udf to remove duplicate from an ArrayType column
get_uniq = F.pandas_udf(lambda s: pd.Series([ list(set(x)) for x in s ]), "array<int>")
地点:
在函数
find_child_idlist()
中,left-join必须满足以下两个条件:
in the function
find_child_idlist()
, the left-join must satisfy the following two conditions:
- d2.ID 在 d1.RELATED_IDLIST 中:
find_in_set(d2.ID,d1.RELATED_IDLIST)>0
- d2.ID 不在 d1.PROCESSED_IDLIST 中:
find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1
- d2.ID is in d1.RELATED_IDLIST:
find_in_set(d2.ID,d1.RELATED_IDLIST)>0
- d2.ID not in d1.PROCESSED_IDLIST:
find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1
当没有行满足 RELATED_IDLIST!=PROCESSED_IDLIST
或 max_iter > 时退出 recursive_join1
处理:
设置数据框:
set up dataframe:
df = spark.createDataFrame([
(123, "mike", [345,456]), (345, "alen", [789]), (456, "sam", [789,999]),
(789, "marc", [111]), (555, "dan", [333])
],["ID", "NAME", "RELATED_IDLIST"])
添加一个新列PROCESSED_IDLIST
保存上一个join中的RELATED_IDLIST
,并执行recursive_join()
add a new column PROCESSED_IDLIST
to save RELATED_IDLIST
in the previous join, and do recursive_join()
df1 = df.withColumn('RELATED_IDLIST', F.concat_ws(',','RELATED_IDLIST')) \
.withColumn('PROCESSED_IDLIST', F.col('ID'))
df_new = recursive_join(df1, 5)
df_new.show(10,0)
+---+----+-----------------------+-----------------------+
|ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST |
+---+----+-----------------------+-----------------------+
|555|dan |333 |333 |
|789|marc|111 |111 |
|345|alen|789,111 |789,111 |
|123|mike|345,456,789,789,999,111|345,456,789,789,999,111|
|456|sam |789,999,111 |789,999,111 |
+---+----+-----------------------+-----------------------+
将 RELATED_IDLIST
拆分为整数数组,然后使用 pandas_udf 函数删除重复的数组元素:
split RELATED_IDLIST
into array of integers and then use pandas_udf function to drop duplicate array elements:
df_new.withColumn("RELATED_IDLIST", get_uniq(F.split('RELATED_IDLIST', ',').cast('array<int>'))).show(10,0)
+---+----+-------------------------+-----------------------+
|ID |NAME|RELATED_IDLIST |PROCESSED_IDLIST |
+---+----+-------------------------+-----------------------+
|555|dan |[333] |333 |
|789|marc|[111] |111 |
|345|alen|[789, 111] |789,111 |
|123|mike|[999, 456, 111, 789, 345]|345,456,789,789,999,111|
|456|sam |[111, 789, 999] |789,999,111 |
+---+----+-------------------------+-----------------------+
这篇关于PySpark 从其数组对象值中获取相关记录的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!