TensorFlow:“无法按值捕获有状态节点"在 tf.contrib.data API 中 [英] TensorFlow: "Cannot capture a stateful node by value" in tf.contrib.data API

查看:23
本文介绍了TensorFlow:“无法按值捕获有状态节点"在 tf.contrib.data API 中的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

对于迁移学习,人们经常使用网络作为特征提取器来创建特征数据集,在该数据集上训练另一个分类器(例如 SVM).

For transfer learning, one often uses a network as a feature extractor to create a dataset of features, on which another classifier is trained (e.g. a SVM).

我想使用数据集 API (tf.contrib.data) 和 dataset.map():

I want to implement this using the Dataset API (tf.contrib.data) and dataset.map():

# feature_extractor will create a CNN on top of the given tensor
def features(feature_extractor, ...):
    dataset = inputs(...)  # This creates a dataset of (image, label) pairs

    def map_example(image, label):
        features = feature_extractor(image, trainable=False)
        #  Leaving out initialization from a checkpoint here... 
        return features, label

    dataset = dataset.map(map_example)

    return dataset

为数据集创建迭代器时,这样做会失败.

Doing this fails when creating an iterator for the dataset.

ValueError: Cannot capture a stateful node by value.

这是真的,网络的内核和偏差是变量,因此是有状态的.对于这个特定的例子,他们不必是.

This is true, the kernels and biases of the network are variables and thus stateful. For this particular example they don't have to be though.

有没有办法制作 Ops,特别是 tf.Variable 对象无状态?

Is there a way to make Ops and specifically tf.Variable objects stateless?

因为我使用的是 tf.layers 我不能简单地将它们创建为常量,并且设置 trainable=False 也不会创建常量,只是不会将变量添加到 GraphKeys.TRAINABLE_VARIABLES 集合中.

Since I'm using tf.layers I cannot simply create them as constants, and setting trainable=False won't create constants neither but just won't add the variables to the GraphKeys.TRAINABLE_VARIABLES collection.

推荐答案

不幸的是,tf.Variable 本质上是有状态的.但是,只有在使用 Dataset.make_one_shot_iterator() 创建迭代器时才会出现此错误.* 为避免该问题,您可以改用 Dataset.make_initializable_iterator(),并使用需要注意的是,在运行输入管道中使用的 tf.Variable 对象的初始化程序之后,您还必须在返回的迭代器上运行 iterator.initializer.

Unfortunately, tf.Variable is inherently stateful. However, this error only arises if you use Dataset.make_one_shot_iterator() to create the iterator.* To avoid the problem, you can instead use Dataset.make_initializable_iterator(), with the caveat that you must also run iterator.initializer on the returned iterator after running the initializer for the tf.Variable objects used in the input pipeline.

* 此限制的原因是 Dataset.make_one_shot_iterator() 的实现细节以及它用于封装数据集定义.由于使用查找表和变量等有状态资源比我们最初想象的更受欢迎,我们正在寻找放宽此限制的方法.

* The reason for this limitation is an implementation detail of Dataset.make_one_shot_iterator() and the work-in-progress TensorFlow function (Defun) support that it uses to encapsulate the dataset definition. Since using stateful resources like lookup tables and variables has been more popular than we initially imagined, we're looking into ways to relax this restriction.

这篇关于TensorFlow:“无法按值捕获有状态节点"在 tf.contrib.data API 中的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆