TensorFlow:按名称获取变量 [英] TensorFlow: getting variable by name

查看:41
本文介绍了TensorFlow:按名称获取变量的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在使用 TensorFlow Python API 时,我创建了一个变量(在构造函数中没有指定它的 name),它的 name 属性的值为 "Variable_23:0".当我尝试使用 tf.get_variable("Variable23") 选择此变量时,会创建一个名为 "Variable_23_1:0" 的新变量.如何正确选择 "Variable_23" 而不是创建一个新的?

我想要做的是按名称选择变量,然后重新初始化它,以便我可以微调权重.

解决方案

get_variable() 函数创建一个新变量或返回一个之前由 get_variable() 创建的变量.它不会返回使用 tf.Variable() 创建的变量.这是一个简单的例子:

<预><代码>>>>使用 tf.variable_scope("foo"):... bar1 = tf.get_variable("bar", (2,3)) # 创建...>>>使用 tf.variable_scope("foo", 重用 = True):... bar2 = tf.get_variable("bar") # 重用...>>>with tf.variable_scope("",useuse=True): # 根变量作用域... bar3 = tf.get_variable("foo/bar") # 重用(相当于上面的)...>>>(bar1 是 bar2) 和 (bar2 是 bar3)真的

如果您没有使用 tf.get_variable() 创建变量,您有几个选择.首先,您可以使用 tf.global_variables() (如@mrry 建议的那样):

<预><代码>>>>bar1 = tf.Variable(0.0, name="bar")>>>bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]>>>bar1 是 bar2真的

或者你可以像这样使用 tf.get_collection() :

<预><代码>>>>bar1 = tf.Variable(0.0, name="bar")>>>bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]>>>bar1 是 bar2真的

编辑

你也可以使用get_tensor_by_name():

<预><代码>>>>bar1 = tf.Variable(0.0, name="bar")>>>图 = tf.get_default_graph()>>>bar2 = graph.get_tensor_by_name("bar:0")>>>bar1 是 bar2错误,bar2 是在 bar1 上通过 convert_to_tensor 生成的张量.但 bar1 相等bar2 的值.

回想一下张量是操作的输出.它与操作同名,加上 :0.如果操作有多个输出,则它们具有与操作相同的名称加上 :0:1:2 等.

When using the TensorFlow Python API, I created a variable (without specifying its name in the constructor), and its name property had the value "Variable_23:0". When I try to select this variable using tf.get_variable("Variable23"), a new variable called "Variable_23_1:0" is created instead. How do I correctly select "Variable_23" instead of creating a new one?

What I want to do is select the variable by name, and reinitialize it so I can finetune weights.

解决方案

The get_variable() function creates a new variable or returns one created earlier by get_variable(). It won't return a variable created using tf.Variable(). Here's a quick example:

>>> with tf.variable_scope("foo"):
...   bar1 = tf.get_variable("bar", (2,3)) # create
... 
>>> with tf.variable_scope("foo", reuse=True):
...   bar2 = tf.get_variable("bar")  # reuse
... 

>>> with tf.variable_scope("", reuse=True): # root variable scope
...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
... 
>>> (bar1 is bar2) and (bar2 is bar3)
True

If you did not create the variable using tf.get_variable(), you have a couple options. First, you can use tf.global_variables() (as @mrry suggests):

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True

Or you can use tf.get_collection() like so:

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True

Edit

You can also use get_tensor_by_name():

>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
bar2 in value.

Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0. If the operation has multiple outputs, they have the same name as the operation plus :0, :1, :2, and so on.

这篇关于TensorFlow:按名称获取变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆