检查存在于另一个数组pyspark中的数组的所有元素 [英] Check all the elements of an array present in another array pyspark
问题描述
我有一个df1 spark数据帧
I have a df1 spark dataframe
id transactions
1 [1, 2, 3, 5]
2 [1, 2, 3, 6]
3 [1, 2, 9, 8]
4 [1, 2, 5, 6]
root
|-- id: int (nullable = true)
|-- transactions: array (nullable = false)
|-- element: int(containsNull = true)
None
我有一个df2 spark数据帧
I have a df2 spark dataframe
items cost
[1] 1.0
[2] 1.0
[2, 1] 2.0
[6, 1] 2.0
root
|-- items: array (nullable = false)
|-- element: int (containsNull = true)
|-- cost: int (nullable = true)
None
我想检查items列中的所有数组元素是否都在transactions列中.
I want to check whether all the array elements from items column are in transactions column.
第一行( [1、2、3、5]
)包含项列中的 [1],[2],[2、1]
.因此,我需要总结它们相应的费用: 1.0 + 1.0 + 2.0 = 4.0
The first row ([1, 2, 3, 5]
) contains [1],[2],[2, 1]
from items column. Hence I need to sum up their corresponding costs: 1.0 + 1.0 + 2.0 = 4.0
我想要的输出是
id transactions score
1 [1, 2, 3, 5] 4.0
2 [1, 2, 3, 6] 6.0
3 [1, 2, 9, 8] 4.0
4 [1, 2, 5, 6] 6.0
我尝试将循环与 collect()
/ toLocalIterator
一起使用,但这似乎并不有效.我将拥有大量数据.
I tried using a loop with collect()
/toLocalIterator
but it does not seem to be efficient. I will have large data.
我认为创建这样的UDF可以解决它.但是会引发错误.
i think creating an UDF like this will solve it. but it throws error.
from pyspark.sql.functions import udf
def containsAll(x,y):
result = all(elem in x for elem in y)
if result:
print("Yes, transactions contains all items")
else :
print("No")
contains_udf = udf(containsAll)
dataFrame.withColumn("result",
contains_udf(df2.items,df1.transactions)).show()
还是还有其他办法?
推荐答案
2.4之前的有效udf(请注意,它不必返回某些内容
A valid udf before 2.4 (note that it hast to return something
from pyspark.sql.functions import udf
@udf("boolean")
def contains_all(x, y):
if x is not None and y is not None:
return set(y).issubset(set(x))
在2.4或更高版本中,无需udf:
In 2.4 or later no udf is required:
from pyspark.sql.functions import array_intersect, size
def contains_all(x, y):
return size(array_intersect(x, y)) == size(y)
用法:
from pyspark.sql.functions import col, sum as sum_, when
df1 = spark.createDataFrame(
[(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
("id", "transactions")
)
df2 = spark.createDataFrame(
[([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
("items", "cost")
)
(df1
.crossJoin(df2).groupBy("id", "transactions")
.agg(sum_(when(
contains_all("transactions", "items"), col("cost")
)).alias("score"))
.show())
结果:
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
| 1|[1, 2, 3, 5]| 4.0|
| 4|[1, 2, 5, 6]| 6.0|
| 2|[1, 2, 3, 6]| 6.0|
| 3|[1, 2, 9, 8]| 4.0|
+---+------------+-----+
如果 df2
很小,则可以将其用作局部变量:
If df2
is small it could preferred to use it as a local variable:
items = sc.broadcast([
(set(items), cost) for items, cost in df2.select("items", "cost").collect()
])
def score(y):
@udf("double")
def _(x):
if x is not None:
transactions = set(x)
return sum(
cost for items, cost in y.value
if items.issubset(transactions))
return _
df1.withColumn("score", score(items)("transactions")).show()
+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
| 1|[1, 2, 3, 5]| 4.0|
| 2|[1, 2, 3, 6]| 6.0|
| 3|[1, 2, 9, 8]| 4.0|
| 4|[1, 2, 5, 6]| 6.0|
+---+------------+-----+
最后有可能爆炸并加入
from pyspark.sql.functions import explode
costs = (df1
# Explode transactiosn
.select("id", explode("transactions").alias("item"))
.join(
df2
# Add id so we can later use it to identify source
.withColumn("_id", monotonically_increasing_id().alias("_id"))
# Explode items
.select(
"_id", explode("items").alias("item"),
# We'll need size of the original items later
size("items").alias("size"), "cost"),
["item"])
# Count matches in groups id, items
.groupBy("_id", "id", "size", "cost")
.count()
# Compute cost
.groupBy("id")
.agg(sum_(when(col("size") == col("count"), col("cost"))).alias("score")))
costs.show()
+---+-----+
| id|score|
+---+-----+
| 1| 4.0|
| 3| 4.0|
| 2| 6.0|
| 4| 6.0|
+---+-----+
,然后将结果与原始 df1
,
df1.join(costs, ["id"])
但这不是那么简单的解决方案,并且需要多次改组.与笛卡尔积( crossJoin
)相比,它可能仍然更可取,但这取决于实际数据.
but that's much less straightforward solution, and requires multiple shuffles. It might be still preferable over Cartesian product (crossJoin
), but it will depend on the actual data.
这篇关于检查存在于另一个数组pyspark中的数组的所有元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!