检查另一个数组pyspark中存在的数组的所有元素 [英] Check all the elements of an array present in another array pyspark
问题描述
我有一个 df1 spark 数据框
I have a df1 spark dataframe
id transactions
1 [1, 2, 3, 5]
2 [1, 2, 3, 6]
3 [1, 2, 9, 8]
4 [1, 2, 5, 6]
root
|-- id: int (nullable = true)
|-- transactions: array (nullable = false)
|-- element: int(containsNull = true)
None
我有一个 df2 spark 数据框
I have a df2 spark dataframe
items cost
[1] 1.0
[2] 1.0
[2, 1] 2.0
[6, 1] 2.0
root
|-- items: array (nullable = false)
|-- element: int (containsNull = true)
|-- cost: int (nullable = true)
None
我想检查 items 列中的所有数组元素是否都在 transactions 列中.
I want to check whether all the array elements from items column are in transactions column.
第一行 ([1, 2, 3, 5]
) 包含来自 items 列的 [1],[2],[2, 1]
.因此我需要总结它们相应的成本:1.0 + 1.0 + 2.0 = 4.0
The first row ([1, 2, 3, 5]
) contains [1],[2],[2, 1]
from items column. Hence I need to sum up their corresponding costs: 1.0 + 1.0 + 2.0 = 4.0
我想要的输出是
id transactions score
1 [1, 2, 3, 5] 4.0
2 [1, 2, 3, 6] 6.0
3 [1, 2, 9, 8] 4.0
4 [1, 2, 5, 6] 6.0
我尝试使用带有 collect()
/toLocalIterator
的循环,但它似乎效率不高.我会有大数据.
I tried using a loop with collect()
/toLocalIterator
but it does not seem to be efficient. I will have large data.
我认为创建这样的 UDF 可以解决它.但它会引发错误.
i think creating an UDF like this will solve it. but it throws error.
from pyspark.sql.functions import udf
def containsAll(x,y):
result = all(elem in x for elem in y)
if result:
print("Yes, transactions contains all items")
else :
print("No")
contains_udf = udf(containsAll)
dataFrame.withColumn("result",
contains_udf(df2.items,df1.transactions)).show()
或者有其他办法吗?
推荐答案
2.4 之前的有效 udf(注意它必须返回一些东西
A valid udf before 2.4 (note that it hast to return something
from pyspark.sql.functions import udf
@udf("boolean")
def contains_all(x, y):
if x is not None and y is not None:
return set(y).issubset(set(x))
在 2.4 或更高版本中不需要 udf:
In 2.4 or later no udf is required:
from pyspark.sql.functions import array_intersect, size
def contains_all(x, y):
return size(array_intersect(x, y)) == size(y)
用法:
from pyspark.sql.functions import col, sum as sum_, when
df1 = spark.createDataFrame(
[(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
("id", "transactions")
)
df2 = spark.createDataFrame(
[([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
("items", "cost")
)
(df1
.crossJoin(df2).groupBy("id", "transactions")
.agg(sum_(when(
contains_all("transactions", "items"), col("cost")
)).alias("score"))
.show())
结果:
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
| 1|[1, 2, 3, 5]| 4.0|
| 4|[1, 2, 5, 6]| 6.0|
| 2|[1, 2, 3, 6]| 6.0|
| 3|[1, 2, 9, 8]| 4.0|
+---+------------+-----+
如果 df2
很小,最好将其用作局部变量:
If df2
is small it could preferred to use it as a local variable:
items = sc.broadcast([
(set(items), cost) for items, cost in df2.select("items", "cost").collect()
])
def score(y):
@udf("double")
def _(x):
if x is not None:
transactions = set(x)
return sum(
cost for items, cost in y.value
if items.issubset(transactions))
return _
df1.withColumn("score", score(items)("transactions")).show()
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
| 1|[1, 2, 3, 5]| 4.0|
| 2|[1, 2, 3, 6]| 6.0|
| 3|[1, 2, 9, 8]| 4.0|
| 4|[1, 2, 5, 6]| 6.0|
+---+------------+-----+
终于可以爆炸加入了
from pyspark.sql.functions import explode
costs = (df1
# Explode transactiosn
.select("id", explode("transactions").alias("item"))
.join(
df2
# Add id so we can later use it to identify source
.withColumn("_id", monotonically_increasing_id().alias("_id"))
# Explode items
.select(
"_id", explode("items").alias("item"),
# We'll need size of the original items later
size("items").alias("size"), "cost"),
["item"])
# Count matches in groups id, items
.groupBy("_id", "id", "size", "cost")
.count()
# Compute cost
.groupBy("id")
.agg(sum_(when(col("size") == col("count"), col("cost"))).alias("score")))
costs.show()
+---+-----+
| id|score|
+---+-----+
| 1| 4.0|
| 3| 4.0|
| 2| 6.0|
| 4| 6.0|
+---+-----+
然后将结果与原始df1
、
df1.join(costs, ["id"])
但这不是直接的解决方案,需要多次洗牌.它可能仍然优于笛卡尔积(crossJoin
),但这将取决于实际数据.
but that's much less straightforward solution, and requires multiple shuffles. It might be still preferable over Cartesian product (crossJoin
), but it will depend on the actual data.
这篇关于检查另一个数组pyspark中存在的数组的所有元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!