如何在 Tensorflow 中使用 CheckpointReader 恢复变量 [英] How to restore variables using CheckpointReader in Tensorflow
问题描述
如果当前模型中有相同的变量名称,我正在尝试从检查点文件中恢复一些变量.
我发现有一些方法如 GitHub
I'm trying to restore some variables from checkpoint file if same variable name is in current model.
And I found that there is some way as in Tensorfow Github
所以我想要做的是使用 has_tensor("variable.name")
检查检查点文件中的变量名称,如下所示,
So what I want to do is checking variable names in checkpoint file using has_tensor("variable.name")
as below,
...
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
print v.name
if reader.has_tensor(v.name):
print 'has tensor'
...
但我发现v.name
返回变量name
和colon+number
.例如,我有变量名 W_o
和 b_o
然后 v.name
返回 W_o:0, b_o:0
.
But I found that v.name
returns both variable name
and colon+number
. For example, I have variable name W_o
and b_o
then v.name
returns W_o:0, b_o:0
.
然而 reader.has_tensor()
需要 name
没有 colon
和 number
作为 W_o, b_o代码>.
However reader.has_tensor()
requires name
without colon
and number
as W_o, b_o
.
我的问题是:如何去掉变量名末尾的colon
和number
以便读取变量?
有没有更好的方法来恢复这些变量?
My question is: how to remove the colon
and number
at the end of the variable name in order to read the variables?
Is there a better way to restore such variables?
推荐答案
你可以使用 string.split() 获取张量名称:
You could use string.split() to get the tensor name:
...
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
print tensor_name
if reader.has_tensor(tensor_name):
print 'has tensor'
...
接下来,让我用一个例子来说明如何从 .cpkt
文件中恢复每个可能的变量.首先,让我们将 v2
和 v3
保存在 tmp.ckpt
中:
Next, let me use an example to show how I would restore every possible variable from a .cpkt
file. First, let's save v2
and v3
in tmp.ckpt
:
import tensorflow as tf
v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')
saver = tf.train.Saver({'v2': v2, 'v3': v3})
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.save(sess, 'tmp.ckpt')
这就是我如何恢复出现在 tmp.ckpt
中的每个变量(属于一个新图形):
That's how I would restore every variable (belonging to a new graph) showing up in tmp.ckpt
:
with tf.Graph().as_default():
assert len(tf.trainable_variables()) == 0
v1 = tf.Variable(tf.zeros([1]), name='v1')
v2 = tf.Variable(tf.zeros([1]), name='v2')
reader = tf.train.NewCheckpointReader('tmp.ckpt')
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v
saver = tf.train.Saver(restore_dict)
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
saver.restore(sess, 'tmp.ckpt')
print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]
此外,您可能希望确保形状和数据类型匹配.
Also, you may want to ensure that shapes and dtypes match.
这篇关于如何在 Tensorflow 中使用 CheckpointReader 恢复变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!