Spark 中的累积乘积 [英] Cumulative product in Spark
问题描述
我尝试在 Spark Scala 中实现累积产品,但我真的不知道如何实现.我有以下数据框:
I try to implement a cumulative product in Spark Scala, but I really don't know how to it. I have the following dataframe:
Input data:
+--+--+--------+----+
|A |B | date | val|
+--+--+--------+----+
|rr|gg|20171103| 2 |
|hh|jj|20171103| 3 |
|rr|gg|20171104| 4 |
|hh|jj|20171104| 5 |
|rr|gg|20171105| 6 |
|hh|jj|20171105| 7 |
+-------+------+----+
我想要以下输出:
Output data:
+--+--+--------+-----+
|A |B | date | val |
+--+--+--------+-----+
|rr|gg|20171105| 48 | // 2 * 4 * 6
|hh|jj|20171105| 105 | // 3 * 5 * 7
+-------+------+-----+
推荐答案
只要数字是严格的正数(也可以处理 0,如果存在,使用 coalesce
),就像你的例子一样,最简单的解决方案是计算对数之和并取指数:
As long as the number are strictly positive (0 can be handled as well, if present, using coalesce
) as in your example, the simplest solution is to compute the sum of logarithms and take the exponential:
import org.apache.spark.sql.functions.{exp, log, max, sum}
val df = Seq(
("rr", "gg", "20171103", 2), ("hh", "jj", "20171103", 3),
("rr", "gg", "20171104", 4), ("hh", "jj", "20171104", 5),
("rr", "gg", "20171105", 6), ("hh", "jj", "20171105", 7)
).toDF("A", "B", "date", "val")
val result = df
.groupBy("A", "B")
.agg(
max($"date").as("date"),
exp(sum(log($"val"))).as("val"))
由于这使用了 FP 算术,因此结果将不准确:
Since this uses FP arithmetic the result won't be exact:
result.show
+---+---+--------+------------------+
| A| B| date| val|
+---+---+--------+------------------+
| hh| jj|20171105|104.99999999999997|
| rr| gg|20171105|47.999999999999986|
+---+---+--------+------------------+
但四舍五入后对于大多数应用程序来说应该足够了.
but after rounding should good enough for majority of applications.
result.withColumn("val", round($"val")).show
+---+---+--------+-----+
| A| B| date| val|
+---+---+--------+-----+
| hh| jj|20171105|105.0|
| rr| gg|20171105| 48.0|
+---+---+--------+-----+
如果这还不够,您可以定义 UserDefinedAggregateFunction
或 Aggregator
(如何定义并在 Spark SQL 中使用用户定义的聚合函数?)或使用带有 reduceGroups
的函数式 API:
If that's not enough you can define an UserDefinedAggregateFunction
or Aggregator
(How to define and use a User-Defined Aggregate Function in Spark SQL?) or use functional API with reduceGroups
:
import scala.math.Ordering
case class Record(A: String, B: String, date: String, value: Long)
df.withColumnRenamed("val", "value").as[Record]
.groupByKey(x => (x.A, x.B))
.reduceGroups((x, y) => x.copy(
date = Ordering[String].max(x.date, y.date),
value = x.value * y.value))
.toDF("key", "value")
.select($"value.*")
.show
+---+---+--------+-----+
| A| B| date|value|
+---+---+--------+-----+
| hh| jj|20171105| 105|
| rr| gg|20171105| 48|
+---+---+--------+-----+
这篇关于Spark 中的累积乘积的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!