如何在 Tensorflow 中恢复当前模型的预训练检查点? [英] How to restore pretrained checkpoint for current model in Tensorflow?

查看:43
本文介绍了如何在 Tensorflow 中恢复当前模型的预训练检查点?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个预训练的检查点.现在我正在尝试将这个预训练模型恢复到当前网络.但是,变量名称不同.Tensorflow 文档 说使用字典如下:

I have a pretrained checkpoint. And now I'm trying to restore this pretrained model to the current network. However, variable names are different. Tensorflow document says that using dictionary like:

v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)
saver = tf.train.Saver({"v2": v2})

然而,当前网络中的变量定义如下:

However, variables in current network are defined like:

with tf.variable_scope('a'):
    b=tf.get_variable(......)

所以,变量名似乎是a/b.如何让字典像"v2": a/b?

So, the variable name seems to be a/b. How to make the dictionary like "v2": a/b?

推荐答案

您可以使用 tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 来获取当前图中所有变量名称的列表.您还可以指定范围.

You can use tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)to get a list of all variable names in current graph. You also can specify scope.

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')

您可以使用 tf.train.list_variables(ckpt_file) 获取检查点中所有变量的列表.

You can use tf.train.list_variables(ckpt_file) to get a list of all variables in checkpoint.

假设您的检查点中有变量 b,并且您想在 tf.variable_scope('a') 中以名称 a/b 加载.为此,您只需定义它

Suppose you have variable b in your checkpoint, and you want to load inside tf.variable_scope('a') under name a/b. To do that you just define it

with tf.variable_scope('a'):
    b=tf.get_variable(......)

并加载

saver = tf.train.Saver({'v2': b})

with tf.Session() as sess:
    saver.restore(sess, ckpt_file))
    print(b)

这将输出

<tf.Variable 'a/b:0' shape dtype>

如前所述,您可以使用

vars_dict = {}
for var_current in tf.global_variables():
    print(var_current)
    print(var_current.op.name) # this gets only name

for var_ckpt in tf.train.list_variables(ckpt):
    print(var_ckpt[0]) this gets only name

当您知道所有变量的确切名称时,您可以分配您需要的任何值,前提是变量具有相同的形状和数据类型所以要得到一个字典

When you know exact names of all variables you can assign whatever value you need, provided variables have same shape and dtype So to get a dict

vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current 

您可以显式地或在您认为合适的任何类型的循环中构建此字典.然后你把它传给 saver

You can construct this dictionary either explicitly or in any kind of loop you'll see fit. And then you pass it to saver

saver = tf.train.Saver(vars_dict)

这篇关于如何在 Tensorflow 中恢复当前模型的预训练检查点?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆