Pyspark Dataframe 通过过滤分组 [英] Pyspark Dataframe group by filtering
问题描述
我有一个如下的数据框
cust_id req req_met
------- --- -------
1 r1 1
1 r2 0
1 r2 1
2 r1 1
3 r1 1
3 r2 1
4 r1 0
5 r1 1
5 r2 0
5 r1 1
我必须看看客户,看看他们有多少要求,看看他们是否至少满足过一次.同一个客户和需求可以有多条记录,一条满足和不满足.在上述情况下,我的输出应该是
I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be
cust_id
-------
1
2
3
我所做的是
# say initial dataframe is df
df1 = df\
.groupby('cust_id')\
.countdistinct('req')\
.alias('num_of_req')\
.sum('req_met')\
.alias('sum_req_met')
df2 = df1.filter(df1.num_of_req == df1.sum_req_met)
但在少数情况下它没有得到正确的结果
But in few cases it is not getting correct results
如何做到这一点?
推荐答案
首先,我将准备上面给出的玩具数据集,
First, I'll just prepare toy dataset from given above,
from pyspark.sql.functions import col
import pyspark.sql.functions as fn
df = spark.createDataFrame([[1, 'r1', 1],
[1, 'r2', 0],
[1, 'r2', 1],
[2, 'r1', 1],
[3, 'r1', 1],
[3, 'r2', 1],
[4, 'r1', 0],
[5, 'r1', 1],
[5, 'r2', 0],
[5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))
我按 cust_id
和 req
分组做同样的事情,然后计算 req_met
.之后,我创建函数将这些要求降低到 0, 1
I do the same thing by group by cust_id
and req
then count the req_met
. After that, I create function to floor those requirement to just 0, 1
def floor_req(r):
if r >= 1:
return 1
else:
return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))
现在,我们可以通过计算不同的需求数量和满足的需求总数来检查每个客户是否满足所有需求.
Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.
df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'),
fn.count('req').alias('n_req'))
最后,你只需要检查两列是否相等:
Finally, you just have to check if two columns are equal:
df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()
这篇关于Pyspark Dataframe 通过过滤分组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!