TensorFlow:“无法按值捕获有状态节点"在 tf.contrib.data API 中 [英] TensorFlow: "Cannot capture a stateful node by value" in tf.contrib.data API
问题描述
对于迁移学习,人们经常使用网络作为特征提取器来创建特征数据集,在该数据集上训练另一个分类器(例如 SVM).
For transfer learning, one often uses a network as a feature extractor to create a dataset of features, on which another classifier is trained (e.g. a SVM).
我想使用数据集 API (tf.contrib.data
) 和 dataset.map()
:
I want to implement this using the Dataset API (tf.contrib.data
) and dataset.map()
:
# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
dataset = inputs(...) # This creates a dataset of (image, label) pairs
def map_example(image, label):
features = feature_extractor(image, trainable=False)
# Leaving out initialization from a checkpoint here...
return features, label
dataset = dataset.map(map_example)
return dataset
为数据集创建迭代器时,这样做会失败.
Doing this fails when creating an iterator for the dataset.
ValueError: Cannot capture a stateful node by value.
这是真的,网络的内核和偏差是变量,因此是有状态的.对于这个特定的例子,他们不必是.
This is true, the kernels and biases of the network are variables and thus stateful. For this particular example they don't have to be though.
有没有办法制作 Ops,特别是 tf.Variable代码>
对象无状态?
Is there a way to make Ops and specifically tf.Variable
objects stateless?
因为我使用的是 tf.layers
我不能简单地将它们创建为常量,并且设置 trainable=False
也不会创建常量,只是不会将变量添加到 GraphKeys.TRAINABLE_VARIABLES
集合中.
Since I'm using tf.layers
I cannot simply create them as constants, and setting trainable=False
won't create constants neither but just won't add the variables to the GraphKeys.TRAINABLE_VARIABLES
collection.
推荐答案
不幸的是,tf.Variable
本质上是有状态的.但是,只有在使用 Dataset.make_one_shot_iterator()
创建迭代器时才会出现此错误.* 为避免该问题,您可以改用 Dataset.make_initializable_iterator()
,并使用需要注意的是,在运行输入管道中使用的 tf.Variable
对象的初始化程序之后,您还必须在返回的迭代器上运行 iterator.initializer
.
Unfortunately, tf.Variable
is inherently stateful. However, this error only arises if you use Dataset.make_one_shot_iterator()
to create the iterator.* To avoid the problem, you can instead use Dataset.make_initializable_iterator()
, with the caveat that you must also run iterator.initializer
on the returned iterator after running the initializer for the tf.Variable
objects used in the input pipeline.
* 此限制的原因是 Dataset.make_one_shot_iterator()
的实现细节以及它用于封装数据集定义.由于使用查找表和变量等有状态资源比我们最初想象的更受欢迎,我们正在寻找放宽此限制的方法.
* The reason for this limitation is an implementation detail of Dataset.make_one_shot_iterator()
and the work-in-progress TensorFlow function (Defun
) support that it uses to encapsulate the dataset definition. Since using stateful resources like lookup tables and variables has been more popular than we initially imagined, we're looking into ways to relax this restriction.
这篇关于TensorFlow:“无法按值捕获有状态节点"在 tf.contrib.data API 中的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!