Spark聚集在分区内的多个列上,而不会随机 [英] Spark aggregate on multiple columns within partition without shuffle

查看:96
本文介绍了Spark聚集在分区内的多个列上,而不会随机的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在多个列上汇总一个数据框。我知道聚合所需的所有内容都在分区内-即无需进行洗牌,因为聚合的所有数据都在分区本地。

I'm trying to aggregate a dataframe on multiple columns. I know that everything I need for the aggregation is within the partition- that is, there's no need for a shuffle because all of the data for the aggregation are local to the partition.

使用示例,如果我有

        val sales=sc.parallelize(List(
        ("West",  "Apple",  2.0, 10),
        ("West",  "Apple",  3.0, 15),
        ("West",  "Orange", 5.0, 15),
        ("South", "Orange", 3.0, 9),
        ("South", "Orange", 6.0, 18),
        ("East",  "Milk",   5.0, 5))).repartition(2)
        val tdf = sales.map{ case (store, prod, amt, units) => ((store, prod), (amt, amt, amt, units)) }.
        reduceByKey((x, y) => (x._1 + y._1, math.min(x._2, y._2), math.max(x._3, y._3), x._4 + y._4))
      println(tdf.toDebugString)

我得到一个结果,例如

(2) ShuffledRDD[12] at reduceByKey at Test.scala:59 []
 +-(2) MapPartitionsRDD[11] at map at Test.scala:58 []
    |  MapPartitionsRDD[10] at repartition at Test.scala:57 []
    |  CoalescedRDD[9] at repartition at Test.scala:57 []
    |  ShuffledRDD[8] at repartition at Test.scala:57 []
    +-(1) MapPartitionsRDD[7] at repartition at Test.scala:57 []
       |  ParallelCollectionRDD[6] at parallelize at Test.scala:51 []

您可以看到MapPartitionsRDD,很好。但是接下来是ShuffleRDD,我想阻止它,因为我想要按分区汇总,并按分区内的列值分组。

You can see the MapPartitionsRDD, which is good. But then there's the ShuffleRDD, which I want to prevent because I want the per-partition summarization, grouped by column values within the partition.

zero323 建议非常接近,但我需要按列分组功能。

zero323's suggestion is tantalizingly close, but I need the "group by columns" functionality.

请参阅上面的示例,我正在寻找结果

Referring to my sample above, I'm looking for the result that would be produced by

select store, prod, sum(amt), avg(units) from sales group by partition_id, store, prod

(我真的不需要分区ID,这只是为了说明我希望-partition结果)

(I don't really need the partition id- that's just to illustrate that I want per-partition results)

我看过地段 of 示例,但我生成的每个调试字符串都具有随机播放功能。我真的希望摆脱这种洗牌。我猜我实际上是在寻找groupByKeysWithinPartitions函数。

I've looked at lots of examples but every debug string I've produced has the Shuffle. I really hope to get rid of the shuffle. I guess I'm essentially looking for a groupByKeysWithinPartitions function.

推荐答案

实现此目标的唯一方法是使用mapPartitions并进行自定义用于在迭代分区时对值进行分组和计算的代码。
正如您提到的,数据已经按分组键(存储,产品)进行了排序,我们可以以流水线方式有效地计算您的聚合:

The only way to achieve that is by using mapPartitions and have custom code for grouping and computing your values while iterating the partition. As you mention the data is already sorted by grouping keys (store, prod), we can efficiently compute your aggregations in a pipelined fashion:

(1)定义助手类:

:paste

case class MyRec(store: String, prod: String, amt: Double, units: Int)

case class MyResult(store: String, prod: String, total_amt: Double, min_amt: Double, max_amt: Double, total_units: Int)

object MyResult {
  def apply(rec: MyRec): MyResult = new MyResult(rec.store, rec.prod, rec.amt, rec.amt, rec.amt, rec.units)

  def aggregate(result: MyResult, rec: MyRec) = {
    new MyResult(result.store,
      result.prod,
      result.total_amt + rec.amt,
      math.min(result.min_amt, rec.amt),
      math.max(result.max_amt, rec.amt),
      result.total_units + rec.units
    )
  }
}

(2)定义流水线聚合器:

(2) Define pipelined aggregator:

:paste

def pipelinedAggregator(iter: Iterator[MyRec]): Iterator[Seq[MyResult]] = {

var prev: MyResult = null
var res: Seq[MyResult] = Nil

for (crt <- iter) yield {
  if (prev == null) {
    prev = MyResult(crt)
  }
  else if (prev.prod != crt.prod || prev.store != crt.store) {
    res = Seq(prev)
    prev = MyResult(crt)
  }
  else {
    prev = MyResult.aggregate(prev, crt)
  }

  if (!iter.hasNext) {
    res = res ++ Seq(prev)
  }

  res
}

}

(3)运行汇总:

:paste

val sales = sc.parallelize(
  List(MyRec("West", "Apple", 2.0, 10),
    MyRec("West", "Apple", 3.0, 15),
    MyRec("West", "Orange", 5.0, 15),
    MyRec("South", "Orange", 3.0, 9),
    MyRec("South", "Orange", 6.0, 18),
    MyRec("East", "Milk", 5.0, 5),
    MyRec("West", "Apple", 7.0, 11)), 2).toDS

sales.mapPartitions(iter => Iterator(iter.toList)).show(false)

val result = sales
  .mapPartitions(recIter => pipelinedAggregator(recIter))
  .flatMap(identity)

result.show
result.explain

输出:

    +-------------------------------------------------------------------------------------+
    |value                                                                                |
    +-------------------------------------------------------------------------------------+
    |[[West,Apple,2.0,10], [West,Apple,3.0,15], [West,Orange,5.0,15]]                     |
    |[[South,Orange,3.0,9], [South,Orange,6.0,18], [East,Milk,5.0,5], [West,Apple,7.0,11]]|
    +-------------------------------------------------------------------------------------+

    +-----+------+---------+-------+-------+-----------+
    |store|  prod|total_amt|min_amt|max_amt|total_units|
    +-----+------+---------+-------+-------+-----------+
    | West| Apple|      5.0|    2.0|    3.0|         25|
    | West|Orange|      5.0|    5.0|    5.0|         15|
    |South|Orange|      9.0|    3.0|    6.0|         27|
    | East|  Milk|      5.0|    5.0|    5.0|          5|
    | West| Apple|      7.0|    7.0|    7.0|         11|
    +-----+------+---------+-------+-------+-----------+

    == Physical Plan ==
    *SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).store, true) AS store#31, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).prod, true) AS prod#32, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_amt AS total_amt#33, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).min_amt AS min_amt#34, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).max_amt AS max_amt#35, assertnotnull(input[0, $line14.$read$$iw$$iw$MyResult, true]).total_units AS total_units#36]
    +- MapPartitions <function1>, obj#30: $line14.$read$$iw$$iw$MyResult
       +- MapPartitions <function1>, obj#20: scala.collection.Seq
          +- Scan ExternalRDDScan[obj#4]
    sales: org.apache.spark.sql.Dataset[MyRec] = [store: string, prod: string ... 2 more fields]
    result: org.apache.spark.sql.Dataset[MyResult] = [store: string, prod: string ... 4 more fields]    

这篇关于Spark聚集在分区内的多个列上,而不会随机的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆