tf model.fit() 中的 batch_size 与 tf.data.Dataset 中的 batch_size [英] batch_size in tf model.fit() vs. batch_size in tf.data.Dataset

查看:27
本文介绍了tf model.fit() 中的 batch_size 与 tf.data.Dataset 中的 batch_size的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个可以放入主机内存的大型数据集.但是,当我使用 tf.keras 训练模型时,它会产生 GPU 内存不足问题.然后我查看 tf.data.Dataset 并希望使用它的 batch() 方法来批处理训练数据集,以便它可以在 GPU 中执行 model.fit().根据其文档,示例如下:

I have a large dataset that can fit in host memory. However, when I use tf.keras to train the model, it yields GPU out-of-memory problem. Then I look into tf.data.Dataset and want to use its batch() method to batch the training dataset so that it can execute the model.fit() in GPU. According to its documentation, an example is as follows:

train_dataset = tf.data.Dataset.from_tensor_slices((train_examples, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_examples, test_labels))

BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

dataset.from_tensor_slices().batch() 中的 BATCH_SIZE 与 tf.keras modelt.fit() 中的 batch_size 是否相同?

Is the BATCH_SIZE in dataset.from_tensor_slices().batch() the same as the batch_size in the tf.keras modelt.fit()?

我应该如何选择 BATCH_SIZE 才能让 GPU 有足够的数据来高效运行,同时其内存不会溢出?

How should I choose BATCH_SIZE so that GPU has sufficient data to run efficiently and yet its memory is not overflown?

推荐答案

在这种情况下,您不需要在 model.fit() 中传递 batch_size 参数.它将自动使用您在 tf.data.Dataset().batch() 中使用的 BATCH_SIZE.

You do not need to pass the batch_size parameter in model.fit() in this case. It will automatically use the BATCH_SIZE that you use in tf.data.Dataset().batch().

至于您的另一个问题:批量大小超参数确实需要仔细调整.另一方面,如果你看到 OOM 错误,你应该减少它,直到你没有得到 OOM(通常以这种方式 32 --> 16 --> 8 ...).

As for your other question : the batch size hyperparameter indeed needs to be carefully tuned. On the other hand, if you see OOM errors, you should decrease it until you do not get OOM (normally in this manner 32 --> 16 --> 8 ...).

在您的情况下,我会从 2 的 batch_size 开始,然后将其增加 2 的幂,然后检查是否仍然出现 OOM.

In your case I would start with a batch_size of 2 an increase it by a power of two and check if I still get OOM.

如果您使用 tf.data.Dataset().batch() 方法,则不需要提供 batch_size 参数.

You do not need to provide the batch_size parameter if you use the tf.data.Dataset().batch() method.

事实上,即使官方文档也说明了这一点:

In fact, even the official documentation states this:

batch_size : 整数或无.每次梯度更新的样本数.如果未指定,batch_size 将默认为 32.不要指定如果您的数据采用数据集、生成器或keras.utils.Sequence 实例(因为它们生成批次).

batch_size : Integer or None. Number of samples per gradient update. If unspecified, batch_size will default to 32. Do not specify the batch_size if your data is in the form of datasets, generators, or keras.utils.Sequence instances (since they generate batches).

这篇关于tf model.fit() 中的 batch_size 与 tf.data.Dataset 中的 batch_size的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆