在 tensorflow 中加载元图和检查点 [英] Loading metagraph and checkpoints in tensorflow

查看:32
本文介绍了在 tensorflow 中加载元图和检查点的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经为此研究了一段时间,但似乎无法破解它.在其他问题中,我看到他们使用这些代码示例来使用元图和检查点文件保存和恢复模型,但是当我做类似的事情时,它说 w1 是未定义的,当我有savemodel 和 restore 模型作为单独的 python 文件.当我只在保存部分结束时进行恢复时,它工作正常,但它违背了必须在单独的文件中重新手动定义所有内容的目的.我查看了检查点文件,它只有两行并且似乎没有引用任何变量或具有任何值,这似乎很奇怪.它只有 1kb.我尝试在打印函数中将 'w1' 作为字符串放入,并且返回 None 而不是我正在寻找的值.这对其他人有用吗?如果是这样,您的检查点文件是什么样的?

I have been working on this for a while now and can't seem to crack it. In other questions I have seen them use these code samples in order to save and restore a model using the metagraph and checkpoint files, but when I do something similar to this it says that w1 is undefined when I have the savemodel and restore model as separate python files. It works ok when I just have the restore at the end of the saving portion but it defeats the purpose to have to hand define everything all over again in a seperate file. I have looked into the checkpoint file and it seems bizarre that it only has two lines and it doesnt seem to reference any variables or have any values. it is only 1kb. I have tried putting in 'w1' as a string in the print function instead and that returns a None rather than the values I am looking for. Does this work for anyone else? if so, what do your checkpoint files look like?

#Saving
import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver([w1,w2])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model',global_step=1000)

#restoring
with tf.Session() as sess:    
    saver = tf.train.import_meta_graph('my_test_model-1000.meta',clear_devices=True)
    saver.restore(sess,tf.train.latest_checkpoint('./'))
    print sess.run(w1)

推荐答案

您的图形已正确保存,但还原它不会还原包含图形节点的变量.w1 是一个 python 变量,您从未在恢复"部分代码中声明过.为了重新掌握你的体重,

Your graph is saved correctly, but restoring it does not restore your variables that contain nodes of the graph. w1 is a python variable that you've never declared in you 'restoring' part of the code. To get back a handle on your weights,

  • 您可以在 TF 图中使用它们的名称:w1=get_variable(name='w1').问题是您必须密切注意您的名称范围,并确保您没有多个同名变量(在这种情况下,TF 会将 '_1' 添加到它们的名称之一,因此您可能弄错了).如果你这样做,张量板可以帮助你知道每个变量的确切名称.

  • you can use their names in the TF graph: w1=get_variable(name='w1'). The problem is that you'll have to pay close attention to your name scopes, and make sure that you don't have multiple variables of the same name (in which case TF adds '_1' to one of their names, so you might get the wrong one). If you go that way, tensorboard can be of great help to know the exact name of each variable.

您可以使用集合:将感兴趣的节点保存在集合中,并在恢复后从中取回.构建图形时,在保存之前,请执行例如:tf.add_to_collection('weights', w1)tf.add_to_collection('weights', w2),以及在您的恢复代码中:[w1, w2] = tf.get_collection('weights1').然后就可以正常使用w1和w2了.

You can use collections: save the interesting nodes in collections, and get them back from them after restoring. When building the graph, before saving it, do for instance: tf.add_to_collection('weights', w1) and tf.add_to_collection('weights', w2), and in your restoring code: [w1, w2] = tf.get_collection('weights1'). Then you'll be able to use w1 and w2 normally.

我认为后者虽然更冗长,但对于未来架构的变化可能更好.我知道所有这些看起来都很冗长,但请记住,通常您不必重新处理所有变量,但只需处理其中的几个:输入、输出和训练步骤通常就足够了.

I think the latter, though more verbose, is probably better with regard to future changes in your architecture. I know all of this looks quite verbose, but remember that usually you don't have to get back handles on all your variables, but on few of them: the inputs, outputs, and train step are usually enough.

这篇关于在 tensorflow 中加载元图和检查点的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆