使用 pyspark 分层采样 [英] Stratified sampling with pyspark

查看:46
本文介绍了使用 pyspark 分层采样的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个 Spark DataFrame,它有一列有很多零和很少的零(只有 0.01%).

I have a Spark DataFrame that has one column that has lots of zeros and very few ones (only 0.01% of ones).

我想抽取一个随机子样本,但要进行分层 - 以便在该列中保持 1 与 0 的比率.

I'd like to take a random subsample but a stratified one - so that it keeps the ratio of 1s to 0s in that column.

是否可以在 pyspark 中进行?

Is it possible to do in pyspark ?

我正在寻找一个非scala解决方案,并且基于DataFrame而不是基于RDD.

I am looking for a non-scala solution and on based on DataFrames and not RDD-based.

推荐答案

我在 Spark 中的分层采样中提出的解决方案Scala 转换为 Python(甚至转换为 Java - 对 Spark 数据集进行分层的最简单方法是什么?).

The solution I suggested in Stratified sampling in Spark is pretty straightforward to convert from Scala to Python (or even to Java - What's the easiest way to stratify a Spark Dataset ?).

不过,我会重写它python.让我们首先创建一个玩具 DataFrame :

Nevertheless, I'll rewrite it python. Let's start first by creating a toy DataFrame :

from pyspark.sql.functions import lit
list = [(2147481832,23355149,1),(2147481832,973010692,1),(2147481832,2134870842,1),(2147481832,541023347,1),(2147481832,1682206630,1),(2147481832,1138211459,1),(2147481832,852202566,1),(2147481832,201375938,1),(2147481832,486538879,1),(2147481832,919187908,1),(214748183,919187908,1),(214748183,91187908,1)]
df = spark.createDataFrame(list, ["x1","x2","x3"])
df.show()
# +----------+----------+---+
# |        x1|        x2| x3|
# +----------+----------+---+
# |2147481832|  23355149|  1|
# |2147481832| 973010692|  1|
# |2147481832|2134870842|  1|
# |2147481832| 541023347|  1|
# |2147481832|1682206630|  1|
# |2147481832|1138211459|  1|
# |2147481832| 852202566|  1|
# |2147481832| 201375938|  1|
# |2147481832| 486538879|  1|
# |2147481832| 919187908|  1|
# | 214748183| 919187908|  1|
# | 214748183|  91187908|  1|
# +----------+----------+---+

这个 DataFrame 如你所见,有 12 个元素:

This DataFrame has 12 elements as you can see :

df.count()
# 12

分布如下:

df.groupBy("x1").count().show()
# +----------+-----+
# |        x1|count|
# +----------+-----+
# |2147481832|   10|
# | 214748183|    2|
# +----------+-----+

现在让我们举例:

首先我们将设置种子:

seed = 12

找到分数和样本的关键:

The find the keys to fraction on and sample :

fractions = df.select("x1").distinct().withColumn("fraction", lit(0.8)).rdd.collectAsMap()
print(fractions)                                                            
# {2147481832: 0.8, 214748183: 0.8}
sampled_df = df.stat.sampleBy("x1", fractions, seed)
sampled_df.show()
# +----------+---------+---+
# |        x1|       x2| x3|
# +----------+---------+---+
# |2147481832| 23355149|  1|
# |2147481832|973010692|  1|
# |2147481832|541023347|  1|
# |2147481832|852202566|  1|
# |2147481832|201375938|  1|
# |2147481832|486538879|  1|
# |2147481832|919187908|  1|
# | 214748183|919187908|  1|
# | 214748183| 91187908|  1|
# +----------+---------+---+

我们现在可以检查样本的内容:

We can now check the content of our sample :

sampled_df.count()
# 9

sampled_df.groupBy("x1").count().show()
# +----------+-----+
# |        x1|count|
# +----------+-----+
# |2147481832|    7|
# | 214748183|    2|
# +----------+-----+

这篇关于使用 pyspark 分层采样的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆