Keras Tensorflow - 从多个线程进行预测时出现异常 [英] Keras Tensorflow - Exception while predicting from multiple threads
问题描述
我使用 keras 2.0.8 和 tensorflow 1.3.0 后端.
I am using keras 2.0.8 with tensorflow 1.3.0 backend.
我正在类 init 中加载一个模型,然后用它来预测多线程.
I am loading a model in the class init and then use it to predict multithreaded.
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
class CNN:
def __init__(self, model_path):
self.cnn_model = load_model(model_path)
self.session = K.get_session()
self.graph = tf.get_default_graph()
def query_cnn(self, data):
X = self.preproccesing(data)
with self.session.as_default():
with self.graph.as_default():
return self.cnn_model.predict(X)
我初始化了一次 CNN,query_cnn 方法发生在多个线程中.
I initialize the CNN once and the query_cnn method happens from multiple threads.
我在日志中得到的异常是:
The exception i get in my log is:
File "/home/*/Similarity/CNN.py", line 43, in query_cnn
return self.cnn_model.predict(X)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 913, in predict
return self.model.predict(x, batch_size=batch_size, verbose=verbose)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1713, in predict
verbose=verbose, steps=steps)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1269, in _predict_loop
batch_outs = f(ins_batch)
File "/usr/local/lib/python3.5/dist-packages/keras/backend/tensorflow_backend.py", line 2273, in __call__
**self.session_kwargs)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: PruneForTargets: Some target nodes not found: group_deps
代码在大多数情况下运行良好,可能是多线程存在问题.
The code works fine most of the times, its probably some problem with the multithreading.
我该如何解决?
推荐答案
确保在创建其他线程之前完成图形创建.
Make sure you finish the graph creation before creating the other threads.
在图表上调用 finalize()
可能会对此有所帮助.
Calling finalize()
on the graph may help you with that.
def __init__(self, model_path):
self.cnn_model = load_model(model_path)
self.session = K.get_session()
self.graph = tf.get_default_graph()
self.graph.finalize()
更新 1: finalize()
将使您的图形只读,因此可以在多个线程中安全使用.作为副作用,它会帮助您找到无意的行为,有时还会出现内存泄漏,因为当您尝试修改图形时它会抛出异常.
Update 1: finalize()
will make your graph read-only so it can be safely used in multiple threads. As a side effect, it will help you find unintentional behavior and sometimes memory leaks as it will throw an exception when you try to modify the graph.
想象一下,您有一个线程可以对您的输入进行一次热编码.(不好的例子:)
Imagine that you have a thread that does for instance one hot encoding of your inputs. (bad example:)
def preprocessing(self, data):
one_hot_data = tf.one_hot(data, depth=self.num_classes)
return self.session.run(one_hot_data)
如果您打印图表中的对象数量,您会注意到它会随着时间的推移而增加
If you print the amount of objects in the graph you will notice that it will increase over time
# amount of nodes in tf graph
print(len(list(tf.get_default_graph().as_graph_def().node)))
但如果你先定义图表,情况就不会如此(代码稍微好一点):
But if you define the graph first that won't be the case (slightly better code):
def preprocessing(self, data):
# run pre-created operation with self.input as placeholder
return self.session.run(self.one_hot_data, feed_dict={self.input: data})
更新 2:根据这个thread在进行多线程之前,您需要在 keras 模型上调用 model._make_predict_function()
.
Update 2: According to this thread you need to call model._make_predict_function()
on a keras model before doing multithreading.
Keras 在您第一次调用 predict() 时构建 GPU 函数.那这样,如果您从不调用 predict,您可以节省一些时间和资源.但是,第一次调用 predict 时比每次调用都稍慢其他时间.
Keras builds the GPU function the first time you call predict(). That way, if you never call predict, you save some time and resources. However, the first time you call predict is slightly slower than every other time.
更新后的代码:
def __init__(self, model_path):
self.cnn_model = load_model(model_path)
self.cnn_model._make_predict_function() # have to initialize before threading
self.session = K.get_session()
self.graph = tf.get_default_graph()
self.graph.finalize() # make graph read-only
更新 3: 我做了一个热身概念的证明,因为 _make_predict_function()
似乎没有按预期工作.首先我创建了一个虚拟模型:
Update 3: I did a proof of concept of a warming up, because _make_predict_function()
doesn't seems to work as expected.
First I created a dummy model:
import tensorflow as tf
from keras.layers import *
from keras.models import *
model = Sequential()
model.add(Dense(256, input_shape=(2,)))
model.add(Dense(1, activation='softmax'))
model.compile(loss='mean_squared_error', optimizer='adam')
model.save("dummymodel")
然后在另一个脚本中我加载了该模型并使其在多个线程上运行
Then in another script I loaded that model and made it run on multiple threads
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
import threading as t
import numpy as np
K.clear_session()
class CNN:
def __init__(self, model_path):
self.cnn_model = load_model(model_path)
self.cnn_model.predict(np.array([[0,0]])) # warmup
self.session = K.get_session()
self.graph = tf.get_default_graph()
self.graph.finalize() # finalize
def preproccesing(self, data):
# dummy
return data
def query_cnn(self, data):
X = self.preproccesing(data)
with self.session.as_default():
with self.graph.as_default():
prediction = self.cnn_model.predict(X)
print(prediction)
return prediction
cnn = CNN("dummymodel")
th = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th2 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th3 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th4 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th5 = t.Thread(target=cnn.query_cnn, kwargs={"data": np.random.random((500, 2))})
th.start()
th2.start()
th3.start()
th4.start()
th5.start()
th2.join()
th.join()
th3.join()
th5.join()
th4.join()
评论热身和完成的行,我能够重现您的第一个问题
Commenting the lines for the warmingup and finalize I was able to reproduce your first issue
这篇关于Keras Tensorflow - 从多个线程进行预测时出现异常的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!