tf.data.Dataset:如何获取数据集大小(一个时期的元素数)? [英] tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

查看:35
本文介绍了tf.data.Dataset:如何获取数据集大小(一个时期的元素数)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我以这种方式定义了一个数据集:

Let's say I have defined a dataset in this way:

filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))

如何获得数据集中的元素数量(因此,组成一个时代的单个元素的数量)?

how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?

我知道 tf.data.Dataset 已经知道数据集的维度,因为 repeat() 方法允许将输入管道重复指定的时期数.所以它一定是获取这些信息的一种方式.

I know that tf.data.Dataset already knows the dimension of the dataset, because the repeat() method allows repeating the input pipeline for a specified number of epochs. So it must be a way to get this information.

推荐答案

tf.data.Dataset.list_files 创建一个名为 MatchingFiles:0 的张量(带有适当的前缀如果适用).

tf.data.Dataset.list_files creates a tensor called MatchingFiles:0 (with the appropriate prefix if applicable).

你可以评价

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

获取文件数量.

当然,这仅适用于简单的情况,特别是如果每​​个图像只有一个样本(或已知数量的样本).

Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.

在更复杂的情况下,例如当不知道每个文件的样本数时,只能观察一个epoch结束时的样本数.

In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.

为此,您可以查看由您的 Dataset 计算的 epoch 数.repeat() 创建一个名为 _count 的成员,它计算时代的数量.通过在迭代过程中观察它,您可以发现它何时发生变化并从那里计算您的数据集大小.

To do this, you can watch the number of epochs that is counted by your Dataset. repeat() creates a member called _count, that counts the number of epochs. By observing it during your iterations, you can spot when it changes and compute your dataset size from there.

这个计数器可能被埋在连续调用成员函数时创建的Dataset的层次结构中,所以我们要这样挖出来.

This counter may be buried in the hierarchy of Datasets that is created when calling member functions successively, so we have to dig it out like this.

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

请注意,使用此技术时,数据集大小的计算并不准确,因为 epoch_counter 递增的批次通常会混合来自两个连续 epoch 的样本.因此,此计算精确到您的批次长度.

Note that with this technique, the computation of your dataset size is not exact, because the batch during which epoch_counter is incremented typically mixes samples from two successive epochs. So this computation is precise up to your batch length.

这篇关于tf.data.Dataset:如何获取数据集大小(一个时期的元素数)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆