tensorflow.train.import_meta_graph 不起作用? [英] tensorflow.train.import_meta_graph does not work?

查看:30
本文介绍了tensorflow.train.import_meta_graph 不起作用?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我尝试简单地保存和恢复图形,但最简单的示例没有按预期工作(这是在 Linux 64 上使用 0.9.0 或 0.10.0 版完成的,没有使用 python 2.7 或 3.5.2 的 CUDA)

首先我像这样保存图表:

 将 tensorflow 导入为 tfv1 = tf.placeholder('float32')v2 = tf.placeholder('float32')v3 = tf.mul(v1,v2)c1 = tf.constant(22.0)v4 = tf.add(v3,c1)sess = tf.Session()结果 = sess.run(v4,feed_dict={v1:12.0, v2:3.3})g1 = tf.train.export_meta_graph("文件")## 或者我也尝试过:## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

这会创建一个非空的文件file",并将 g1 设置为看起来像正确图形定义的内容.

然后我尝试恢复这个图:

 将 tensorflow 导入为 tfg=tf.train.import_meta_graph("文件")

这没有错误,但根本不返回任何内容.

谁能提供必要的代码来简单地为v4"保存图形并完全恢复它,以便在新会话中运行它会产生相同的结果?

解决方案

要重用 MetaGraphDef,您需要在原始图中记录感兴趣的张量的名称.例如,在第一个程序中,在 v1v2v4 的定义中设置显式的 name 参数:

v1 = tf.placeholder(tf.float32, name="v1")v2 = tf.placeholder(tf.float32, name="v2")# ...v4 = tf.add(v3, c1, name="v4")

然后,您可以在调用 sess.run() 时使用原始图中张量的字符串名称.例如,以下代码段应该可以工作:

 将 tensorflow 导入为 tf_ = tf.train.import_meta_graph("./file")sess = tf.Session()结果 = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})

或者,您可以使用 tf.get_default_graph().get_tensor_by_name() 获取感兴趣张量的 tf.Tensor 对象,然后您可以将其传递给 sess.run():

 将 tensorflow 导入为 tf_ = tf.train.import_meta_graph("./file")g = tf.get_default_graph()v1 = g.get_tensor_by_name("v1:0")v2 = g.get_tensor_by_name("v2:0")v4 = g.get_tensor_by_name("v4:0")sess = tf.Session()结果 = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})

<小时>

UPDATE:基于评论中的讨论,这里有一个完整的保存和加载示例,包括保存变量内容.这说明了通过在单独的操作中将变量 vx 的值加倍来保存变量.

保存:

 将 tensorflow 导入为 tfv1 = tf.placeholder(tf.float32, name="v1")v2 = tf.placeholder(tf.float32, name="v2")v3 = tf.mul(v1, v2)vx = tf.Variable(10.0, name="vx")v4 = tf.add(v3, vx, name="v4")saver = tf.train.Saver([vx])sess = tf.Session()sess.run(tf.initialize_all_variables())sess.run(vx.assign(tf.add(vx, vx)))结果 = sess.run(v4, feed_dict={v1:12.0, v2:3.3})打印(结果)saver.save(sess, "./model_ex1")

恢复:

 将 tensorflow 导入为 tfsaver = tf.train.import_meta_graph("./model_ex1.meta")sess = tf.Session()saver.restore(sess, "./model_ex1")结果 = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})打印(结果)

最重要的是,为了使用保存的模型,您必须记住至少一些节点的名称(例如训练操作、输入占位符、评估张量等).MetaGraphDef 存储模型中包含的变量列表,并有助于从检查点恢复这些变量,但您需要自己重建用于训练/评估模型的张量/操作.

I try to simply save and restore a graph, but the simplest example does not work as expected (this is done using version 0.9.0 or 0.10.0 on Linux 64 without CUDA using python 2.7 or 3.5.2)

First I save the graph like this:

import tensorflow as tf
v1 = tf.placeholder('float32') 
v2 = tf.placeholder('float32')
v3 = tf.mul(v1,v2)
c1 = tf.constant(22.0)
v4 = tf.add(v3,c1)
sess = tf.Session()
result = sess.run(v4,feed_dict={v1:12.0, v2:3.3})
g1 = tf.train.export_meta_graph("file")
## alternately I also tried:
## g1 = tf.train.export_meta_graph("file",collection_list=["v4"])

This creates a file "file" that is non-empty and also sets g1 to something that looks like a proper graph definition.

Then I try to restore this graph:

import tensorflow as tf
g=tf.train.import_meta_graph("file")

This works without an error, but does not return anything at all.

Can anyone provide the necessary code to simply just save the graph for "v4" and completely restore it so that running this in a new session will produce the same result?

解决方案

To reuse a MetaGraphDef, you will need to record the names of interesting tensors in your original graph. For example, in the first program, set an explicit name argument in the definition of v1, v2 and v4:

v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
# ...
v4 = tf.add(v3, c1, name="v4")

Then, you can use the string names of the tensors in the original graph in your call to sess.run(). For example, the following snippet should work:

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")

sess = tf.Session()
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})

Alternatively, you can use tf.get_default_graph().get_tensor_by_name() to get tf.Tensor objects for the tensors of interest, which you can then pass to sess.run():

import tensorflow as tf
_ = tf.train.import_meta_graph("./file")
g = tf.get_default_graph()

v1 = g.get_tensor_by_name("v1:0")
v2 = g.get_tensor_by_name("v2:0")
v4 = g.get_tensor_by_name("v4:0")

sess = tf.Session()
result = sess.run(v4, feed_dict={v1: 12.0, v2: 3.3})


UPDATE: Based on discussion in the comments, here a the complete example for saving and loading, including saving the variable contents. This illustrates the saving of a variable by doubling the value of variable vx in a separate operation.

Saving:

import tensorflow as tf
v1 = tf.placeholder(tf.float32, name="v1") 
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.mul(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver([vx])
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")

Restoring:

import tensorflow as tf
saver = tf.train.import_meta_graph("./model_ex1.meta")
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
print(result)

The bottom line is that, in order to make use of a saved model, you must remember the names of at least some of the nodes (e.g. a training op, an input placeholder, an evaluation tensor, etc.). The MetaGraphDef stores the list of variables that are contained in the model, and helps to restore these from a checkpoint, but you are required to reconstruct the tensors/operations used in training/evaluating the model yourself.

这篇关于tensorflow.train.import_meta_graph 不起作用?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆