在同一个 Tensorflow 会话中从 Saver 加载两个模型 [英] Loading two models from Saver in the same Tensorflow session

查看:31
本文介绍了在同一个 Tensorflow 会话中从 Saver 加载两个模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个网络:一个生成输出的 Model 和一个对输出进行评分的 Adversary.

I have two networks: a Model which generates output and an Adversary which grades the output.

两者都分别接受过训练,但现在我需要在一次会话中合并他们的输出.

Both have been trained separately but now I need to combine their outputs during a single session.

我已尝试实施本文中提出的解决方案:同时运行多个预训练的 Tensorflow 网络

I've attempted to implement the solution proposed in this post: Run multiple pre-trained Tensorflow nets at the same time

我的代码

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题

对函数 model_saver.restore() 的调用似乎什么也没做.在另一个模块中,我使用带有 tf.train.Saver(tf.global_variables()) 的保护程序,它可以很好地恢复检查点.

The call to the function model_saver.restore() appears to be doing nothing. In another module I use a saver with tf.train.Saver(tf.global_variables()) and it restores the checkpoint fine.

该模型具有 model.tvars = tf.trainable_variables().为了检查发生了什么,我使用了 sess.run() 来提取恢复前后的 tvars.每次使用初始随机分配的变量而未分配检查点的变量时.

The model has model.tvars = tf.trainable_variables(). To check what was happening I used sess.run() to extract the tvars before and after restore. Each time the initial randomly assigned variables are being used and the variables from the checkpoint are not being assigned.

有没有想过为什么 model_saver.restore() 似乎什么都不做?

Any thoughts on why model_saver.restore() appears to be doing nothing?

推荐答案

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它.

Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.

为了诊断问题,我手动遍历每个变量并一一分配.然后我注意到在分配变量后名称会改变.这在这里描述:TensorFlow checkpoint save and read

To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read

根据那篇文章中的建议,我在自己的图表中运行了每个模型.这也意味着我必须在自己的会话中运行每个图.这意味着以不同的方式处理会话管理.

Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.

首先我创建了两个图表

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

然后是两个会话

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我在每个会话中初始化变量并分别恢复每个图形

Then I initialised the variables in each session and restored each graph separately

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

从这里每当需要每个会话时,我都会使用 with sess.as_default(): 包装该会话中的任何 tf 函数.最后我手动关闭会话

From here whenever each session was needed I would wrap any tf functions in that session with with sess.as_default():. At the end I manually close the sessions

sess.close()
adv_sess.close()

这篇关于在同一个 Tensorflow 会话中从 Saver 加载两个模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆