使用 TensorFlow Dataset API 的纪元计数器 [英] Epoch counter with TensorFlow Dataset API

查看:29
本文介绍了使用 TensorFlow Dataset API 的纪元计数器的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在将我的 TensorFlow 代码从旧的队列界面更改为新的 DatasetAPI.在我的旧代码中,每次在队列中访问和处理新的输入张量时,我都会通过增加 tf.Variable 来跟踪时代计数.我想用新的 Dataset API 计算这个纪元数,但我在让它工作时遇到了一些麻烦.

I'm changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable every time a new input tensor is accessed and processed in the queue. I'd like to have this epoch count with the new Dataset API, but I'm having some trouble making it work.

由于我在预处理阶段生成了可变数量的数据项,因此在训练循环中增加 (Python) 计数器并不是一件简单的事情 - 我需要计算相对于队列或数据集的输入.

Since I'm producing a variable amount of data items in the pre-processing stage, it is not a simple matter of incrementing a (Python) counter in the training loop - I need to compute the epoch count with respect to the input of the queues or Dataset.

我用旧的队列系统模仿了我以前的东西,这是我最终得到的数据集 API(简化示例):

I mimicked what I had before with the old queue system, and here is what I ended up with for the Dataset API (simplified example):

with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while

然而,这行不通.它崩溃并显示一条神秘的错误消息:

However, this doesn't work. It crashes with a cryptic error message:

 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
 are not compatible with expected types ([tf.float32_ref, tf.float32])

仔细检查发现 epoch_counter 变量可能根本无法在 pre_processing_func 中访问.它是否存在于不同的图表中?

Closer inspection shows that the epoch_counter variable might not be accessible within the pre_processing_func at all. Does it live in a different graph perhaps?

知道如何修复上面的例子吗?或者如何通过其他方式获得纪元计数器(带小数点,例如0.4或2.9)?

Any idea how to fix the above example? Or how to get the epoch counter (with decimal points, e.g. 0.4 or 2.9) through some other means?

推荐答案

TL;DR:将 epoch_counter 的定义替换为以下内容:

TL;DR: Replace the definition of epoch_counter with the following:

epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                trainable=False, use_resource=True)

<小时>

tf.data.Dataset 转换中使用 TensorFlow 变量存在一些限制.原则限制是所有变量都必须是资源变量",而不是旧的参考变量";不幸的是,tf.Variable 仍然出于向后兼容的原因创建引用变量".


There are some limitations around using TensorFlow variables inside tf.data.Dataset transformations. The principle limitation is that all variables must be "resource variables" and not the older "reference variables"; unfortunately tf.Variable still creates "reference variables" for backwards compatibility reasons.

一般来说,如果可以避免的话,我不建议在 tf.data 管道中使用变量.例如,您可以使用 Dataset.range() 来定义一个纪元计数器,然后执行以下操作:

Generally speaking, I wouldn't recommend using variables in a tf.data pipeline if it's possible to avoid it. For example, you might be able to use Dataset.range() to define an epoch counter, and then do something like:

epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
    (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

上面的代码片段将一个纪元计数器附加到每个值作为第二个组件.

The above snippet attaches an epoch counter to every value as a second component.

这篇关于使用 TensorFlow Dataset API 的纪元计数器的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆