恢复 Tensorflow 模型并查看变量值 [英] Restoring Tensorflow model and viewing variable value

查看:31
本文介绍了恢复 Tensorflow 模型并查看变量值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我声明了一些代表权重和偏差的 Tensorflow 变量,并在保存之前在训练中更新了它们的值,如下所示:

I declared the some Tensorflow variables representing weights and biases and updated their values in training before saving them, as shown:

#                # 5 x 5 x 5 patches, 1 channel, 32 features to compute.
weights = {'W_conv1':tf.Variable(tf.random_normal([3,3,3,1,32]), name='w_conv1'),
           #       5 x 5 x 5 patches, 32 channels, 64 features to compute.
           'W_conv2':tf.Variable(tf.random_normal([3,3,3,32,64]), name='w_conv2'),
           #                                  64 features
           'W_fc':tf.Variable(tf.random_normal([32448,1024]), name='w_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(10/2/2) * 64
           #'W_fc':tf.Variable(tf.random_normal([54080,1024]), name='W_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(20/2/2) * 64
           'out':tf.Variable(tf.random_normal([1024, n_classes]), name='w_out')}

biases = {'b_conv1':tf.Variable(tf.random_normal([32]), name='b_conv1'),
           'b_conv2':tf.Variable(tf.random_normal([64]), name='b_conv2'),
           'b_fc':tf.Variable(tf.random_normal([1024]), name='b_fc'),
           'out':tf.Variable(tf.random_normal([n_classes]), name='b_out')}

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    #some training code

    saver = tf.train.Saver()
    saver.save(sess, 'my-save-dir/my-model-10')

然后,我尝试恢复模型并访问如下所示的变量:

Then, I tried restoring the model and accessing the variables as shown below:

weights = {'W_conv1':tf.Variable(-1.0, validate_shape=False, name='w_conv1'),
           #       5 x 5 x 5 patches, 32 channels, 64 features to compute.
           'W_conv2':tf.Variable(-1.0, validate_shape=False, name='w_conv2'),
           #                                  64 features
           'W_fc':tf.Variable(-1.0, validate_shape=False, name='w_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(10/2/2) * 64
           #'W_fc':tf.Variable(tf.random_normal([54080,1024]), name='W_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(20/2/2) * 64
           'out':tf.Variable(-1.0, validate_shape=False, name='w_out')}

biases = {'b_conv1':tf.Variable(-1.0, validate_shape=False, name='b_conv1'),
           'b_conv2':tf.Variable(-1.0, validate_shape=False, name='b_conv2'),
           'b_fc':tf.Variable(-1.0, validate_shape=False, name='b_fc'),
           'out':tf.Variable(-1.0, validate_shape=False, name='b_out')}

with tf.Session() as sess:
    model_saver = tf.train.import_meta_graph('my-save-dir/my-model-10.meta')
    model_saver.restore(sess, "my-save-dir/my-model-10")
    print("Model restored.") 
    print('Initialized')
    print(sess.run(weights['W_conv1']))

但是,我收到了FailedPreconditionError: Attempting to use uninitialized value w_conv1".请协助.

However, I got a "FailedPreconditionError: Attempting to use uninitialized value w_conv1". Please assist.

推荐答案

这是第二个代码片段中发生的事情:首先创建 w_conv1b_out 的所有变量,所以默认图填充了相应的节点.然后您调用 import_meta_graph(..),其中再次使用您存储在第一个代码片段中的模型中的所有节点填充默认图形.但是,对于它尝试加载的每个节点,已经存在另一个具有相同名称的节点(因为您之前手动"创建了它).我不知道在这种情况下内部会发生什么,但是在调用 import_meta_graph(..) 之后查看 tf.global_variables() 的输出揭示了这一点现在每个节点都以完全相同的名称存在两次.所以恢复可能是未定义的,它可能只是恢复了一半的变量,这就是你看到这个错误的原因.

Here's what happens in your second code snippet: you first create all the variables w_conv1 to b_out, so the default graph is populated with the respective nodes. Then you call import_meta_graph(..) where again the default graph is populated with all the nodes from the model you stored in your first code snippet. However, for every node it tries to load, another node with the same name already exists (because you created it "by hand" just before). I don't know what happens internally in this case, but taking a look at the output of tf.global_variables() after the call to import_meta_graph(..) reveals that now every node exists twice with exactly the same name. So restoring is probably undefined and it might just restore half of the variables which is why you see this error.

因此,您有两种方法可以解决此问题:

So, you have two possibilites to solve this:

1) 不要使用 import_from_metagraph

weights = {'W_conv1':tf.Variable(tf.random_normal([3,3,3,1,32]), name='w_conv1'),
           #       5 x 5 x 5 patches, 32 channels, 64 features to compute.
           'W_conv2':tf.Variable(tf.random_normal([3,3,3,32,64]), name='w_conv2'),
           #                                  64 features
           'W_fc':tf.Variable(tf.random_normal([32448,1024]), name='w_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(10/2/2) * 64
           #'W_fc':tf.Variable(tf.random_normal([54080,1024]), name='W_fc'), #54080 = ceil(50/2/2) * ceil(50/2/2) * ceil(20/2/2) * 64
           'out':tf.Variable(tf.random_normal([1024, n_classes]), name='w_out')}

biases = {'b_conv1':tf.Variable(tf.random_normal([32]), name='b_conv1'),
           'b_conv2':tf.Variable(tf.random_normal([64]), name='b_conv2'),
           'b_fc':tf.Variable(tf.random_normal([1024]), name='b_fc'),
           'out':tf.Variable(tf.random_normal([n_classes]), name='b_out')}

with tf.Session() as sess:
    model_saver = tf.train.Saver()
    model_saver.restore(sess, "my-save-dir/my-model-10")
    print("Model restored.")
    print('Initialized')
    print(sess.run(weights['W_conv1']))

2) 使用 import_from_metagraph 但不要手动重新创建图表

2) Use import_from_metagraph but don't recreate the graph from hand

所以,就这样:

with tf.Session() as sess:
    model_saver = tf.train.import_meta_graph('my-save-dir/my-model-10.meta')
    model_saver.restore(sess, "my-save-dir/my-model-10")
    print("Model restored.") 
    print('Initialized')
    print(sess.run(tf.get_default_graph().get_tensor_by_name('w_conv1:0')))

请注意,在这种情况下,您需要如何更改检索w_conv1"(最后一行)中的值的方式.除了调用 get_tensor_by_name(),您还可以使用 tf.get_variable(),但要使其正常工作,您必须使用 tf.get_variable().查看此帖子了解更多详情:TensorFlow:按名称获取变量

Note how in this case you need to change how you retrieve the value in 'w_conv1' (last line). Instead of calling get_tensor_by_name() you could also use tf.get_variable(), but for this to work, you have to create the variables already by using tf.get_variable(). Check this post for more details: TensorFlow: getting variable by name

这篇关于恢复 Tensorflow 模型并查看变量值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆