Tensorflow:如何在另一个图上初始化变量? [英] Tensorflow: How to initialize variables on another graph?
问题描述
我有一个默认图表和一个新创建的图表 (G1).
I have a default graph and a newly created graph (G1).
在 G1 中,我有一个名为a"的变量.
In G1, I have a variable named "a".
我可以使用 tf.import_graph_def
将 G1 包含到主图中,并公开其a"变量.
I can use tf.import_graph_def
to include G1 onto the main graph, and expose its "a" variable.
如何初始化这个变量并成功打印a"的值?
实际代码如下:
import tensorflow as tf
INT = tf.int32
def graph():
g = tf.Graph()
with g.as_default() as g:
a = tf.get_variable('a', [], INT, tf.constant_initializer(10))
return g
tf.reset_default_graph()
g = graph()
[g_a] = tf.import_graph_def(g.as_graph_def(), return_elements=['a:0'])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(g_a))
上述方法不起作用,会出现FailedPreconditionError: Attempting to use uninitialized value import/a
的错误.
The above won't work, it will error with FailedPreconditionError: Attempting to use uninitialized value import/a
.
推荐答案
出现错误的原因是,在导入图形定义时,没有导入或恢复任何变量和值.
The reason you get errors is that when you import a graph def, no variables and values are imported or restored.
如果您执行以下操作,您可以在另一个图表中使用变量:
You can use variables in another graph if you do the following:
- 在会话中声明变量,然后运行 tf.global_variables_initalizer()
- 保存变量
- 导入 graph_def 后,恢复变量
- 重要:当您导入图形 def 时,请使用 name='' 以使用与其他图形中相同的命名空间,否则会出现错误
- declare your variable in a session, then run tf.global_variables_initalizer()
- save your variable
- after you import your graph_def, restore your variable
- important: when you import the graph def use name='' to use the same namespace as in your other graph otherwise you get errors
如何做到这一点的最小示例:
A minimal example how to to this:
import tensorflow as tf
INT = tf.int32
def graph():
g = tf.Graph()
with tf.Session(graph=g) as sess:
a = tf.get_variable("a", shape=[1], dtype=INT, initializer=tf.constant_initializer(10))
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.save(sess, './test_dir/test_save.ckpt')
return g
g = graph()
tf.reset_default_graph()
g_a = tf.import_graph_def(g.as_graph_def(), return_elements=['a:0'], name='')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
second_saver = tf.train.Saver(var_list=g_a)
second_saver.restore(sess, './test_dir/test_save.ckpt')
a = sess.graph.get_tensor_by_name('a:0')
print(sess.run(a))
这篇关于Tensorflow:如何在另一个图上初始化变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!