Tensorflow:如何在另一个图上初始化变量? [英] Tensorflow: How to initialize variables on another graph?

查看:33
本文介绍了Tensorflow:如何在另一个图上初始化变量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个默认图表和一个新创建的图表 (G1).

I have a default graph and a newly created graph (G1).

在 G1 中,我有一个名为a"的变量.

In G1, I have a variable named "a".

我可以使用 tf.import_graph_def 将 G1 包含到主图中,并公开其a"变量.

I can use tf.import_graph_def to include G1 onto the main graph, and expose its "a" variable.

如何初始化这个变量并成功打印a"的值?

实际代码如下:

import tensorflow as tf

INT = tf.int32


def graph():
    g = tf.Graph()
    with g.as_default() as g:
        a = tf.get_variable('a', [], INT, tf.constant_initializer(10))
    return g


tf.reset_default_graph()

g = graph()
[g_a] = tf.import_graph_def(g.as_graph_def(), return_elements=['a:0'])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(g_a))

上述方法不起作用,会出现FailedPreconditionError: Attempting to use uninitialized value import/a 的错误.

The above won't work, it will error with FailedPreconditionError: Attempting to use uninitialized value import/a.

推荐答案

出现错误的原因是,在导入图形定义时,没有导入或恢复任何变量和值.

The reason you get errors is that when you import a graph def, no variables and values are imported or restored.

如果您执行以下操作,您可以在另一个图表中使用变量:

You can use variables in another graph if you do the following:

  • 在会话中声明变量,然后运行 ​​tf.global_variables_initalizer()
  • 保存变量
  • 导入 graph_def 后,恢复变量
  • 重要:当您导入图形 def 时,请使用 name='' 以使用与其他图形中相同的命名空间,否则会出现错误
  • declare your variable in a session, then run tf.global_variables_initalizer()
  • save your variable
  • after you import your graph_def, restore your variable
  • important: when you import the graph def use name='' to use the same namespace as in your other graph otherwise you get errors

如何做到这一点的最小示例:

A minimal example how to to this:

import tensorflow as tf

INT = tf.int32

def graph():
    g = tf.Graph()
    with tf.Session(graph=g) as sess:
        a = tf.get_variable("a", shape=[1], dtype=INT, initializer=tf.constant_initializer(10))
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.save(sess, './test_dir/test_save.ckpt')
        return g


g = graph()

tf.reset_default_graph()

g_a = tf.import_graph_def(g.as_graph_def(), return_elements=['a:0'], name='')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    second_saver = tf.train.Saver(var_list=g_a)
    second_saver.restore(sess, './test_dir/test_save.ckpt')
    a = sess.graph.get_tensor_by_name('a:0')
    print(sess.run(a))

这篇关于Tensorflow:如何在另一个图上初始化变量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆