Pyspark Dataframe 通过过滤分组 [英] Pyspark Dataframe group by filtering

查看:55
本文介绍了Pyspark Dataframe 通过过滤分组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个如下的数据框

cust_id   req    req_met
-------   ---    -------
 1         r1      1
 1         r2      0
 1         r2      1
 2         r1      1
 3         r1      1
 3         r2      1
 4         r1      0
 5         r1      1
 5         r2      0
 5         r1      1

我必须看看客户,看看他们有多少要求,看看他们是否至少满足过一次.同一个客户和需求可以有多条记录,一条满足和不满足.在上述情况下,我的输出应该是

I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be

cust_id
-------
  1
  2
  3

我所做的是

# say initial dataframe is df
df1 = df\
    .groupby('cust_id')\
    .countdistinct('req')\
    .alias('num_of_req')\
    .sum('req_met')\
    .alias('sum_req_met')

df2 = df1.filter(df1.num_of_req == df1.sum_req_met)

但在少数情况下它没有得到正确的结果

But in few cases it is not getting correct results

如何做到这一点?

推荐答案

首先,我将准备上面给出的玩具数据集,

First, I'll just prepare toy dataset from given above,

from pyspark.sql.functions import col
import pyspark.sql.functions as fn

df = spark.createDataFrame([[1, 'r1', 1],
 [1, 'r2', 0],
 [1, 'r2', 1],
 [2, 'r1', 1],
 [3, 'r1', 1],
 [3, 'r2', 1],
 [4, 'r1', 0],
 [5, 'r1', 1],
 [5, 'r2', 0],
 [5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))

我按 cust_idreq 分组做同样的事情,然后计算 req_met.之后,我创建函数将这些要求降低到 0, 1

I do the same thing by group by cust_id and req then count the req_met. After that, I create function to floor those requirement to just 0, 1

def floor_req(r):
    if r >= 1:
        return 1
    else:
        return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))

现在,我们可以通过计算不同的需求数量和满足的需求总数来检查每个客户是否满足所有需求.

Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.

df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'), 
                                                 fn.count('req').alias('n_req'))

最后,你只需要检查两列是否相等:

Finally, you just have to check if two columns are equal:

df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()

这篇关于Pyspark Dataframe 通过过滤分组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆