会话恢复后 get_variable() 不起作用 [英] get_variable() does not work after session restoration
问题描述
我尝试恢复会话并调用 get_variable()
来获取类型的对象tf.Variable(根据这个答案).它无法找到变量.重现案例的最小示例是如下.
I try to restore a session and call get_variable()
to get an object of type
tf.Variable (according to this answer).
And it fails to find the variable. The minimal example to reproduce the case is
as follows.
首先,创建一个变量并保存会话.
First, create a variable and save the session.
import tensorflow as tf
var = tf.Variable(101)
with tf.Session() as sess:
with tf.variable_scope(''):
scoped_var = tf.get_variable('scoped_var', [])
with tf.variable_scope('', reuse=True):
new_scoped_var = tf.get_variable('scoped_var', [])
assert scoped_var is new_scoped_var
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print(sess.run(scoped_var))
saver.save(sess, 'data/sess')
这里 get_variables
在具有 reuse=True
的范围内工作正常.然后,从文件中恢复会话并尝试获取变量.
Here get_variables
inside a scope with reuse=True
works fine.
Then, restore the session from a file and try to get the variable.
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('data/sess.meta')
saver.restore(sess, 'data/sess')
for v in tf.get_collection('variables'):
print(v.name)
print(tf.get_collection(("__variable_store",)))
# Oops, it's empty!
with tf.variable_scope('', reuse=True):
# the next line fails
new_scoped_var = tf.get_variable('scoped_var', [])
print("new_scoped_var: ", new_scoped_var)
输出:
Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?
正如我们所见,get_variable()
找不到变量.和("__variable_store",)
集合,由 get_variable()
内部使用,是空的.
As we can see, get_variable()
can not find the variable. And
("__variable_store",)
collection, that is used internally by get_variable()
,
is empty.
为什么 get_variable
会失败?
推荐答案
你可以试试这个,而不是处理元图(如果你想修改图以及它的加载方式,这会很有帮助).
Instead of dealing with the meta graph (which can be helpful if you want to modify the graph and how it's loaded etc) you can try this.
import tensorflow as tf
with tf.Session() as sess:
with tf.variable_scope(''):
scoped_var = tf.get_variable('scoped_var', [])
with tf.variable_scope('', reuse=True):
new_scoped_var = tf.get_variable('scoped_var', [])
assert scoped_var is new_scoped_var
saver = tf.train.Saver()
path = tf.train.get_checkpoint_state('data/sess')
if path is not None:
saver.restore(sess, path.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
print(sess.run(scoped_var))
saver.save(sess, 'data/sess')
#now continue to use as you normally would with a restored model
主要区别在于您在调用 saver.restore 之前设置了模型
The main difference is you've set up your model before calling saver.restore
这篇关于会话恢复后 get_variable() 不起作用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!