获取 TensorFlow 训练的模型中某些权重的值 [英] Get the value of some weights in a model trained by TensorFlow

查看:63
本文介绍了获取 TensorFlow 训练的模型中某些权重的值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经用 TensorFlow 训练了一个 ConvNet 模型,我想在层中获得一个特定的权重.例如,在 torch7 中,我只需访问 model.modules[2].weights.获取第 2 层的权重.我将如何在 TensorFlow 中做同样的事情?

I have trained a ConvNet model with TensorFlow, and I want to get a particular weight in layer. For example in torch7 I would simply access model.modules[2].weights. to get the weights of layer 2. How would I do the same thing in TensorFlow?

推荐答案

在 TensorFlow 中,经过训练的权重由 tf.Variable 对象.如果你创建了一个 tf.Variable—e.g.称为 v—你自己,你可以通过调用 sess.run(v)(其中 sesstf.Session).

In TensorFlow, trained weights are represented by tf.Variable objects. If you created a tf.Variable—e.g. called v—yourself, you can get its value as a NumPy array by calling sess.run(v) (where sess is a tf.Session).

如果您当前没有指向 tf.Variable 的指针,您可以通过调用 tf.trainable_variables().此函数返回当前图中所有可训练的 tf.Variable 对象的列表,您可以通过匹配 v.name 属性来选择所需的对象.例如:

If you do not currently have a pointer to the tf.Variable, you can get a list of the trainable variables in the current graph by calling tf.trainable_variables(). This function returns a list of all trainable tf.Variable objects in the current graph, and you can select the one that you want by matching the v.name property. For example:

# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]

这篇关于获取 TensorFlow 训练的模型中某些权重的值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆