在多处理中使用keras [英] Use keras in multiprocessing

查看:282
本文介绍了在多处理中使用keras的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

这基本上是以下内容的重复: Keras + Tensorflow和Python中的多处理 但是我的设置有些不同,他们的解决方案对我不起作用.

This is basically a duplicate of: Keras + Tensorflow and Multiprocessing in Python But my setup is a bit different, and their solution doesn't work for me.

我需要针对另一个模型的预测来训练keras模型. 预测与一些CPU繁重的代码有关,因此我想对其进行并行化,并在工作进程中运行代码. 这是我要执行的代码:

I need to train a keras model against predictions made from another model. The predictions are connected to some CPU heavy code, so I would like to parallelize them and have the code run in worker processes. Here is the code I would like to execute:

import numpy as np

from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam

def create_model():
    input_layer = Input((10,))
    dense = Dense(10)(input_layer)

    return Model(inputs=input_layer, outputs=dense)

model_outside = create_model()
model_outside.compile(Adam(1e-3), "mse")

def subprocess_routine(weights):
    model_inside = create_model()
    model_inside.set_weights(weights)

    while True:
        # lots of CPU
        batch = np.random.rand(10, 10)
        prediction = model_inside.predict(batch)

        yield batch, prediction

weights = model_outside.get_weights()

model_outside.fit_generator(subprocess_routine(weights),
                            epochs=10,
                            steps_per_epoch=100,
                            use_multiprocessing=True,
                            workers=1)

这会产生错误

E tensorflow/core/grappler/clusters/utils.cc:81]无法获取设备 属性,错误代码:3

E tensorflow/core/grappler/clusters/utils.cc:81] Failed to get device properties, error code: 3

我发现了上面的问题,答案是将keras导入移到子进程中.我已将所有导入添加到subprocess_routine中.但这并不会改变错误.可能有必要从主要过程中完全消除喀拉拉邦的进口,但是在我的设置中,这将意味着巨大的重构.

I found the above question, the answer is to move keras imports into the subprocess. I have added all imports into the subprocess_routine. But that doesn't change the error. It would probably be necessary to eliminate keras imports altogether from the main process, but in my setup, that would mean huge refactorings.

Keras +多线程似乎可以工作.在本期中,向下滚动到最后一条评论: https://github.com/keras-team/keras/issues/5640 在我的代码中,它看起来像这样:

Keras + multithreading seems to work. In this issue, scroll down to the very last comment: https://github.com/keras-team/keras/issues/5640 In my code, it looks like this:

model_inside = create_model()
model_inside._make_predict_function()

graph = tf.get_default_graph()

def subprocess_routine(model_inside, graph):

    while True:
        batch = np.random.rand(10, 10)

        with graph.as_default():
            prediction = model_inside.predict(batch)

        yield batch, prediction

model_outside.fit_generator(subprocess_routine(model_inside, graph),
                            epochs=10,
                            steps_per_epoch=100,
                            use_multiprocessing=True,
                            workers=1)

但是错误消息是相同的.

But the error message is identical.

由于问题显然与子流程的初始化有关,所以我尝试在每个子流程中创建一个新会话:

Since the problem is apparently related to initialization of the subprocesses, I tried to create a new session in each subprocess:

def subprocess_routine(weights):

    import keras.backend as K
    import tensorflow as tf
    sess = tf.Session()
    K.set_session(sess)

    model_inside = create_model()
    model_inside.set_weights(weights)

    while True:
        batch = np.random.rand(10, 10)
        prediction = model_inside.predict(batch)

        yield batch, prediction

它在相同的错误消息上产生一个变体:

It produces a variation on the same error message:

E tensorflow/stream_executor/cuda/cuda_driver.cc:1300]无法 检索CUDA设备计数:CUDA_ERROR_NOT_INITIALIZED

E tensorflow/stream_executor/cuda/cuda_driver.cc:1300] could not retrieve CUDA device count: CUDA_ERROR_NOT_INITIALIZED

同样,初始化似乎已中断.

So again, the initialization seems broken.

如何在我的主流程和由多处理产生的子流程中运行keras?

How can I run keras both in my main process and subprocesses spawned by multiprocessing ?

推荐答案

好消息是tensorflow会话是线程安全的:

The good news is that tensorflow sessions are thread-safe: Is it thread-safe when using tf.Session in inference service?

要在多个过程中使用keras模型,您必须执行以下操作:

To use a keras model in multiple processes, you have to do the following:

  • 设置模型
  • 致电_make_predict_function()
  • 设置会话并使用它获取张量流图
  • 最终完成此图
  • 每次预测时,提供此图as_default_graph()
  • set up the model
  • call _make_predict_function()
  • set up a session and use it to get the tensorflow graph
  • finalize this graph
  • everytime you predict something, supply this graph as_default_graph()

以下是一些示例代码:

# the usual imports
import numpy as np
import tensorflow as tf

from keras.models import *
from keras.layers import *

# set up the model
i = Input(shape=(10,))
b = Dense(1)(i)
model = Model(inputs=i, outputs=b)

# now to use it in multiprocessing, the following is necessary
model._make_predict_function()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
default_graph = tf.get_default_graph()
default_graph.finalize()

# now you share the model and graph between processes
# in each process you can call this:
with default_graph.as_default():
    return model.predict(something)

这篇关于在多处理中使用keras的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆