通过过滤Pyspark Dataframe组 [英] Pyspark Dataframe group by filtering

查看:168
本文介绍了通过过滤Pyspark Dataframe组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的数据框如下

cust_id   req    req_met
-------   ---    -------
 1         r1      1
 1         r2      0
 1         r2      1
 2         r1      1
 3         r1      1
 3         r2      1
 4         r1      0
 5         r1      1
 5         r2      0
 5         r1      1

我必须查看客户,查看他们有多少要求,并查看他们是否至少满足一次.可以有同一客户和需求的多个记录,一个满足和不满足的记录.在上述情况下,我的输出应为

I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be

cust_id
-------
  1
  2
  3

我所做的是

# say initial dataframe is df
df1 = df\
    .groupby('cust_id')\
    .countdistinct('req')\
    .alias('num_of_req')\
    .sum('req_met')\
    .alias('sum_req_met')

df2 = df1.filter(df1.num_of_req == df1.sum_req_met)

但是在少数情况下,无法获得正确的结果

But in few cases it is not getting correct results

这怎么办?

推荐答案

首先,我将从上面给出的玩具数据集中进行准备,

First, I'll just prepare toy dataset from given above,

from pyspark.sql.functions import col
import pyspark.sql.functions as fn

df = spark.createDataFrame([[1, 'r1', 1],
 [1, 'r2', 0],
 [1, 'r2', 1],
 [2, 'r1', 1],
 [3, 'r1', 1],
 [3, 'r2', 1],
 [4, 'r1', 0],
 [5, 'r1', 1],
 [5, 'r2', 0],
 [5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))

我按cust_idreq分组进行相同的操作,然后计算req_met.之后,我创建函数以将这些要求限制为0、1

I do the same thing by group by cust_id and req then count the req_met. After that, I create function to floor those requirement to just 0, 1

def floor_req(r):
    if r >= 1:
        return 1
    else:
        return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))

现在,我们可以通过计算不同数量的需求和已满足的需求总数来检查每个客户是否满足所有需求.

Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.

df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'), 
                                                 fn.count('req').alias('n_req'))

最后,您只需要检查两列是否相等:

Finally, you just have to check if two columns are equal:

df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()

这篇关于通过过滤Pyspark Dataframe组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆