TensorFlow:从多个检查点恢复变量 [英] TensorFlow: Restoring variables from from multiple checkpoints
问题描述
我有以下情况:
我有 2 个用 2 个单独脚本编写的模型:
I have 2 models written in 2 separate scripts:
模型A由变量a1
、a2
和a3
组成,编写在A.py
Model A consists of variables a1
, a2
, and a3
, and is written in A.py
模型B由变量b1
、b2
和b3
组成,用B.py编写
Model B consists of variables b1
, b2
, and b3
, and is written in B.py
在 A.py
和 B.py
的每一个中,我都有一个 tf.train.Saver
来保存所有的检查点局部变量,让我们分别调用检查点文件 ckptA
和 ckptB
.
In each of A.py
and B.py
, I have a tf.train.Saver
that saves the checkpoint of all the local variables, and let's call the checkpoint files ckptA
and ckptB
respectively.
我现在想制作一个使用 a1
和 b1
的模型 C.我可以通过使用 var_scope 来使 a1
的变量名称完全相同(对于 b1
也是如此).
I now want to make a model C that uses a1
and b1
. I can make it so that the exact same variable name for a1
is used in both A and C by using the var_scope (and the same for b1
).
问题是如何将 a1
和 b1
从 ckptA
和 ckptB
加载到模型 C 中?例如,以下是否可行?
The question is how might I load a1
and b1
from ckptA
and ckptB
into model C? For example, would the following work?
saver.restore(session, ckptA_location)
saver.restore(session, ckptB_location)
如果您尝试恢复同一个会话两次,是否会引发错误?它会抱怨没有为额外变量(b2
、b3
、a2
、a3
>),或者它只是简单地恢复它可以恢复的变量,并且只在 C 中存在一些未初始化的其他变量时才抱怨?
Would an error be raised if you are try to restore the same session twice? Would it complain that there are no allocated "slots" for the extra variables (b2
, b3
, a2
, a3
), or would it simply restore the variables it can, and only complain if there are some other variables in C that are uninitialized?
我现在正在尝试编写一些代码来测试这个,但我很想看到一个规范的方法来解决这个问题,因为在尝试重新使用一些预先训练的权重时经常会遇到这种情况.
I'm trying to write some code to test this now but I would love to see a canonical approach to this problem, because one encounters this often when trying to re-use some pre-trained weights.
谢谢!
推荐答案
如果您尝试使用保护程序(默认情况下表示所有六个变量)来恢复,您将收到 tf.errors.NotFoundError
从不包含保护程序代表的所有变量的检查点.(但请注意,您可以在同一个会话中多次调用 Saver.restore()
,对于变量的任何子集,只要所有请求的变量都存在于相应的文件中.)
You would get a tf.errors.NotFoundError
if you tried to use a saver (by default representing all six variables) to restore from a checkpoint that does not contain all of the variables that the saver represents. (Note however that you are free to call Saver.restore()
multiple times in the same session, for any subset of the variables, as long as all of the requested variables are present in the corresponding file.)
规范方法是定义两个独立的tf.train.Saver
实例 覆盖完全包含在单个检查点中的每个变量子集.例如:
The canonical approach is to define two separate tf.train.Saver
instances covering each subset of variables that is entirely contained in a single checkpoint. For example:
saver_a = tf.train.Saver([a1])
saver_b = tf.train.Saver([b1])
saver_a.restore(session, ckptA_location)
saver_b.restore(session, ckptB_location)
根据你的代码是如何构建的,如果你有指向局部作用域中名为 a1
和 b1
的 tf.Variable
对象的指针,你可以在这里停止阅读.
Depending on how your code is built, if you have pointers to tf.Variable
objects called a1
and b1
in the local scope, you can stop reading here.
另一方面,如果变量 a1
和 b1
在单独的文件中定义,您可能需要做一些创造性的事情来检索指向这些变量的指针.虽然不理想,但人们通常做的是使用一个共同的前缀,例如如下(假设变量名是"a1:0"
and "b1:0"
> 分别):
On the other hand, if variables a1
and b1
are defined in separate files, you might need to do something creative to retrieve pointers to those variables. Although it's not ideal, what people typically do is to use a common prefix, for example as follows (assuming the variable names are "a1:0"
and "b1:0"
respectively):
saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"])
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"])
最后一点:您不必费尽心思确保变量在 A 和 C 中具有相同的名称.您可以将 name-to-Variable
字典作为第一个传递tf.train.Saver
构造函数的参数,从而将检查点文件中的名称重新映射到代码中的 Variable
对象.如果 A.py
和 B.py
具有类似命名的变量,或者如果您想在 C.py
中组织模型代码,这会有所帮助来自 tf.name_scope() 中的这些文件
.
One final note: you don't have to make heroic efforts to ensure that the variables have the same names in A and C. You can pass a name-to-Variable
dictionary as the first argument to the tf.train.Saver
constructor, and thereby remap names in the checkpoint file to Variable
objects in your code. This helps if A.py
and B.py
have similarly-named variables, or if in C.py
you want to organize the model code from those files in a tf.name_scope()
.
这篇关于TensorFlow:从多个检查点恢复变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!