在Keras和Tensorflow中复制模型以实现多线程设置 [英] Replicating models in Keras and Tensorflow for a multi-threaded setting

查看:393
本文介绍了在Keras和Tensorflow中复制模型以实现多线程设置的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试在Keras和TensorFlow中实现actor-critic的异步版本.我将Keras用作构建网络层的前端(我直接使用tensorflow更新参数).我有一个global_model和一个主要的tensorflow会话.但是在每个线程中,我正在创建一个local_model,它从global_model复制参数.我的代码看起来像这样

I am trying to implement the asynchronous version of actor-critic in Keras and TensorFlow. I am using Keras just as a front-end for building my network layers (I am updating the parameters directly with tensorflow). I have a global_model and one main tensorflow session. But inside each thread I am creating a local_model which copies parameters from the global_model. My code looks something like this

def main(args):
    config=tf.ConfigProto(log_device_placement=False,allow_soft_placement=True)
    sess = tf.Session(config=config)
    K.set_session(sess) # K is keras backend
    global_model = ConvNetA3C(84,84,4,num_actions=3)

    threads = [threading.Thread(target=a3c_thread, args=(i, sess, global_model)) for i in range(NUM_THREADS)]

    for t in threads:
        t.start()

def a3c_thread(i, sess, global_model):
    K.set_session(sess) # registering a session for each thread (don't know if it matters)
    local_model = ConvNetA3C(84,84,4,num_actions=3)
    sync = local_model.get_from(global_model) # I get the error here

    #in the get_from function I do tf.assign(dest.params[i], src.params[i])

我从Keras收到用户警告

I get a user warning from Keras

UserWarning:默认TensorFlow图不是关联的图 当前在Keras注册的TensorFlow会话,以及 这样的Keras无法自动初始化变量.你 应该考虑通过以下方式与Keras注册适当的会话 K.set_session(sess)

UserWarning: The default TensorFlow graph is not the graph associated with the TensorFlow session currently registered with Keras, and as such Keras was not able to automatically initialize a variable. You should consider registering the proper session with Keras via K.set_session(sess)

后跟tf.assign操作上的张量流错误,表明操作必须在同一图上.

followed by a tensorflow error on the tf.assign operation saying operations must be on the same graph.

ValueError:Tensor("conv1_W:0",shape =(8,8,4,16), dtype = float32_ref,device =/device:CPU:0)必须来自同一图 作为Tensor("conv1_W:0",shape =(8,8,4,16),dtype = float32_ref)

ValueError: Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("conv1_W:0", shape=(8, 8, 4, 16), dtype=float32_ref)

我不确定是怎么回事.

谢谢

推荐答案

该错误来自Keras,因为tf.get_default_graph() is sess.graph返回False.从TF文档中,我看到tf.get_default_graph()返回当前线程的默认图形.从启动新线程并创建图形的那一刻起,它就被构建为特定于该线程的单独图形.我可以通过执行以下操作来解决此问题,

The error comes from Keras because tf.get_default_graph() is sess.graph is returning False. From the TF docs, I see that tf.get_default_graph() is returning the default graph for the current thread. The moment I start a new thread and create a graph, it is built as a separate graph specific to that thread. I can solve this issue by doing the following,

with sess.graph.as_default():
   local_model = ConvNetA3C(84,84,4,3)

这篇关于在Keras和Tensorflow中复制模型以实现多线程设置的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆