数据帧上的通用迭代器(Spark/Scala) [英] Generic iterator over dataframe (Spark/scala)

查看:94
本文介绍了数据帧上的通用迭代器(Spark/Scala)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要按特定顺序遍历数据帧,并应用一些复杂的逻辑来计算新列.

I need to iterate over data frame in specific order and apply some complex logic to calculate new column.

在下面的示例中,我将使用简单的表达式,其中s的当前值是所有先前值的乘积,因此似乎可以使用UDF甚至解析函数来完成.但是,实际上逻辑要复杂得多.

In below example I'll be using simple expression where current value for s is multiplication of all previous values thus it may seem like this can be done using UDF or even analytic functions. However, in reality logic is much more complex.

下面的代码完成了所需的操作

Below code does what is needed

import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.encoders.RowEncoder

val q = """
select 10 x, 1 y
union all select 10, 2
union all select 10, 3
union all select 20, 6
union all select 20, 4
union all select 20, 5
"""
val df = spark.sql(q)
def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row(0,0,1)) {
    case (r1, r2) => {
      val (x1, y1, s1) = r1 match {case Row(x: Int, y: Int, s: Int) => (x, y, s)}
      val (x2, y2)     = r2 match {case Row(x: Int, y: Int) => (x, y)}
      Row(x2, y2, s1 * y2)
    }
  }.drop(1)
}
val schema = new StructType().
             add(StructField("x", IntegerType, true)).
             add(StructField("y", IntegerType, true)).
             add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

输出

scala> df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+---+
|  x|  y|  s|
+---+---+---+
| 20|  4|  4|
| 20|  5| 20|
| 20|  6|120|
| 10|  1|  1|
| 10|  2|  2|
| 10|  3|  6|
+---+---+---+

我不喜欢的是

1)即使Spark可以推断数据框的名称和类型,我也明确定义了架构

1) I explicitly define schema even though Spark can infer names and types for data frame

scala> df
res1: org.apache.spark.sql.DataFrame = [x: int, y: int]

2)如果我向数据框中添加任何新列,则必须再次声明架构,更烦人的是-重新定义函数!

2) If I add any new column to data frame then I have to declare schema again and what is more annoying - re-define function!

假定数据框中有新的列z.在这种情况下,我必须更改f_row中的几乎每一行.

Assume there is new column z in data frame. In this case I have to change almost every line in f_row.

def f_row(iter: Iterator[Row]) : Iterator[Row] = {
  iter.scanLeft(Row(0,0,"",1)) {
    case (r1, r2) => {
      val (x1, y1, z1, s1) = r1 match {case Row(x: Int, y: Int, z: String, s: Int) => (x, y, z, s)}
      val (x2, y2, z2)     = r2 match {case Row(x: Int, y: Int, z: String) => (x, y, z)}
      Row(x2, y2, z2, s1 * y2)
    }
  }.drop(1)
}
val schema = new StructType().
             add(StructField("x", IntegerType, true)).
             add(StructField("y", IntegerType, true)).
             add(StructField("z", StringType, true)).
             add(StructField("s", IntegerType, true))
val encoder = RowEncoder(schema)
df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

输出

scala> df.withColumn("z", lit("dummy")).repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show
+---+---+-----+---+
|  x|  y|    z|  s|
+---+---+-----+---+
| 20|  4|dummy|  4|
| 20|  5|dummy| 20|
| 20|  6|dummy|120|
| 10|  1|dummy|  1|
| 10|  2|dummy|  2|
| 10|  3|dummy|  6|
+---+---+-----+---+

是否有一种以更通用的方式实现逻辑的方法,因此我不需要创建用于遍历每个特定数据帧的函数? 或者至少要避免在将新列添加到数据帧中之后在计算逻辑中未使用的代码更改.

Is there a way to implement logic in more generic way so I do not need to create function to iterate over every specific data frame? Or at least to avoid code changes after adding new columns into data frame which are not used in calculation logic.

请在下面查看更新的问题.

Please see updated question below.

更新

下面是两个以更通用的方式进行迭代的选项,但仍然存在一些缺点.

Below are two options to iterate in more generic way but still with some drawbacks.

// option 1
def f_row(iter: Iterator[Row]): Iterator[Row] = {
  val r = Row.fromSeq(Row(0, 0).toSeq :+ 1)
  iter.scanLeft(r)((r1, r2) => 
    Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y")))
  ).drop(1)
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

// option 2
def f_row(iter: Iterator[Row]): Iterator[Row] = {
  iter.map{
    var s = 1
    r => {
      s = s * r.getInt(r.fieldIndex("y"))
      Row.fromSeq(r.toSeq :+ s)
    }
  }
}
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

如果将新列添加到数据框中,则必须在选项1中更改iter.scanLeft的初始值.我也不喜欢选项2,因为它使用可变的var.

If a new column added to data frame then initial value for iter.scanLeft has to be changed in Option 1. Also I do not really like Option 2 because it uses mutable var.

是否有一种方法可以改进代码,使其纯粹地起作用,并且在将新列添加到数据框中时不需要进行任何更改?

推荐答案

好的,下面有足够的解决方法

Well, sufficient solution is below

def f_row(iter: Iterator[Row]): Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r = Row.fromSeq(head.toSeq :+ head.getInt(head.fieldIndex("y")))
    iter.scanLeft(r)((r1, r2) => 
      Row.fromSeq(r2.toSeq :+ r1.getInt(r1.size - 1) * r2.getInt(r2.fieldIndex("y"))))
  } else iter
}
val encoder = 
  RowEncoder(StructType(df.schema.fields :+ StructField("s", IntegerType, false)))
df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row)(encoder).show

更新

可以避免使用诸如getInt之类的功能,而采用更通用的getAs.

Functions like getInt can be avoided in favor of more generic getAs.

此外,为了能够按名称访问r1的行,我们可以生成GenericRowWithSchema,它是Row的子类.

Also, in order to be able to access rows of r1 by name we can generate GenericRowWithSchema which is subclass of Row.

隐式参数已添加到f_row,以便函数可以使用数据帧的当前架构,同时可以将其用作mapPartitions的参数.

Implicit parameter has been added to f_row so that function can use current schema of the data frame and in the same time it can be used as a parameter of the mapPartitions.

import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.encoders.RowEncoder

implicit val schema = StructType(df.schema.fields :+ StructField("result", IntegerType))
implicit val encoder = RowEncoder(schema)

def mul(x1: Int, x2: Int) = x1 * x2;

def f_row(iter: Iterator[Row])(implicit currentSchema : StructType) : Iterator[Row] = {
  if (iter.hasNext) {
    val head = iter.next
    val r =
      new GenericRowWithSchema((head.toSeq :+ (head.getAs("y"))).toArray, currentSchema)

    iter.scanLeft(r)((r1, r2) =>
      new GenericRowWithSchema((r2.toSeq :+ mul(r1.getAs("result"), r2.getAs("y"))).toArray, currentSchema))
  } else iter
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show

最后,可以以尾部递归的方式实现逻辑.

Finally, logic can be implemented in a tail recursive manner.

import scala.annotation.tailrec

def f_row(iter: Iterator[Row]) = {
  @tailrec
  def f_row_(iter: Iterator[Row], tmp: Int, result: Iterator[Row]): Iterator[Row] = {
    if (iter.hasNext) {
      val r = iter.next
      f_row_(iter, mul(tmp, r.getAs("y")),
        result ++ Iterator(Row.fromSeq(r.toSeq :+ mul(tmp, r.getAs("y")))))
    } else result
  }
  f_row_(iter, 1, Iterator[Row]())
}

df.repartition($"x").sortWithinPartitions($"y").mapPartitions(f_row).show

这篇关于数据帧上的通用迭代器(Spark/Scala)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆