TensorFlow:从多个检查点恢复变量 [英] TensorFlow: Restoring variables from from multiple checkpoints

查看:24
本文介绍了TensorFlow:从多个检查点恢复变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有以下情况:

  • 我有 2 个用 2 个单独脚本编写的模型:

  • I have 2 models written in 2 separate scripts:

模型A由变量a1a2a3组成,编写在A.py

Model A consists of variables a1, a2, and a3, and is written in A.py

模型B由变量b1b2b3组成,用B.py编写

Model B consists of variables b1, b2, and b3, and is written in B.py

A.pyB.py 的每一个中,我都有一个 tf.train.Saver 来保存所有的检查点局部变量,让我们分别调用检查点文件 ckptAckptB.

In each of A.py and B.py, I have a tf.train.Saver that saves the checkpoint of all the local variables, and let's call the checkpoint files ckptA and ckptB respectively.

我现在想制作一个使用 a1b1 的模型 C.我可以通过使用 var_scope 来使 a1 的变量名称完全相同(对于 b1 也是如此).

I now want to make a model C that uses a1 and b1. I can make it so that the exact same variable name for a1 is used in both A and C by using the var_scope (and the same for b1).

问题是如何将 a1b1ckptAckptB 加载到模型 C 中?例如,以下是否可行?

The question is how might I load a1 and b1 from ckptA and ckptB into model C? For example, would the following work?

saver.restore(session, ckptA_location)
saver.restore(session, ckptB_location)

如果您尝试恢复同一个会话两次,是否会引发错误?它会抱怨没有为额外变量(b2b3a2a3>),或者它只是简单地恢复它可以恢复的变量,并且只在 C 中存在一些未初始化的其他变量时才抱怨?

Would an error be raised if you are try to restore the same session twice? Would it complain that there are no allocated "slots" for the extra variables (b2, b3, a2, a3), or would it simply restore the variables it can, and only complain if there are some other variables in C that are uninitialized?

我现在正在尝试编写一些代码来测试这个,但我很想看到一个规范的方法来解决这个问题,因为在尝试重新使用一些预先训练的权重时经常会遇到这种情况.

I'm trying to write some code to test this now but I would love to see a canonical approach to this problem, because one encounters this often when trying to re-use some pre-trained weights.

谢谢!

推荐答案

如果您尝试使用保护程序(默认情况下表示所有六个变量)来恢复,您将收到 tf.errors.NotFoundError从不包含保护程序代表的所有变量的检查点.(但请注意,您可以在同一个会话中多次调用 Saver.restore(),对于变量的任何子集,只要所有请求的变量都存在于相应的文件中.)

You would get a tf.errors.NotFoundError if you tried to use a saver (by default representing all six variables) to restore from a checkpoint that does not contain all of the variables that the saver represents. (Note however that you are free to call Saver.restore() multiple times in the same session, for any subset of the variables, as long as all of the requested variables are present in the corresponding file.)

规范方法是定义两个独立的tf.train.Saver 实例 覆盖完全包含在单个检查点中的每个变量子集.例如:

The canonical approach is to define two separate tf.train.Saver instances covering each subset of variables that is entirely contained in a single checkpoint. For example:

saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])

saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)

根据你的代码是如何构建的,如果你有指向局部作用域中名为 a1b1tf.Variable 对象的指针,你可以在这里停止阅读.

Depending on how your code is built, if you have pointers to tf.Variable objects called a1 and b1 in the local scope, you can stop reading here.

另一方面,如果变量 a1b1 在单独的文件中定义,您可能需要做一些创造性的事情来检索指向这些变量的指针.虽然不理想,但人们通常做的是使用一个共同的前缀,例如如下(假设变量名是"a1:0" and "b1:0"> 分别):

On the other hand, if variables a1 and b1 are defined in separate files, you might need to do something creative to retrieve pointers to those variables. Although it's not ideal, what people typically do is to use a common prefix, for example as follows (assuming the variable names are "a1:0" and "b1:0" respectively):

saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"])
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"])

最后一点:您不必费尽心思确保变量在 A 和 C 中具有相同的名称.您可以将 name-to-Variable 字典作为第一个传递tf.train.Saver 构造函数的参数,从而将检查点文件中的名称重新映射到代码中的 Variable 对象.如果 A.pyB.py 具有类似命名的变量,或者如果您想在 C.py 中组织模型代码,这会有所帮助来自 tf.name_scope() 中的这些文件.

One final note: you don't have to make heroic efforts to ensure that the variables have the same names in A and C. You can pass a name-to-Variable dictionary as the first argument to the tf.train.Saver constructor, and thereby remap names in the checkpoint file to Variable objects in your code. This helps if A.py and B.py have similarly-named variables, or if in C.py you want to organize the model code from those files in a tf.name_scope().

这篇关于TensorFlow:从多个检查点恢复变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆