多线程预测时出现Keras错误 [英] Keras error when predicting on multithreading
问题描述
我正在尝试创建四个线程(每个线程都有自己的图形和模型),这些线程将同时运行并以相同的方式发出预测.
I'm trying to create four threads (each one with its own graph and model) that will run concurently and issue predictions in the same way.
我的线程代码类似于:
thread_locker.acquire()
thread_graph = Graph()
with thread_graph.as_default():
thread_session = Session()
with thread_session.as_default():
#Model Training
if (once_flag_raised == False):
try:
model = load_model('ten_step_forward_'+ timeframe +'.h5')
except OSError:
input_layer = Input(shape=(X_train.shape[1], 17,))
lstm = Bidirectional(
LSTM(250),
merge_mode='concat')(input_layer)
pred = Dense(10)(lstm)
model = Model(inputs=input_layer, outputs=pred)
model.compile(optimizer='adam', loss='mean_squared_error')
once_flag_raised = True
model.fit(X_train, y_train, epochs=10, batch_size=128)
thread_locker.acquire()
nn_info_dict['model'] = model
nn_info_dict['sc'] = sc
model.save('ten_step_forward_'+ timeframe +'.h5')
thread_locker.release()
thread_locker.release()
(....)
thread_locker.acquire()
thread_graph = Graph()
with thread_graph.as_default():
thread_session = Session()
with thread_session.as_default():
pred_data= model.predict(X_pred)
thread_locker.release()
在每个线程上.
当我阅读代码的预测部分时,我不断收到以下错误(线程-1次):
I keep getting the following error (threads - 1 times) when I read the predicting part of the code:
ValueError: Tensor Tensor("dense_1/BiasAdd:0", shape=(?, 10), dtype=float32) is not an element of this graph.
我的理解是,其中一个线程声明"了Tensorflow后端及其默认的Graph和Session.
My understanding is that one of the threads "claims" the Tensorflow backend and its default Graph and Session.
有什么办法可以解决这个问题?
Is there any way to work around that?
推荐答案
我发现我做错了什么. 我的想法是正确的,但我不应该在下面重新创建图表和会话. 代码的底部应该简单地是:
I have figured what I was doing wrong. My thinking was right but I shouldn't have recreated the Graph and Session below. The bottom part of the code should simply be:
thread_locker.acquire()
with thread_graph.as_default():
with thread_session.as_default():
pred_data= model.predict(X_pred)
thread_locker.release()
这篇关于多线程预测时出现Keras错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!