批处理、重复和随机播放对 TensorFlow 数据集有什么作用? [英] What does batch, repeat, and shuffle do with TensorFlow Dataset?

查看:16
本文介绍了批处理、重复和随机播放对 TensorFlow 数据集有什么作用?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我目前正在学习 TensorFlow,但在以下代码片段中遇到了困惑:

I'm currently learning TensorFlow but I came across a confusion in the below code snippet:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

我知道首先数据集将保存所有数据,但 shuffle()repeat()batch() 的作用到数据集?请帮我举个例子和解释.

I know that first the dataset will hold all the data but what shuffle(),repeat(), and batch() do to the dataset? Please help me with an example and explanation.

推荐答案

更新:这里是一个用于演示此答案的小型协作笔记本.

Update: Here is a small collaboration notebook for demonstration of this answer.

想象一下,你有一个数据集:[1, 2, 3, 4, 5, 6],那么:

Imagine, you have a dataset: [1, 2, 3, 4, 5, 6], then:

ds.shuffle() 的工作原理

dataset.shuffle(buffer_size=3) 将分配一个大小为 3 的缓冲区来选择随机条目.此缓冲区将连接到源数据集.我们可以这样成像:

dataset.shuffle(buffer_size=3) will allocate a buffer of size 3 for picking random entries. This buffer will be connected to the source dataset. We could image it like this:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

假设条目 2 取自随机缓冲区.空闲空间由源缓冲区中的下一个元素填充,即 4:

Let's assume that the entry 2 was taken from the random buffer. Free space is filled by the next element from the source buffer, that is 4:

2 <= [1,3,4] <= [5,6]

我们继续阅读直到什么都没有:

We continue reading till nothing is left:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

ds.repeat() 的工作原理

一旦从数据集中读取所有条目并尝试读取下一个元素,数据集就会抛出错误.这就是 ds.repeat() 发挥作用的地方.它将重新初始化数据集,再次像这样:

As soon as all the entries are read from the dataset and you try to read the next element, the dataset will throw an error. That's where ds.repeat() comes into play. It will re-initialize the dataset, making it again like this:

[1,2,3] <= [4,5,6]

ds.batch() 会产生什么

ds.batch() 将获取第一个 batch_size 条目并从中生成一个批次.因此,我们的示例数据集的批大小为 3 将产生两个批记录:

The ds.batch() will take first batch_size entries and make a batch out of them. So, batch size of 3 for our example dataset will produce two batch records:

[2,1,5]
[3,6,4]

由于我们在批处理之前有一个 ds.repeat(),因此数据的生成将继续.但是由于ds.random(),元素的顺序会有所不同.应该考虑的是,由于随机缓冲区的大小,6 永远不会出现在第一批中.

As we have a ds.repeat() before the batch, the generation of the data will continue. But the order of the elements will be different, due to the ds.random(). What should be taken into account is that 6 will never be present in the first batch, due to the size of the random buffer.

这篇关于批处理、重复和随机播放对 TensorFlow 数据集有什么作用?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆