会话恢复后 get_variable() 不起作用 [英] get_variable() does not work after session restoration

查看:25
本文介绍了会话恢复后 get_variable() 不起作用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试恢复会话并调用 get_variable() 来获取类型的对象tf.Variable(根据这个答案).它无法找到变量.重现案例的最小示例是如下.

I try to restore a session and call get_variable() to get an object of type tf.Variable (according to this answer). And it fails to find the variable. The minimal example to reproduce the case is as follows.

首先,创建一个变量并保存会话.

First, create a variable and save the session.

import tensorflow as tf

var = tf.Variable(101)

with tf.Session() as sess:
    with tf.variable_scope(''):
        scoped_var = tf.get_variable('scoped_var', [])

    with tf.variable_scope('', reuse=True):
        new_scoped_var = tf.get_variable('scoped_var', [])

    assert scoped_var is new_scoped_var
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print(sess.run(scoped_var))
    saver.save(sess, 'data/sess')

这里 get_variables 在具有 reuse=True 的范围内工作正常.然后,从文件中恢复会话并尝试获取变量.

Here get_variables inside a scope with reuse=True works fine. Then, restore the session from a file and try to get the variable.

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    for v in tf.get_collection('variables'):
        print(v.name)

    print(tf.get_collection(("__variable_store",)))
    # Oops, it's empty!

    with tf.variable_scope('', reuse=True):
        # the next line fails
        new_scoped_var = tf.get_variable('scoped_var', [])

    print("new_scoped_var: ", new_scoped_var)

输出:

Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

正如我们所见,get_variable() 找不到变量.和("__variable_store",) 集合,由 get_variable() 内部使用,是空的.

As we can see, get_variable() can not find the variable. And ("__variable_store",) collection, that is used internally by get_variable(), is empty.

为什么 get_variable 会失败?

推荐答案

你可以试试这个,而不是处理元图(如果你想修改图以及它的加载方式,这会很有帮助).

Instead of dealing with the meta graph (which can be helpful if you want to modify the graph and how it's loaded etc) you can try this.

import tensorflow as tf

with tf.Session() as sess:
  with tf.variable_scope(''):
    scoped_var = tf.get_variable('scoped_var', [])

  with tf.variable_scope('', reuse=True):
    new_scoped_var = tf.get_variable('scoped_var', [])

  assert scoped_var is new_scoped_var
  saver = tf.train.Saver()
  path = tf.train.get_checkpoint_state('data/sess')
  if path is not None:
    saver.restore(sess, path.model_checkpoint_path)
  else:
    sess.run(tf.global_variables_initializer())

  print(sess.run(scoped_var))
  saver.save(sess, 'data/sess')

  #now continue to use as you normally would with a restored model

主要区别在于您在调用 saver.restore 之前设置了模型

The main difference is you've set up your model before calling saver.restore

这篇关于会话恢复后 get_variable() 不起作用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆