TensorFlow-导入元图并使用其中的变量 [英] TensorFlow - import meta graph and use variables from it

查看:158
本文介绍了TensorFlow-导入元图并使用其中的变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用TensorFlow v0.12训练分类CNN,然后希望使用训练后的模型为新数据创建标签.

I'm training classification CNN using TensorFlow v0.12, and then want to create labels for new data using the trained model.

在培训脚本的末尾,我添加了以下代码行:

At the end of the training script, I added those lines of code:

saver = tf.train.Saver()
save_path = saver.save(sess,'/home/path/to/model/model.ckpt')

培训结束后,出现在文件夹中的文件为:1. checkpoint ; 2. model.ckpt.data-00000-of-00001 ; 3. model.ckpt.index ; 4. model.ckpt.meta

After the training completed, the files appearing in the folder are: 1. checkpoint ; 2. model.ckpt.data-00000-of-00001 ; 3. model.ckpt.index ; 4. model.ckpt.meta

然后,我尝试使用 .meta 文件还原模型.在本教程之后,我将以下行添加到了分类代码中:

Then I tried to restore the model using the .meta file. Following this tutorial, I added the following line into my classification code:

saver=tf.train.import_meta_graph(savepath+'model.ckpt.meta') #line1

然后:

saver.restore(sess, save_path=savepath+'model.ckpt') #line2

在进行此更改之前,我需要再次构建图形,然后编写(代替line1):

Before that change, I needed to build the graph again, and then write (instead of line1):

saver = tf.train.Saver()

但是,删除图形构造并使用line1进行还原会引发错误.错误是我在代码中使用了图表中的变量,而python无法识别它:

But, deleting the graph building, and using line1 in order to restore it, raised an error. The error was that I used a variable from the graph inside my code, and the python didn't recognize it:

predictions = sess.run(y_conv, feed_dict={x: patches,keep_prob: 1.0})

Python无法识别y_conv参数.有一种使用元图还原变量的方法吗?如果没有,如果我不能使用原始图形中的变量,此恢复有什么帮助?

The python didn't recognize the y_conv parameter. There is a way to restore the variables using the meta graph? if not, what os this restore helping, if I can't use variables from the original graph?

我知道这个问题不清楚,但是我很难用语言表达这个问题.抱歉...

I know this question isn't so clear, but it was hard for me to express the problem in words. Sorry about it...

感谢您的答复,感谢您的帮助!投资回报率.

Thanks for answering, appreciate your help! Roi.

推荐答案

有可能,不用担心.假设您不想再触摸该图,请执行以下操作:

it is possible, don't worry. Assuming you don't want to touch the graph anymore, do something like this:

saver = tf.train.import_meta_graph('model/export/{}.meta'.format(model_name))
saver.restore(sess, 'model/export/{}'.format(model_name))
graph = tf.get_default_graph()       
y_conv = graph.get_operation_by_name('y_conv').outputs[0]
predictions = sess.run(y_conv, feed_dict={x: patches,keep_prob: 1.0})

但是,一种首选方法是在构建图形并引用它们时将ops添加到集合中.因此,当您定义图形时,您将添加以下行:

A preferred way would however be adding the ops into collections when you build the graph and then referring to them. So when you define the graph, you would add the line:

tf.add_to_collection("y_conv", y_conv)

然后,在导入元图并将其还原后,您将调用:

And then after you import the metagraph and restore it, you would call:

y_conv = tf.get_collection("y_conv")[0]

它实际上在文档中进行了解释-您链接的确切页面-但也许您错过了它.

It is actually explained in the documentation - the exact page you linked - but perhaps you missed it.

顺便说一句,不需要.ckpt扩展名,这可能会造成一些混乱,因为这是保存模型的旧方法.

Btw, no need for the .ckpt extension, it might create some confusion as that is the old way of saving models.

这篇关于TensorFlow-导入元图并使用其中的变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆