Spark 中的累积乘积 [英] Cumulative product in Spark

查看:38
本文介绍了Spark 中的累积乘积的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试在 Spark Scala 中实现累积产品,但我真的不知道如何实现.我有以下数据框:

I try to implement a cumulative product in Spark Scala, but I really don't know how to it. I have the following dataframe:

Input data:
+--+--+--------+----+
|A |B | date   | val|
+--+--+--------+----+
|rr|gg|20171103| 2  |
|hh|jj|20171103| 3  |
|rr|gg|20171104| 4  |
|hh|jj|20171104| 5  |
|rr|gg|20171105| 6  |
|hh|jj|20171105| 7  |
+-------+------+----+

我想要以下输出:

Output data:
+--+--+--------+-----+
|A |B | date   | val |
+--+--+--------+-----+
|rr|gg|20171105| 48  | // 2 * 4 * 6
|hh|jj|20171105| 105 | // 3 * 5 * 7
+-------+------+-----+

推荐答案

只要数字是严格的正数(也可以处理 0,如果存在,使用 coalesce),就像你的例子一样,最简单的解决方案是计算对数之和并取指数:

As long as the number are strictly positive (0 can be handled as well, if present, using coalesce) as in your example, the simplest solution is to compute the sum of logarithms and take the exponential:

import org.apache.spark.sql.functions.{exp, log, max, sum}

val df = Seq(
  ("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3), 
  ("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5), 
  ("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")

val result = df
  .groupBy("A", "B")
  .agg(
    max($"date").as("date"), 
    exp(sum(log($"val"))).as("val"))

由于这使用了 FP 算术,因此结果将不准确:

Since this uses FP arithmetic the result won't be exact:

result.show

+---+---+--------+------------------+
|  A|  B|    date|               val|
+---+---+--------+------------------+
| hh| jj|20171105|104.99999999999997|
| rr| gg|20171105|47.999999999999986|
+---+---+--------+------------------+

但四舍五入后对于大多数应用程序来说应该足够了.

but after rounding should good enough for majority of applications.

result.withColumn("val", round($"val")).show

+---+---+--------+-----+
|  A|  B|    date|  val|
+---+---+--------+-----+
| hh| jj|20171105|105.0|
| rr| gg|20171105| 48.0|
+---+---+--------+-----+

如果这还不够,您可以定义 UserDefinedAggregateFunctionAggregator(如何定义并在 Spark SQL 中使用用户定义的聚合函数?)或使用带有 reduceGroups 的函数式 API:

If that's not enough you can define an UserDefinedAggregateFunction or Aggregator (How to define and use a User-Defined Aggregate Function in Spark SQL?) or use functional API with reduceGroups:

import scala.math.Ordering

case class Record(A: String, B: String, date: String, value: Long)

df.withColumnRenamed("val", "value").as[Record]
  .groupByKey(x => (x.A, x.B))
  .reduceGroups((x, y) => x.copy(
    date = Ordering[String].max(x.date, y.date),
    value = x.value * y.value))
  .toDF("key", "value")
  .select($"value.*")
  .show

+---+---+--------+-----+
|  A|  B|    date|value|
+---+---+--------+-----+
| hh| jj|20171105|  105|
| rr| gg|20171105|   48|
+---+---+--------+-----+

这篇关于Spark 中的累积乘积的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆