组/集合中的Spark合并/合并数组 [英] Spark merge/combine arrays in groupBy/aggregate

查看:1045
本文介绍了组/集合中的Spark合并/合并数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

以下Spark代码正确演示了我想做的事,并通过一个很小的演示数据集生成了正确的输出.

当我在大量生产数据上运行相同的通用类型代码时,我遇到了运行时问题. Spark作业在我的群集上运行了大约12个小时,并且失败了.

仅看一下下面的代码,爆炸每一行,将其合并回去似乎效率低下.在给定的测试数据集中,第四行在array_value_1中具有三个值,在array_value_2中具有三个值,将爆炸为3 * 3或九个爆炸行.

那么,在更大的数据集中,一行包含五个这样的数组列,每一列中包含十个值,会爆炸成10 ^ 5个爆炸行吗?

看看提供的Spark函数,没有什么开箱即用的函数可以满足我的要求.我可以提供用户定义的功能.有速度上的缺点吗?

val sparkSession = SparkSession.builder.
  master("local")
  .appName("merge list test")
  .getOrCreate()

val schema = StructType(
  StructField("category", IntegerType) ::
    StructField("array_value_1", ArrayType(StringType)) ::
    StructField("array_value_2", ArrayType(StringType)) ::
    Nil)

val rows = List(
  Row(1, List("a", "b"), List("u", "v")),
  Row(1, List("b", "c"), List("v", "w")),
  Row(2, List("c", "d"), List("w")),
  Row(2, List("c", "d", "e"), List("x", "y", "z"))
)

val df = sparkSession.createDataFrame(rows.asJava, schema)

val dfExploded = df.
  withColumn("scalar_1", explode(col("array_value_1"))).
  withColumn("scalar_2", explode(col("array_value_2")))

// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")

val dfOutput = dfExploded.groupBy("category").agg(
  collect_set("scalar_1").alias("combined_values_2"),
  collect_set("scalar_2").alias("combined_values_2"))

dfOutput.show()

解决方案

explode可能效率低下,但从根本上讲,您尝试实现的操作非常昂贵.实际上,这只是另一个groupByKey,您在这里可以做的很多事情都不能做得更好.由于您使用Spark> 2.0,因此可以直接collect_list并展平:

import org.apache.spark.sql.functions.{collect_list, udf}

val flatten_distinct = udf(
  (xs: Seq[Seq[String]]) => xs.flatten.distinct)

df
  .groupBy("category")
  .agg(
    flatten_distinct(collect_list("array_value_1")), 
    flatten_distinct(collect_list("array_value_2"))
  )

在Spark> = 2.4中,您可以将udf替换为内置函数:

import org.apache.spark.sql.functions.{array_distinct, flatten}

val flatten_distinct = (array_distinct _) compose (flatten _)

也可以使用自定义Aggregator ,但我怀疑其中任何一个都会带来很大的不同./p>

如果集合相对较大,并且您希望有大量重复项,则可以尝试对可变集合使用aggregateByKey:

import scala.collection.mutable.{Set => MSet}

val rdd = df
  .select($"category", struct($"array_value_1", $"array_value_2"))
  .as[(Int, (Seq[String], Seq[String]))]
  .rdd

val agg = rdd
  .aggregateByKey((MSet[String](), MSet[String]()))( 
    {case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
    {case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
  )
  .mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
  .toDF

The following Spark code correctly demonstrates what I want to do and generates the correct output with a tiny demo data set.

When I run this same general type of code on a large volume of production data, I am having runtime problems. The Spark job runs on my cluster for ~12 hours and fails out.

Just glancing at the code below, it seems inefficient to explode every row, just to merge it back down. In the given test data set, the fourth row with three values in array_value_1 and three values in array_value_2, that will explode to 3*3 or nine exploded rows.

So, in a larger data set, a row with five such array columns, and ten values in each column, would explode out to 10^5 exploded rows?

Looking at the provided Spark functions, there are no out of the box functions that would do what I want. I could supply a user-defined-function. Are there any speed drawbacks to that?

val sparkSession = SparkSession.builder.
  master("local")
  .appName("merge list test")
  .getOrCreate()

val schema = StructType(
  StructField("category", IntegerType) ::
    StructField("array_value_1", ArrayType(StringType)) ::
    StructField("array_value_2", ArrayType(StringType)) ::
    Nil)

val rows = List(
  Row(1, List("a", "b"), List("u", "v")),
  Row(1, List("b", "c"), List("v", "w")),
  Row(2, List("c", "d"), List("w")),
  Row(2, List("c", "d", "e"), List("x", "y", "z"))
)

val df = sparkSession.createDataFrame(rows.asJava, schema)

val dfExploded = df.
  withColumn("scalar_1", explode(col("array_value_1"))).
  withColumn("scalar_2", explode(col("array_value_2")))

// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")

val dfOutput = dfExploded.groupBy("category").agg(
  collect_set("scalar_1").alias("combined_values_2"),
  collect_set("scalar_2").alias("combined_values_2"))

dfOutput.show()

解决方案

It could be inefficient to explode but fundamentally the operation you try to implement is simply expensive. Effectively it is just another groupByKey and there is not much you can do here to make it better. Since you use Spark > 2.0 you could collect_list directly and flatten:

import org.apache.spark.sql.functions.{collect_list, udf}

val flatten_distinct = udf(
  (xs: Seq[Seq[String]]) => xs.flatten.distinct)

df
  .groupBy("category")
  .agg(
    flatten_distinct(collect_list("array_value_1")), 
    flatten_distinct(collect_list("array_value_2"))
  )

In Spark >= 2.4 you can replace udf with composition of built-in functions:

import org.apache.spark.sql.functions.{array_distinct, flatten}

val flatten_distinct = (array_distinct _) compose (flatten _)

It is also possible to use custom Aggregator but I doubt any of these will make a huge difference.

If sets are relatively large and you expect significant number of duplicates you could try to use aggregateByKey with mutable sets:

import scala.collection.mutable.{Set => MSet}

val rdd = df
  .select($"category", struct($"array_value_1", $"array_value_2"))
  .as[(Int, (Seq[String], Seq[String]))]
  .rdd

val agg = rdd
  .aggregateByKey((MSet[String](), MSet[String]()))( 
    {case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
    {case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
  )
  .mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
  .toDF

这篇关于组/集合中的Spark合并/合并数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆