在pyspark UDF中使用tensorflow.keras模型会产生泡菜错误 [英] Using tensorflow.keras model in pyspark UDF generates a pickle error

查看:482
本文介绍了在pyspark UDF中使用tensorflow.keras模型会产生泡菜错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在pysark pandas_udf中使用tensorflow.keras模型.但是,在将模型发送给工作人员之前对其进行序列化时,我会收到一个pickle错误.我不确定我使用的是最好的方法来执行所需的操作,因此我将展示一个最小但完整的示例.

I would like to use a tensorflow.keras model in a pysark pandas_udf. However, I get a pickle error when the model is being serialized before sending it to the workers. I am not sure I am using the best method to perform what I want, therefore I will expose a minimal but complete example.

包装:

  • tensorflow-2.2.0(但所有以前的版本也会触发错误)
  • pyspark-2.4.5

导入语句为:

import pandas as pd
import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

from pyspark.sql import SparkSession, functions as F, types as T

Pyspark UDF是pandas_udf:

The Pyspark UDF is a pandas_udf:

def compute_output_pandas_udf(model):
    '''Spark pandas udf for model prediction.'''

    @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
    def compute_output(inputs1, inputs2, inputs3):
        pdf = pd.DataFrame({
            'input1': inputs1,
            'input2': inputs2,
            'input3': inputs3
        })
        pdf['predicted_output'] = model.predict(pdf.values)
        return pdf['predicted_output']

    return compute_output

主要代码:

# Model parameters
weights = np.array([[0.5], [0.4], [0.3]])
bias = np.array([1.25])
activation = 'linear'
input_dim, output_dim = weights.shape

# Initialize model
model = Sequential()
layer = Dense(output_dim, input_dim=input_dim, activation=activation)
model.add(layer)
layer.set_weights([weights, bias])

# Initialize Spark session
spark = SparkSession.builder.appName('test').getOrCreate()

# Create pandas df with inputs and run model
pdf = pd.DataFrame({
    'input1': np.random.randn(200),
    'input2': np.random.randn(200),
    'input3': np.random.randn(200)
})
pdf['predicted_output'] = model.predict(pdf[['input1', 'input2', 'input3']].values)

# Create spark df with inputs and run model using udf
sdf = spark.createDataFrame(pdf)
sdf = sdf.withColumn('predicted_output', compute_output_pandas_udf(model)('input1', 'input2', 'input3'))
sdf.limit(5).show()

调用 compute_output_pandas_udf(model)时触发此错误:

PicklingError: Could not serialize object: TypeError: can't pickle _thread.RLock objects

我发现此页面关于腌制喀拉斯模型并在tensorflow.keras上进行了尝试,但是当在UDF中调用模型的 predict 函数时,出现了以下错误(因此,可以进行序列化但不能进行反序列化吗?):

I found this page about pickling a keras model and tried it on tensorflow.keras but I got the following error when the predict function of the model is called in the UDF (so serialization worked but unserialization not?):

AttributeError: 'Sequential' object has no attribute '_distribution_strategy'

任何人都知道如何进行吗?预先谢谢你!

Anyone has an idea about how to proceed? Thank you in advance!

PS:请注意,我没有直接从keras库中使用模型,因为我有另一个定期出现的错误,而且解决起来似乎更困难.但是,该模型的序列化不会像tensorflow.keras模型那样产生错误.

PS: Note that I did not use a model directly from keras library because I have another error appearing periodically and it seems more difficult to solve it. However, the serialization of the model does not generate an error as with the tensorflow.keras model.

推荐答案

因此,如果我们使用该解决方案直接在 getstate setstate 方法中扩展该方法,则看起来像这样tensorflow.keras.models.Model类,如 http://zachmoshe .com/2017/04/03/pickling-keras-models.html ,则工作人员无法解序列化模型,因为他们没有此类的扩展.

So it looks like that if we use the solution to extend the getstate and setstate methods directly in the tensorflow.keras.models.Model class as in http://zachmoshe.com/2017/04/03/pickling-keras-models.html, then the workers are not able to unserialize the model as they don't have this extension of the class.

然后,解决方案是使用包装类,如 Erp12 /stackoverflow.com/questions/50007126/pickling-monkey-patched-keras-model-for-use-in-pyspark>帖子.

Then, the solution is to use a wrapper class as Erp12 suggested in this post.

class ModelWrapperPickable:

    def __init__(self, model):
        self.model = model

    def __getstate__(self):
        model_str = ''
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            tensorflow.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = { 'model_str': model_str }
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            self.model = tensorflow.keras.models.load_model(fd.name)

UDF变为:

def compute_output_pandas_udf(model_wrapper):
    '''Spark pandas udf for model prediction.'''

    @F.pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
    def compute_output(inputs1, inputs2, inputs3):
        pdf = pd.DataFrame({
            'input1': inputs1,
            'input2': inputs2,
            'input3': inputs3
        })
        pdf['predicted_output'] = model_wrapper.model.predict(pdf.values)
        return pdf['predicted_output']

    return compute_output

主要代码:

# Model parameters
weights = np.array([[0.5], [0.4], [0.3]])
bias = np.array([1.25])
activation = 'linear'
input_dim, output_dim = weights.shape

# Initialize keras model
model = Sequential()
layer = Dense(output_dim, input_dim=input_dim, activation=activation)
model.add(layer)
layer.set_weights([weights, bias])
# Initialize model wrapper
model_wrapper= ModelWrapperPickable(model)

# Initialize Spark session
spark = SparkSession.builder.appName('test').getOrCreate()

# Create pandas df with inputs and run model
pdf = pd.DataFrame({
    'input1': np.random.randn(200),
    'input2': np.random.randn(200),
    'input3': np.random.randn(200)
})
pdf['predicted_output'] = model_wrapper.model.predict(pdf[['input1', 'input2', 'input3']].values)

# Create spark df with inputs and run model using udf
sdf = spark.createDataFrame(pdf)
sdf = sdf.withColumn('predicted_output', compute_output_pandas_udf(model_wrapper)('input1', 'input2', 'input3'))
sdf.limit(5).show()

这篇关于在pyspark UDF中使用tensorflow.keras模型会产生泡菜错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆