使用 Python 序列化自定义转换器以在 Pyspark ML 管道中使用 [英] Serialize a custom transformer using python to be used within a Pyspark ML pipeline
问题描述
我在 在 PySpark 中创建自定义转换器的评论部分中发现了相同的讨论ML,但没有明确的答案.还有一个未解决的 JIRA 对应于:https://issues.apache.org/jira/browse/SPARK-17025.
I found the same discussion in comments section of Create a custom Transformer in PySpark ML, but there is no clear answer. There is also an unresolved JIRA corresponding to that: https://issues.apache.org/jira/browse/SPARK-17025.
鉴于 Pyspark ML 管道没有提供用于保存用 Python 编写的自定义转换器的选项,还有哪些其他选项可以完成它?如何在返回兼容 java 对象的 python 类中实现 _to_java 方法?
Given that there is no option provided by Pyspark ML pipeline for saving a custom transformer written in python, what are the other options to get it done? How can I implement the _to_java method in my python class that returns a compatible java object?
推荐答案
从 Spark 2.3.0 开始,有一种很多、很多更好的方法来做到这一点.
As of Spark 2.3.0 there's a much, much better way to do this.
只需扩展 <代码>DefaultParamsWritable 和 DefaultParamsReadable
并且您的类将自动具有 write
和 read
方法,这些方法将保存您的参数并将PipelineModel
序列化系统使用.
Simply extend DefaultParamsWritable
and DefaultParamsReadable
and your class will automatically have write
and read
methods that will save your params and will be used by the PipelineModel
serialization system.
文档不是很清楚,我必须阅读一些源代码才能理解这是反序列化的工作方式.
The docs were not really clear, and I had to do a bit of source reading to understand this was the way that deserialization worked.
PipelineModel.read
实例化一个PipelineModelReader
PipelineModelReader
加载元数据并检查语言是否为'Python'
.如果不是,则使用典型的JavaMLReader
(这些答案中的大多数都是为此而设计的)- 否则,使用
PipelineSharedReadWrite
,它调用DefaultParamsReader.loadParamsInstance
PipelineModel.read
instantiates aPipelineModelReader
PipelineModelReader
loads metadata and checks if language is'Python'
. If it's not, then the typicalJavaMLReader
is used (what most of these answers are designed for)- Otherwise,
PipelineSharedReadWrite
is used, which callsDefaultParamsReader.loadParamsInstance
loadParamsInstance
将从保存的元数据中找到 class
.它将实例化该类并对其调用 .load(path)
.您可以扩展 DefaultParamsReader
并获得 DefaultParamsReader.load
方法自动.如果您确实需要实现专门的反序列化逻辑,我会将 load
方法视为起点.
loadParamsInstance
will find class
from the saved metadata. It will instantiate that class and call .load(path)
on it. You can extend DefaultParamsReader
and get the DefaultParamsReader.load
method automatically. If you do have specialized deserialization logic you need to implement, I would look at that load
method as a starting place.
对面:
PipelineModel.write
将检查所有阶段是否都是 Java(实现JavaMLWritable
).如果是这样,则使用典型的JavaMLWriter
(大多数这些答案的设计目的)- 否则,使用
PipelineWriter
,它检查所有阶段是否实现MLWritable
并调用PipelineSharedReadWrite.saveImpl
PipelineSharedReadWrite.saveImpl
将在每个阶段调用.write().save(path)
.
PipelineModel.write
will check if all stages are Java (implementJavaMLWritable
). If so, the typicalJavaMLWriter
is used (what most of these answers are designed for)- Otherwise,
PipelineWriter
is used, which checks that all stages implementMLWritable
and callsPipelineSharedReadWrite.saveImpl
PipelineSharedReadWrite.saveImpl
will call.write().save(path)
on each stage.
您可以扩展 DefaultParamsWriter
以获取 DefaultParamsWritable.write
方法,以正确的格式保存类和参数的元数据.如果您有需要实现的自定义序列化逻辑,我会查看它和 DefaultParamsWriter
作为起点.
You can extend DefaultParamsWriter
to get the DefaultParamsWritable.write
method that saves metadata for your class and params in the right format. If you have custom serialization logic you need to implement, I would look at that and DefaultParamsWriter
as a starting point.
好的,最后,您有一个非常简单的转换器来扩展 Params,并且您的所有参数都以典型的 Params 方式存储:
Ok, so finally, you have a pretty simple transformer that extends Params and all your parameters are stored in the typical Params fashion:
from pyspark import keyword_only
from pyspark.ml import Transformer
from pyspark.ml.param.shared import HasOutputCols, Param, Params
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql.functions import lit # for the dummy _transform
class SetValueTransformer(
Transformer, HasOutputCols, DefaultParamsReadable, DefaultParamsWritable,
):
value = Param(
Params._dummy(),
"value",
"value to fill",
)
@keyword_only
def __init__(self, outputCols=None, value=0.0):
super(SetValueTransformer, self).__init__()
self._setDefault(value=0.0)
kwargs = self._input_kwargs
self._set(**kwargs)
@keyword_only
def setParams(self, outputCols=None, value=0.0):
"""
setParams(self, outputCols=None, value=0.0)
Sets params for this SetValueTransformer.
"""
kwargs = self._input_kwargs
return self._set(**kwargs)
def setValue(self, value):
"""
Sets the value of :py:attr:`value`.
"""
return self._set(value=value)
def getValue(self):
"""
Gets the value of :py:attr:`value` or its default value.
"""
return self.getOrDefault(self.value)
def _transform(self, dataset):
for col in self.getOutputCols():
dataset = dataset.withColumn(col, lit(self.getValue()))
return dataset
现在我们可以使用它了:
Now we can use it:
from pyspark.ml import Pipeline, PipelineModel
svt = SetValueTransformer(outputCols=["a", "b"], value=123.0)
p = Pipeline(stages=[svt])
df = sc.parallelize([(1, None), (2, 1.0), (3, 0.5)]).toDF(["key", "value"])
pm = p.fit(df)
pm.transform(df).show()
pm.write().overwrite().save('/tmp/example_pyspark_pipeline')
pm2 = PipelineModel.load('/tmp/example_pyspark_pipeline')
print('matches?', pm2.stages[0].extractParamMap() == pm.stages[0].extractParamMap())
pm2.transform(df).show()
结果:
+---+-----+-----+-----+
|key|value| a| b|
+---+-----+-----+-----+
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
+---+-----+-----+-----+
matches? True
+---+-----+-----+-----+
|key|value| a| b|
+---+-----+-----+-----+
| 1| null|123.0|123.0|
| 2| 1.0|123.0|123.0|
| 3| 0.5|123.0|123.0|
+---+-----+-----+-----+
这篇关于使用 Python 序列化自定义转换器以在 Pyspark ML 管道中使用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!