TensorFlow,如何重用变量作用域名称 [英] TensorFlow, how to reuse a variable scope name

查看:27
本文介绍了TensorFlow,如何重用变量作用域名称的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我在这里定义了一个类

class BasicNetwork(object):
    def __init__(self, scope, task_name, is_train=False, img_shape=(80, 80)):
        self.scope = scope
        self.is_train = is_train
        self.task_name = task_name
        self.__create_network(scope, img_shape=img_shape)

    def __create_network(self, scope, img_shape=(80, 80)):
        with tf.variable_scope(scope):
            with tf.variable_scope(self.task_name):
                with tf.variable_scope('input_data'):
                    self.inputs = tf.placeholder(shape=[None, *img_shape, cfg.HIST_LEN], dtype=tf.float32)
                with tf.variable_scope('networks'):
                    with tf.variable_scope('conv_1'):
                        self.conv_1 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.inputs, num_outputs=32,
                                                  kernel_size=[8, 8], stride=4, padding='SAME', trainable=self.is_train)
                    with tf.variable_scope('conv_2'):
                        self.conv_2 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_1, num_outputs=64,
                                                  kernel_size=[4, 4], stride=2, padding='SAME', trainable=self.is_train)
                    with tf.variable_scope('conv_3'):
                        self.conv_3 = slim.conv2d(activation_fn=tf.nn.relu, inputs=self.conv_2, num_outputs=64,
                                                  kernel_size=[3, 3], stride=1, padding='SAME', trainable=self.is_train)
                    with tf.variable_scope('f_c'):
                        self.fc = slim.fully_connected(slim.flatten(self.conv_3), 512,
                                                       activation_fn=tf.nn.elu, trainable=self.is_train)

我想定义两个具有不同任务名称的 BasicNetwork 实例.范围是全球".但是当我检查输出时,有

And I want to define two instances of BasicNetwork with different task names. The scope is 'global'. But when I check the output, there are

ipdb> for i in net_1.layres: print(i)
Tensor("global/simple/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global/simple/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2)

ipdb> for i in net_2.layres: print(i)
Tensor("global_1/supreme/networks/conv_1/Conv/Relu:0", shape=(?, 20, 20, 32), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/conv_2/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/conv_3/Conv/Relu:0", shape=(?, 10, 10, 64), dtype=float32, device=/device:GPU:2)
Tensor("global_1/supreme/networks/f_c/fully_connected/Elu:0", shape=(?, 512), dtype=float32, device=/device:GPU:2)

正如您在输出中看到的,已经创建了一个新的范围 global_1,但我想将其设为 global.我设置了reuse=True,但是后来我发现当没有名为global的作用域时,reuse=True是不能使用的.我该怎么办?

As you can see in the output, a new scope global_1 has been created, but I want to make it global. I set reuse=True but later I found that when there is no scope named global, reuse=True can't be used. What should I do?

推荐答案

using reuse True 你可以得到现有的变量.现在要重用变量软管应该存在于图中.如果存在同名的变量,那么您可以将它们重用于其他操作.

using reuse True you can get the existing variables. Now to reuse variables hose should exist in the graph. If th variables with same name exists, then you can reuse those for other operation.

class BasicNetwork(object):
def __init__(self, scope, task_name, reuse, is_train=False, img_shape=(80, 80)):
    self.scope = scope
    self.is_train = is_train
    self.reuse = reuse
    self.task_name = task_name
    self.__create_network(scope, reuse=self.reuse, img_shape=img_shape)

def __create_network(self, scope, reuse=None, img_shape=(80, 80)):
    with tf.variable_scope(scope, reuse=reuse):
    ...
        # delete this line with tf.variable_scope(self.task_name): 
        # or replace with; with tf.name_scope(self.task_name):               

trainnet = BasicNetwork('global', taskname, None)
# resue the created variables
valnet = BasicNetwork('global', taskname, True)

这篇关于TensorFlow,如何重用变量作用域名称的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆