检查存在于另一个数组pyspark中的数组的所有元素 [英] Check all the elements of an array present in another array pyspark

查看:81
本文介绍了检查存在于另一个数组pyspark中的数组的所有元素的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个df1 spark数据帧

I have a df1 spark dataframe

id     transactions
1      [1, 2, 3, 5]
2      [1, 2, 3, 6]
3      [1, 2, 9, 8]
4      [1, 2, 5, 6]

root
 |-- id: int (nullable = true)
 |-- transactions: array (nullable = false)
     |-- element: int(containsNull = true)
 None

我有一个df2 spark数据帧

I have a df2 spark dataframe

items   cost
  [1]    1.0
  [2]    1.0
 [2, 1]  2.0
 [6, 1]  2.0

root
 |-- items: array (nullable = false)
    |-- element: int (containsNull = true)
 |-- cost: int (nullable = true)
 None

我想检查items列中的所有数组元素是否都在transactions列中.

I want to check whether all the array elements from items column are in transactions column.

第一行( [1、2、3、5] )包含项列中的 [1],[2],[2、1] .因此,我需要总结它们相应的费用: 1.0 + 1.0 + 2.0 = 4.0

The first row ([1, 2, 3, 5]) contains [1],[2],[2, 1] from items column. Hence I need to sum up their corresponding costs: 1.0 + 1.0 + 2.0 = 4.0

我想要的输出是

id     transactions    score
1      [1, 2, 3, 5]   4.0
2      [1, 2, 3, 6]   6.0
3      [1, 2, 9, 8]   4.0
4      [1, 2, 5, 6]   6.0

我尝试将循环与 collect()/ toLocalIterator 一起使用,但这似乎并不有效.我将拥有大量数据.

I tried using a loop with collect()/toLocalIterator but it does not seem to be efficient. I will have large data.

我认为创建这样的UDF可以解决它.但是会引发错误.

i think creating an UDF like this will solve it. but it throws error.

from pyspark.sql.functions import udf
def containsAll(x,y):
  result =  all(elem in x  for elem in y)

  if result:
    print("Yes, transactions contains all items")    
  else :
    print("No")

 contains_udf = udf(containsAll)
 dataFrame.withColumn("result", 
 contains_udf(df2.items,df1.transactions)).show()

还是还有其他办法?

推荐答案

2.4之前的有效udf(请注意,它不必返回某些内容

A valid udf before 2.4 (note that it hast to return something

from pyspark.sql.functions import udf

@udf("boolean")
def contains_all(x, y):
    if x is not None and y is not None:
        return set(y).issubset(set(x))

在2.4或更高版本中,无需udf:

In 2.4 or later no udf is required:

from pyspark.sql.functions import array_intersect, size

def contains_all(x, y):
    return size(array_intersect(x, y)) == size(y)

用法:

from pyspark.sql.functions import col, sum as sum_, when

df1 = spark.createDataFrame(
   [(1, [1, 2, 3, 5]), (2, [1, 2, 3, 6]), (3, [1, 2, 9, 8]), (4, [1, 2, 5, 6])],
   ("id", "transactions")
)

df2 = spark.createDataFrame(
    [([1], 1.0), ([2], 1.0), ([2, 1], 2.0), ([6, 1], 2.0)],
    ("items", "cost")
)


(df1
    .crossJoin(df2).groupBy("id", "transactions")
    .agg(sum_(when(
        contains_all("transactions", "items"), col("cost")
    )).alias("score"))
    .show())

结果:

+---+------------+-----+                                                        
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
+---+------------+-----+

如果 df2 很小,则可以将其用作局部变量:

If df2 is small it could preferred to use it as a local variable:

items = sc.broadcast([
    (set(items), cost) for items, cost in df2.select("items", "cost").collect()
])

def score(y):
    @udf("double")
    def _(x):
        if x is not None:
            transactions = set(x)
            return sum(
                cost for items, cost in y.value 
                if items.issubset(transactions))
    return _


df1.withColumn("score", score(items)("transactions")).show()

+---+------------+-----+
| id|transactions|score|
+---+------------+-----+
|  1|[1, 2, 3, 5]|  4.0|
|  2|[1, 2, 3, 6]|  6.0|
|  3|[1, 2, 9, 8]|  4.0|
|  4|[1, 2, 5, 6]|  6.0|
+---+------------+-----+

最后有可能爆炸并加入

from pyspark.sql.functions import explode

costs = (df1
    # Explode transactiosn
    .select("id", explode("transactions").alias("item"))
    .join(
        df2 
            # Add id so we can later use it to identify source
            .withColumn("_id", monotonically_increasing_id().alias("_id"))
             # Explode items
            .select(
                "_id", explode("items").alias("item"), 
                # We'll need size of the original items later
                size("items").alias("size"), "cost"), 
         ["item"])
     # Count matches in groups id, items
     .groupBy("_id", "id", "size", "cost")
     .count()
     # Compute cost
     .groupBy("id")
     .agg(sum_(when(col("size") == col("count"), col("cost"))).alias("score")))

costs.show()

+---+-----+                                                                      
| id|score|
+---+-----+
|  1|  4.0|
|  3|  4.0|
|  2|  6.0|
|  4|  6.0|
+---+-----+

,然后将结果与原始 df1

df1.join(costs, ["id"])

但这不是那么简单的解决方案,并且需要多次改组.与笛卡尔积( crossJoin )相比,它可能仍然更可取,但这取决于实际数据.

but that's much less straightforward solution, and requires multiple shuffles. It might be still preferable over Cartesian product (crossJoin), but it will depend on the actual data.

这篇关于检查存在于另一个数组pyspark中的数组的所有元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆