PySpark 从其数组对象值中获取相关记录 [英] PySpark get related records from its array object values

查看:18
本文介绍了PySpark 从其数组对象值中获取相关记录的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个 spark 数据框,它有一个 ID 列和其他列,它有一个数组列,其中包含其相关记录的 ID,作为其值.

I have a spark dataframe that has an ID column and along with other columns, it has an array column that contains the IDs of its related records, as its value.

示例数据框将是

ID | NAME | RELATED_IDLIST
--------------------------
123 | mike | [345,456]
345 | alen | [789]
456 | sam  | [789,999]
789 | marc | [111]
555 | dan  | [333]

从上面可以看出,我需要将所有相关的子 ID 附加到父 ID 的数组列中.得到的 DF 应该是这样的

From the above, I need to append all the related child Id's to the array column of the parent ID. The resultant DF should be like

ID | NAME | RELATED_IDLIST
 --------------------------
 123 | mike | [345,456,789,999,111]
 345 | alen | [789,111]
 456 | sam  | [789,999,111]
 789 | marc | [111]
 555 | dan  | [333]
 

需要有关如何操作的帮助.谢谢

need help on how to do it. thanks

推荐答案

处理此任务的一种方法是self leftjoin,更新RELATED_IDLIST, 这样做几次迭代,直到满足某些条件(这仅在整个层次结构的 max-depth 很小时才有效).对于 Spark 2.3,我们可以将 ArrayType 列转换为逗号分隔的 StringType 列,使用 SQL 内置函数 find_in_set 和一个新列 PROCESSED_IDLIST 来设置连接条件,主要逻辑见下文:

One way to handle this task is to do self leftjoin, update the RELATED_IDLIST, do this several iterations until some conditions satisfy (this works only when the max-depth of the whole hierarchy is small). For Spark 2.3, we can convert the ArrayType column into a comma-delimitered StringType column, use SQL builtin function find_in_set and a new column PROCESSED_IDLIST to set up the join-conditions, see below for the main logic:

功能:

from pyspark.sql import functions as F
import pandas as pd

# define a function which takes a dataframe as input, does a self left-join and then return another 
# dataframe with exactly the same schema as the input dataframe. do the same repeatly until some conditions satisfy
def recursive_join(d, max_iter=10):
  # function to find direct child-IDs and merge into RELATED_IDLIST
  def find_child_idlist(_df):
    return _df.alias('d1').join(
        _df.alias('d2'), 
        F.expr("find_in_set(d2.ID,d1.RELATED_IDLIST)>0 AND find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1"),
        "left"
      ).groupby("d1.ID", "d1.NAME").agg(
        F.expr("""
          /* combine d1.RELATED_IDLIST with all matched entries from collect_set(d2.RELATED_IDLIST)
           * and remove trailing comma from when all d2.RELATED_IDLIST are NULL */
          trim(TRAILING ',' FROM
              concat_ws(",", first(d1.RELATED_IDLIST), concat_ws(",", collect_list(d2.RELATED_IDLIST)))
          ) as RELATED_IDLIST"""),
        F.expr("first(d1.RELATED_IDLIST) as PROCESSED_IDLIST")
    )
  # below the main code logic
  d = find_child_idlist(d).persist()
  if (d.filter("RELATED_IDLIST!=PROCESSED_IDLIST").count() > 0) & (max_iter > 1):
    d = recursive_join(d, max_iter-1)
  return d

# define pandas_udf to remove duplicate from an ArrayType column
get_uniq = F.pandas_udf(lambda s: pd.Series([ list(set(x)) for x in s ]), "array<int>")

地点:

  1. 在函数find_child_idlist()中,left-join必须满足以下两个条件:

  1. in the function find_child_idlist(), the left-join must satisfy the following two conditions:

  • d2.IDd1.RELATED_IDLIST 中:find_in_set(d2.ID,d1.RELATED_IDLIST)>0
  • d2.ID 不在 d1.PROCESSED_IDLIST 中:find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1
  • d2.ID is in d1.RELATED_IDLIST: find_in_set(d2.ID,d1.RELATED_IDLIST)>0
  • d2.ID not in d1.PROCESSED_IDLIST: find_in_set(d2.ID,d1.PROCESSED_IDLIST)<1

当没有行满足 RELATED_IDLIST!=PROCESSED_IDLISTmax_iter > 时退出 recursive_join1

处理:

  1. 设置数据框:

  1. set up dataframe:

df = spark.createDataFrame([
  (123, "mike", [345,456]), (345, "alen", [789]), (456, "sam", [789,999]),
  (789, "marc", [111]), (555, "dan", [333])
],["ID", "NAME", "RELATED_IDLIST"])

  • 添加一个新列PROCESSED_IDLIST 保存上一个join中的RELATED_IDLIST,并执行recursive_join()

  • add a new column PROCESSED_IDLIST to save RELATED_IDLIST in the previous join, and do recursive_join()

    df1 = df.withColumn('RELATED_IDLIST', F.concat_ws(',','RELATED_IDLIST')) \
        .withColumn('PROCESSED_IDLIST', F.col('ID'))
    
    df_new = recursive_join(df1, 5)
    df_new.show(10,0)
    +---+----+-----------------------+-----------------------+
    |ID |NAME|RELATED_IDLIST         |PROCESSED_IDLIST       |
    +---+----+-----------------------+-----------------------+
    |555|dan |333                    |333                    |
    |789|marc|111                    |111                    |
    |345|alen|789,111                |789,111                |
    |123|mike|345,456,789,789,999,111|345,456,789,789,999,111|
    |456|sam |789,999,111            |789,999,111            |
    +---+----+-----------------------+-----------------------+
    

  • RELATED_IDLIST 拆分为整数数组,然后使用 pandas_udf 函数删除重复的数组元素:

  • split RELATED_IDLIST into array of integers and then use pandas_udf function to drop duplicate array elements:

    df_new.withColumn("RELATED_IDLIST", get_uniq(F.split('RELATED_IDLIST', ',').cast('array<int>'))).show(10,0)
    +---+----+-------------------------+-----------------------+
    |ID |NAME|RELATED_IDLIST           |PROCESSED_IDLIST       |
    +---+----+-------------------------+-----------------------+
    |555|dan |[333]                    |333                    |
    |789|marc|[111]                    |111                    |
    |345|alen|[789, 111]               |789,111                |
    |123|mike|[999, 456, 111, 789, 345]|345,456,789,789,999,111|
    |456|sam |[111, 789, 999]          |789,999,111            |
    +---+----+-------------------------+-----------------------+
    

  • 这篇关于PySpark 从其数组对象值中获取相关记录的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

    查看全文
    登录 关闭
    扫码关注1秒登录
    发送“验证码”获取 | 15天全站免登陆