结合使用 Estimators API 和 tf.data.Dataset 时如何加快批处理准备 [英] How to speed up batch preparation when using Estimators API combined with tf.data.Dataset
问题描述
我想加快使用 Estimator API 和使用 tf.data.Dataset
编写的 input_fn 的训练例程.
I'd like to speed up my training routine that uses the Estimator API with input_fn wrote using tf.data.Dataset
.
我的实现需要 2 秒来准备一批数据,然后在 GPU 上运行训练 1 秒,然后重新开始准备一批数据.这真的是低效的.
My implementation takes 2 second to prepare a batch of data and then runs training on GPU for 1 sec, and then start over preparing a batch. Which is really inefficient.
我正在寻找一种方法来异步准备批次并将它们上传到 GPU 以加快训练速度.或者作为一种在 input_fn
调用之间缓存数据集的方法(dataset.cache()
似乎不是一个好的选择,因为数据集必须在每次 input_fn 调用).
I'm looking for a way to prepare the batches asynchronously and upload them to GPU to speed up the training. Or alternatively for a way to cache datasets between invocations of input_fn
(the dataset.cache()
doesn't seems to be a good choice as the dataset has to be recreated on each input_fn invocation).
这是我的代码的简化版本:
Here is a simplified version of my code:
def input_fn(filenames, labels, epochs):
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
if shuffle:
dataset = dataset.shuffle(buffer_size=len(labels))
dataset = dataset.map(_post_process, num_parallel_calls=num_map_threads)
dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
dataset = dataset.batch(128)
dataset = dataset.repeat(epochs) # to iterate over the training set forever
iterator = dataset.dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
我注意到 Estimator API 正在积极开发中,并且在 tensorflow 的 master 分支中 input_fn 已经可以返回数据集,所以也许我问得太早了,这个功能还没有准备好.但如果是这样,请提供可以跟踪此实施的票证.
I've noticed that the Estimator API is under active development and in the master branch of tensorflow the input_fn can return datasets already, so maybe I'm asking too early and this feature isn't ready yet. But if so, please provide a ticket where this implementation can be tracked.
推荐答案
使用 tf.data.Dataset.cache()
确实不是一个好的选择,因为它会将整个数据集缓存到内存中,这需要时间并且可能会溢出您的记忆.
Using tf.data.Dataset.cache()
is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.
要走的路是使用tf.data.Dataset.prefetch()
位于管道末尾,这将始终确保数据管道包含 buffer_size
元素.通常在末尾有 buffer_size = 1
就足够了:
The way to go is to use tf.data.Dataset.prefetch()
at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size
elements. It is usually enough to have buffer_size = 1
at the end:
dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1) # prefetch one batch
正如@mrry 在这个答案中所解释的那样,您还可以尝试稍微增加预取批次的数量.
As explained by @mrry in this answer, you can also try to increase the number of prefetched batches a bit.
通常在管道的最后添加一个小的预取缓冲区(可能只有一个元素)是最有用的,但是更复杂的管道可以从额外的预取中受益,特别是当生成单个元素的时间可以因人而异.
Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.
<小时>
如果与 GPU 计算相比,输入管道仍然很慢,则需要使用 tf.data.Dataset.map()
.
If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls
argument of tf.data.Dataset.map()
.
这篇关于结合使用 Estimators API 和 tf.data.Dataset 时如何加快批处理准备的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!