获取 TensorFlow 训练的模型中某些权重的值 [英] Get the value of some weights in a model trained by TensorFlow
问题描述
我已经用 TensorFlow 训练了一个 ConvNet 模型,我想在层中获得一个特定的权重.例如,在 torch7 中,我只需访问 model.modules[2].weights
.获取第 2 层的权重.我将如何在 TensorFlow 中做同样的事情?
I have trained a ConvNet model with TensorFlow, and I want to get a particular weight in layer. For example in torch7 I would simply access model.modules[2].weights
. to get the weights of layer 2. How would I do the same thing in TensorFlow?
推荐答案
在 TensorFlow 中,经过训练的权重由 tf.Variable
对象.如果你创建了一个 tf.Variable
—e.g.称为 v
—你自己,你可以通过调用 sess.run(v)
(其中 sess
是 tf.Session
).
In TensorFlow, trained weights are represented by tf.Variable
objects. If you created a tf.Variable
—e.g. called v
—yourself, you can get its value as a NumPy array by calling sess.run(v)
(where sess
is a tf.Session
).
如果您当前没有指向 tf.Variable
的指针,您可以通过调用 tf.trainable_variables()
.此函数返回当前图中所有可训练的 tf.Variable
对象的列表,您可以通过匹配 v.name
属性来选择所需的对象.例如:
If you do not currently have a pointer to the tf.Variable
, you can get a list of the trainable variables in the current graph by calling tf.trainable_variables()
. This function returns a list of all trainable tf.Variable
objects in the current graph, and you can select the one that you want by matching the v.name
property. For example:
# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]
这篇关于获取 TensorFlow 训练的模型中某些权重的值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!