TensorFlow:按名称获取变量 [英] TensorFlow: getting variable by name
问题描述
在使用 TensorFlow Python API 时,我创建了一个变量(在构造函数中没有指定它的 name
),它的 name
属性的值为 "Variable_23:0"
.当我尝试使用 tf.get_variable("Variable23")
选择此变量时,会创建一个名为 "Variable_23_1:0"
的新变量.如何正确选择 "Variable_23"
而不是创建一个新的?
我想要做的是按名称选择变量,然后重新初始化它,以便我可以微调权重.
get_variable()
函数创建一个新变量或返回一个之前由 get_variable()
创建的变量.它不会返回使用 tf.Variable()
创建的变量.这是一个简单的例子:
如果您没有使用 tf.get_variable()
创建变量,您有几个选择.首先,您可以使用 tf.global_variables()
(如@mrry 建议的那样):
或者你可以像这样使用 tf.get_collection()
:
编辑
你也可以使用get_tensor_by_name()
:
回想一下张量是操作的输出.它与操作同名,加上 :0
.如果操作有多个输出,则它们具有与操作相同的名称加上 :0
、:1
、:2
等.
When using the TensorFlow Python API, I created a variable (without specifying its name
in the constructor), and its name
property had the value "Variable_23:0"
. When I try to select this variable using tf.get_variable("Variable23")
, a new variable called "Variable_23_1:0"
is created instead. How do I correctly select "Variable_23"
instead of creating a new one?
What I want to do is select the variable by name, and reinitialize it so I can finetune weights.
The get_variable()
function creates a new variable or returns one created earlier by get_variable()
. It won't return a variable created using tf.Variable()
. Here's a quick example:
>>> with tf.variable_scope("foo"):
... bar1 = tf.get_variable("bar", (2,3)) # create
...
>>> with tf.variable_scope("foo", reuse=True):
... bar2 = tf.get_variable("bar") # reuse
...
>>> with tf.variable_scope("", reuse=True): # root variable scope
... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
...
>>> (bar1 is bar2) and (bar2 is bar3)
True
If you did not create the variable using tf.get_variable()
, you have a couple options. First, you can use tf.global_variables()
(as @mrry suggests):
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True
Or you can use tf.get_collection()
like so:
>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True
Edit
You can also use get_tensor_by_name()
:
>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal
bar2 in value.
Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0
. If the operation has multiple outputs, they have the same name as the operation plus :0
, :1
, :2
, and so on.
这篇关于TensorFlow:按名称获取变量的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!