如何在多列上编写 Pyspark UDAF? [英] How to write Pyspark UDAF on multiple columns?

查看:26
本文介绍了如何在多列上编写 Pyspark UDAF?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在名为 end_stats_df 的 pyspark 数据框中有以下数据:

I have the following data in a pyspark dataframe called end_stats_df:

values     start    end    cat1   cat2
10          1        2      A      B
11          1        2      C      B
12          1        2      D      B
510         1        2      D      C
550         1        2      C      B
500         1        2      A      B
80          1        3      A      B

我想用以下方式聚合它:

And I want to aggregate it in the following way:

  • 我想使用开始"和结束"列作为聚合键
  • 对于每组行,我需要执行以下操作:
    • 计算该组的 cat1cat2 中值的唯一数量.例如,对于 start=1 和 end=2 的组,这个数字将是 4,因为有 A、B、C、D.这个数字将被存储为 n(在本例中 n=4).
    • 对于 values 字段,对于每个组,我需要对 values 进行排序,然后选择每个 n-1 值,其中 n 是从上面的第一个操作中存储的值.
    • 在聚合结束时,我并不真正关心在上述操作之后 cat1cat2 中的内容.
    • I want to use the "start" and "end" columns as the aggregate keys
    • For each group of rows, I need to do the following:
      • Compute the unique number of values in both cat1 and cat2 for that group. e.g., for the group of start=1 and end=2, this number would be 4 because there's A, B, C, D. This number will be stored as n (n=4 in this example).
      • For the values field, for each group I need to sort the values, and then select every n-1 value, where n is the value stored from the first operation above.
      • At the end of the aggregation, I don't really care what is in cat1 and cat2 after the operations above.

      以上示例的示例输出是:

      An example output from the example above is:

      values     start    end    cat1   cat2
      12          1        2      D      B
      550         1        2      C      B
      80          1        3      A      B
      

      如何使用 pyspark 数据框来完成?我想我需要使用自定义 UDAF,对吗?

      How do I accomplish using pyspark dataframes? I assume I need to use a custom UDAF, right?

      推荐答案

      Pyspark 不直接支持UDAF,需要手动聚合.

      Pyspark do not support UDAF directly, so we have to do aggregation manually.

      from pyspark.sql import functions as f
      
      def func(values, cat1, cat2):
          n = len(set(cat1 + cat2))
          return sorted(values)[n - 2]
      
      
      df = spark.read.load('file:///home/zht/PycharmProjects/test/text_file.txt', format='csv', sep='\t', header=True)
      df = df.groupBy(df['start'], df['end']).agg(f.collect_list(df['values']).alias('values'),
                                                  f.collect_set(df['cat1']).alias('cat1'),
                                                  f.collect_set(df['cat2']).alias('cat2'))
      df = df.select(df['start'], df['end'], f.UserDefinedFunction(func, StringType())(df['values'], df['cat1'], df['cat2']))
      

      这篇关于如何在多列上编写 Pyspark UDAF?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆