包装在pyspark一个java功能 [英] Wrapping a java function in pyspark
问题描述
我想创建一个用户定义的聚合函数,我可以从蟒蛇调用。我试图按照答案<一个href=\"http://stackoverflow.com/questions/33233737/spark-how-to-map-python-with-scala-or-java-user-defined-functions\">this题。
我基本上采取了以下(从<一所href=\"https://github.com/apache/spark/blob/master/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java\"相对=nofollow>这里):
I am trying to create a user defined aggregate function which I can call from python. I tried to follow the answer to this question. I basically implemented the following (taken from here):
package com.blu.bla;
import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
public class MySum extends UserDefinedAggregateFunction {
private StructType _inputDataType;
private StructType _bufferSchema;
private DataType _returnDataType;
public MySum() {
List<StructField> inputFields = new ArrayList<StructField>();
inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true));
_inputDataType = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<StructField>();
bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true));
_bufferSchema = DataTypes.createStructType(bufferFields);
_returnDataType = DataTypes.DoubleType;
}
@Override public StructType inputSchema() {
return _inputDataType;
}
@Override public StructType bufferSchema() {
return _bufferSchema;
}
@Override public DataType dataType() {
return _returnDataType;
}
@Override public boolean deterministic() {
return true;
}
@Override public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, null);
}
@Override public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
if (buffer.isNullAt(0)) {
buffer.update(0, input.getDouble(0));
} else {
Double newValue = input.getDouble(0) + buffer.getDouble(0);
buffer.update(0, newValue);
}
}
}
@Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
if (!buffer2.isNullAt(0)) {
if (buffer1.isNullAt(0)) {
buffer1.update(0, buffer2.getDouble(0));
} else {
Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0);
buffer1.update(0, newValue);
}
}
}
@Override public Object evaluate(Row buffer) {
if (buffer.isNullAt(0)) {
return null;
} else {
return buffer.getDouble(0);
}
}
}
然后我和所有依赖编译并运行pyspark与--jars myjar.jar
I then compiled it with all dependencies and run pyspark with --jars myjar.jar
在pyspark我所做的:
In pyspark I did:
df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"])
from pyspark.sql.column import Column, _to_java_column, _to_seq
from pyspark.sql import Row
def myCol(col):
_f = sc._jvm.com.blu.bla.MySum.apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
b = df.agg(myCol("A"))
我得到了以下错误:
I got the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-24-f45b2a367e67> in <module>()
----> 1 b = df.agg(myCol("A"))
<ipython-input-22-afcb8884e1db> in myCol(col)
4 def myCol(col):
5 _f = sc._jvm.com.blu.bla.MySum.apply
----> 6 return Column(_f(_to_seq(sc,[col], _to_java_column)))
TypeError: 'JavaPackage' object is not callable
我也尝试添加--driver-类路径pyspark电话,但得到了同样的结果。
I also tried adding --driver-class-path to the pyspark call but got the same result.
还试图通过Java进口访问java类:
Also tried to access the java class through java import:
from py4j.java_gateway import java_import
jvm = sc._gateway.jvm
java_import(jvm, "com.bla.blu.MySum")
def myCol2(col):
_f = jvm.bla.blu.MySum.apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
也试图简单地创建类(如建议<一href=\"http://stackoverflow.com/questions/33544105/running-custom-java-class-in-pyspark\">here):
a = jvm.com.bla.blu.MySum()
所有都得到同样的错误消息。
All are getting the same error message.
我似乎无法找出问题所在。
I can't seem to figure out what the problem is.
推荐答案
如此看来,主要问题是,所有的选项来添加的jar(--jars,驱动程序类路径,SPARK_CLASSPATH)不正常,如果让工作相对路径。这是可能是因为里面IPython的工作目录的问题,而不是在那里我跑pyspark。
So it seems the main issue was that all of the options to add the jar (--jars, driver class path, SPARK_CLASSPATH) do not work properly if giving a relative path. THis is probably because of issues with the working directory inside ipython as opposed to where I ran pyspark.
有一次,我改变了这种绝对路径,它的工作原理(有没有在群集上测试它没有,但至少它的工作原理上的本地安装)。
Once I changed this to absolute path, it works (Haven't tested it on a cluster yet but at least it works on a local installation).
另外,我不知道这是一个错误也是在回答<一个href=\"http://stackoverflow.com/questions/33233737/spark-how-to-map-python-with-scala-or-java-user-defined-functions\">here因为这个答案使用Scala实现,然而,在Java实现我需要做的。
Also, I am not sure if this is a bug also in the answer here as that answer uses a scala implementation, however in the java implementation I needed to do
def myCol(col):
_f = sc._jvm.com.blu.bla.MySum().apply
return Column(_f(_to_seq(sc,[col], _to_java_column)))
这可能不是真的有效,因为它创建_f每一次,而不是我也许应该定义_f功能外(同样,这将需要在集群上测试),但至少现在它提供正确的功能答案
This is probably not really efficient as it creates _f each time, instead I should probably define _f outside the function (again, this would require testing on the cluster) but at least now it provides the correct functional answer
这篇关于包装在pyspark一个java功能的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!