在 Tensorflow 中恢复变量子集 [英] Restore subset of variables in Tensorflow

查看:23
本文介绍了在 Tensorflow 中恢复变量子集的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在 tensorflow 中训练一个生成对抗网络 (GAN),基本上我们有两个不同的网络,每个网络都有自己的优化器.

I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.

self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)

...

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \
                      .minimize(self.d_loss, var_list=self.d_vars)

问题是我先训练其中一个网络 (g),然后我想一起训练 g 和 d.但是,当我调用加载函数时:

The problem is that I train one of the networks (g) first, and then, I want to train g and d together. However, when I call the load function:

self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()

self.load(self.checkpoint_dir)

def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        return True
    else:
        return False

我有这样的错误(有更多的回溯):

I have an error like this (with a lot more traceback):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000

我可以恢复 g 网络并继续使用该函数进行训练,但是当我想从头开始为 d 加上 g 从存储的模型中我有那个错误.

I can restore the g network and keep training with that function, but when I want to star d from scratch, and g from the the stored model I have that error.

推荐答案

要恢复变量子集,您必须创建一个新的 tf.train.Saver 并在可选的 var_list 参数中向其传递要恢复的特定变量列表.

To restore a subset of variables, you must create a new tf.train.Saver and pass it a specific list of variables to restore in the optional var_list argument.

默认情况下,tf.train.Saver 将创建操作,以便 (i) 在您调用 saver.restore().虽然这适用于大多数常见场景,但您必须提供更多信息才能处理变量的特定子集:

By default, a tf.train.Saver will create ops that (i) save every variable in your graph when you call saver.save() and (ii) lookup (by name) every variable in the given checkpoint when you call saver.restore(). While this works for most common scenarios, you have to provide more information to work with specific subsets of the variables:

  1. 如果你只想恢复变量的一个子集,你可以通过调用tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX),假设你把g"网络放在一个共同的with tf.name_scope(G_NETWORK_PREFIX):tf.variable_scope(G_NETWORK_PREFIX): 块.然后,您可以将此列表传递给 tf.train.Saver 构造函数.

如果要恢复变量的子集和/或检查点中的变量具有不同的名称,您可以将字典作为var_list传递争论.默认情况下,检查点中的每个变量都与一个 key 相关联,这是其 tf.Variable.name 属性的值.如果目标图中的名称不同(例如,因为您添加了范围前缀),您可以指定一个字典,将字符串键(在检查点文件中)映射到 tf.Variable 对象(在目标中图).

If you want to restore a subset of the variable and/or they variables in the checkpoint have different names, you can pass a dictionary as the var_list argument. By default, each variable in a checkpoint is associated with a key, which is the value of its tf.Variable.name property. If the name is different in the target graph (e.g. because you added a scope prefix), you can specify a dictionary that maps string keys (in the checkpoint file) to tf.Variable objects (in the target graph).

这篇关于在 Tensorflow 中恢复变量子集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆