tensorflow 数据集洗牌然后批处理或批处理然后洗牌 [英] tensorflow dataset shuffle then batch or batch then shuffle

查看:54
本文介绍了tensorflow 数据集洗牌然后批处理或批处理然后洗牌的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我最近开始学习 tensorflow.

I recently began learning tensorflow.

我不确定是否有区别

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)

另外,我不知道为什么我不能使用

Also, I am not sure why I cannot use

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)

因为它给出了错误

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'

谢谢!

推荐答案

TL;DR: 是的,有区别.几乎总是,你会想要调用 Dataset.shuffle() 之前 Dataset.batch(). 上没有 shuffle_batch() 方法tf.data.Dataset 类,必须分别调用这两个方法才能对数据集进行shuffle和batch.

TL;DR: Yes, there is a difference. Almost always, you will want to call Dataset.shuffle() before Dataset.batch(). There is no shuffle_batch() method on the tf.data.Dataset class, and you must call the two methods separately to shuffle and batch a dataset.

tf.data.Dataset 的转换以调用它们的相同顺序应用.Dataset.batch() 将其输入的连续元素组合成输出中的单个批处理元素.我们可以通过考虑以下两个数据集来看到操作顺序的影响:

The transformations of a tf.data.Dataset are applied in the same sequence that they are called. Dataset.batch() combines consecutive elements of its input into a single, batched element in the output. We can see the effect of the order of operations by considering the following two datasets:

tf.enable_eager_execution()  # To simplify the example code.

# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)

# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)

在第一个版本(shuffle之前的batch)中,每一个batch的元素都是输入中连续的3个元素;而在第二个版本中(批量前洗牌),它们是从输入中随机采样的.通常,当通过(某种变体)小批量随机梯度下降进行训练时,每个应从总输入中尽可能均匀地抽样.否则,网络可能会过度拟合输入数据中的任何结构,从而导致网络无法达到如此高的准确率.

In the first version (batch before shuffle), the elements of each batch are 3 consecutive elements from the input; whereas in the second version (shuffle before batch), they are randomly sampled from the input. Typically, when training by (some variant of) mini-batch stochastic gradient descent, the elements of each batch should be sampled as uniformly as possible from the total input. Otherwise, it is possible that the network will overfit to whatever structure was in the input data, and the resulting network will not achieve as high an accuracy.

这篇关于tensorflow 数据集洗牌然后批处理或批处理然后洗牌的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆