用于PySpark的酸洗猴子补丁Keras模型 [英] Pickling monkey-patched Keras model for use in PySpark

查看:124
本文介绍了用于PySpark的酸洗猴子补丁Keras模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我要实现的总体目标是将Keras模型发送给每个spark工作者,以便我可以在应用于DataFrame列的UDF中使用该模型.为此,Keras模型将需要可腌制.

The overall goal of what I am trying to achieve is sending a Keras model to each spark worker so that I can use the model within a UDF applied to a column of a DataFrame. To do this, the Keras model will need to be picklable.

似乎很多人都通过猴子修补Model类来成功腌制keras模型,如以下链接所示:

It seems like a lot of people have had success at pickling keras models by monkey patching the Model class as shown by the link below:

http://zachmoshe.com/2017/04/03/pickling-keras-models.html

但是,我还没有看到任何与Spark一起执行此操作的示例.我的第一次尝试只是在驱动程序中运行了make_keras_picklable()函数,这使我可以在驱动程序中腌制和释放模型,但无法在UDF中腌制该模型.

However, I have not seen any example of how to do this in tandem with Spark. My first attempt just ran the make_keras_picklable() function on in the driver which allowed me to pickle and unpickle the model in the driver, but I could not pickle the model in UDFs.

def make_keras_picklable():
    "Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"
    ...

make_keras_picklable()

model = Sequential() # etc etc

def score(case):
    ....
    score = model.predict(case)
    ...

def scoreUDF = udf(score, ArrayType(FloatType()))

我得到的错误表明,在UDF中解开模型未使用猴子修补的Model类.

The error I get suggests that the unpickling the model in the UDF is not using the monkey-patched Model class.

AttributeError: 'Sequential' object has no attribute '_built'

似乎其他用户在此SO帖子和答案中遇到了类似的错误是要同时在每个工人上运行make_keras_picklable()".没有给出如何执行此操作的示例.

It looks like another user was running into similar errors in this SO post and the answer was to "run make_keras_picklable() on each worker as well." No example of how to do this was given.

我的问题是:在所有工作人员上调用make_keras_picklable()的适当方法是什么?

My question is: What is the appropriate way to call make_keras_picklable() on all workers?

我尝试使用broadcast()(请参见下文),但出现与上述相同的错误.

I tried using broadcast() (see below) but got the same error as above.

def make_keras_picklable():
    "Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"
    ...

make_keras_picklable()
spark.sparkContext.broadcast(make_keras_picklable())

model = Sequential() # etc etc

def score(case):
    ....
    score = model.predict(case)
    ...

def scoreUDF = udf(score, ArrayType(FloatType()))

推荐答案

Khaled Zaouk over on the Spark user mailing list helped me out by suggesting that the make_keras_picklable() be changed to a wrapper class. This worked great!

class KerasModelWrapper():
'''Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html'''

def __init__(self, model):
    self.model = model

def __getstate__(self):
    model_str = ""
    with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
        km.save_model(self.model, fd.name, overwrite=True)
        model_str = fd.read()
    d = {'model_str': model_str}
    return d

def __setstate__(self, state):
    with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
        fd.write(state['model_str'])
        fd.flush()
        self.model = keras.models.load_model(fd.name)

通过将其实现为Keras的Model类的子类或PySpark.ML转换器/估计器,当然可以使它更加优雅.

Of course this could probably be made a little bit more elegant by implementing this as a subclass of Keras's Model class or maybe a PySpark.ML transformer/estimator.

这篇关于用于PySpark的酸洗猴子补丁Keras模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆