如何使用不同名称但形状相同的Tensorflow恢复重量? [英] How to restore weights with different names but same shapes Tensorflow?

查看:160
本文介绍了如何使用不同名称但形状相同的Tensorflow恢复重量?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在Tensorflow中有多种架构.其中一些共享某些零件的设计.

I have multiple architectures in Tensorflow. Some of them share the design of certain parts.

我想训练一个网络,并在另一个网络中使用经过训练的相似层的权重.

I would like to train one of the networks and use the trained weights of the similar layers in another network.

此时,我能够保存所需的权重,并将其重新加载到具有与变量完全相同的命名约定的体系结构中.

At this point in time, I am able to save the weights I want and reload them in an architecture with an exactly similar naming convention for the variables.

但是,当两个网络中的权重名称不同时,将无法恢复.对于第一个网络,我有以下命名约定:

However, when the weights have different names in the two networks, it is not possible to restore. I have this naming convention for the first network:

  • selector_network/c2w/var1

在第二个网络中,我有这个:

in the second network I have this:

  • joint_network/c2w/var1

除此之外,变量的形状相似.是否有可能在重新加载时更改名称或告诉Tensorflow在哪里适合这些变量?

Apart from that, the variables are similar in terms of shape. Is there a possibility to change the names upon reloading or to tell Tensorflow where to fit those variables?

我从@batzner找到了该脚本,该脚本允许重命名Tensorflow检查点的变量: tensorflow_rename_variables .

I found this script from @batzner that allows renaming the variables of a Tensorflow checkpoint : tensorflow_rename_variables.

它不起作用.我收到以下错误:

It is not working. I get the following error:

ValueError: Couldn't find 'checkpoint' file or checkpoints in given directory ./joint_pos_tagger_lemmatizer/fi/

推荐答案

tf.train.Saver 内置了对var_list参数使用字典的支持.该词典将检查点文件中的对象名称映射到要还原的变量.

tf.train.Saver has builtin support for that using a dictionary for the var_list argument. This dictionary maps the names of the objects in the checkpoint file to your variables you want to restore.

如果要使用选择器网络"的检查点还原联合网络",则可以这样做:

If you want to restore your "joint network" with a checkpoint of your "selector network", you can do it like this:

# var1 is the variable you want ot restore
saver = tf.train.Saver(var_list={'selector_network/c2w/var1': var1})
saver.restore(...)

如果要还原更多变量,只需扩展字典即可.

If you want to restore more variables, you simply have to extend the dictionary.

这篇关于如何使用不同名称但形状相同的Tensorflow恢复重量?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆