内存泄漏tf.data + Keras [英] Memory leak tf.data + Keras

查看:295
本文介绍了内存泄漏tf.data + Keras的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的训练流水线中有内存泄漏,不知道如何解决.

I have a memory leak in my training pipeline and don't know how to fix it.

我在Python 3.5.2中使用Tensorflow版本:1.9.0和Keras(tf)版本:2.1.6-tf

I use Tensorflow version: 1.9.0 and Keras (tf) version: 2.1.6-tf with Python 3.5.2

这是我的训练管道的样子:

This is how my training pipeline looks like:

for i in range(num_epochs):

    training_data = training_set.make_one_shot_iterator().get_next()
    hist = model.fit(training_data[0],[training_data[1],training_data[2],training_data[3]],
                    steps_per_epoch=steps_per_epoch_train,epochs=1, verbose=1, callbacks=[history, MemoryCallback()])


    # custom validation

在迭代器用尽之后,似乎没有释放迭代器的内存.我已经在model.fit之后尝试了del traininig_data.没用

It looks like memory of the iterator is not freed after the iterator is exhausted. I have already tried del traininig_data after model.fit. It didn't work.

有人可以给些提示吗?

这就是我创建数据集的方式.

This is how I create the dataset.

dataset = tf.data.TFRecordDataset(tfrecords_filename)
dataset = dataset.map(map_func=preprocess_fn, num_parallel_calls=8)
dataset = dataset.shuffle(100)
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)

推荐答案

包括输入管道性能指南来了解根据您的要求对方法进行优化的优化顺序.

Including the repeat() method to reinitialize your iterator might solve your problem. You can take a look at Input Pipeline Performance Guide to figure out what would be the a good optimized order of your methods according to your requirements.

dataset = dataset.shuffle(100)
dataset = dataset.repeat() # Can specify num_epochs as input if needed
dataset = dataset.batch(batch_size=batch_size)
dataset = dataset.prefetch(1)

如果您有能力作为fit方法的一部分进行验证,则可以使用下面的代码,并且完全失去循环,从而使您的生活更轻松.

In case you can afford to do the validation as a part of the fit method, you can use something like the code below and lose the loop altogether to make your life easier.

training_data = training_set.make_one_shot_iterator().get_next()
# val_data refers to your validation data and steps_per_epochs_val refers to no of your validation batches
hist = model.fit(training_data[0],training_data[1],training_data[2],training_data[3]], validation_data=val_data.make_one_shot_iterator(), validation_steps=steps_per_epochs_val, 
       steps_per_epoch=steps_per_epoch_train, epochs=num_epochs, verbose=1, callbacks=[history, MemoryCallback()])

参考: https://github.com/keras -team/keras/blob/master/examples/mnist_dataset_api.py

这篇关于内存泄漏tf.data + Keras的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆