Keras使用tf.data.Dataset预测循环内存泄漏,但不使用numpy数组 [英] Keras predict loop memory leak using tf.data.Dataset but not with a numpy array

查看:217
本文介绍了Keras使用tf.data.Dataset预测循环内存泄漏,但不使用numpy数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

使用tf.data.Dataset馈送模型时,在遍历Keras模型predict函数时遇到内存泄漏并降低了性能,但在馈送numpy数组时却没有.

I encounter a memory leak and decreasing performance when looping over a Keras model predict function when using a tf.data.Dataset to feed the model, but not when feeding it with a numpy array.

有人知道造成此问题的原因和/或如何解决此问题的方法吗?

Does anyone understand what is causing this and/or how to resolve the issue?

最小的可复制代码段(可复制/粘贴可运行):

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()

debug_time = time.time()
while True:
    model.predict(x=ds, steps=1)
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()

结果:预测每次循环的时间大约为0.04s,在一两分钟之内最多达到0.5s,并且进程内存将从数百MB不断增加到接近GB.

Result: Predict loop timing starts around 0.04s per iteration, within a minute or two it's up to about 0.5s and process memory continues to increase from a few hundred MB to close to a GB.

tf.data.Dataset换成等效的numpy数组,运行时间始终约为0.01s.

Swap out the tf.data.Dataset for an equivalent numpy array and runtime is ~0.01s consistently.

工作案例代码段(可复制/粘贴可运行):

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)

debug_time = time.time()
while True:
    model.predict(x=np_data)  # using numpy array directly
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()


相关讨论:


Related discussions:

  • Memory leak tf.data + Keras - Doesn't seem to address the core issue, but the question appears similar.
  • https://github.com/tensorflow/tensorflow/issues/22098 - Possibly an open issue in Keras/Github, but I can't confirm it, changing inter_op_paralellism as suggested in that thread has no impact on the results posted here.

其他信息:

  • 通过传递迭代器而不是数据集对象,我可以将性能下降的速度降低大约10倍.我在training_utils.py:1314中注意到,Keras代码正在为每个预测调用创建一个迭代器.
  • I can reduce the rate of performance degradation by around 10x by passing in an iterator instead of a dataset object. I noticed in training_utils.py:1314 the Keras code is creating an iterator each call to predict.

TF 1.14.0

TF 1.14.0

推荐答案

问题的根源似乎是Keras在每个predict循环中创建数据集操作.注意,在每个预测循环中,在training_utils.py:1314处都会创建一个数据集迭代器.

The root of the problem appears to be that Keras is creating dataset operations each predict loop. Notice at training_utils.py:1314 a dataset iterator is created in each predict loop.

可以通过传入迭代器来减轻问题的严重性,并且可以通过传入迭代器get_next()张量来完全解决该问题.

The problem can be reduced in severity by passing in an iterator, and is solved entirely by passing in the iterators get_next() tensor.

我已在Tensorflow Github页面上发布了该问题: https://github.com/tensorflow/tensorflow/issues/30448

I have posted the issue on the Tensorflow Github page: https://github.com/tensorflow/tensorflow/issues/30448

这是解决方案,此示例使用TF数据集在恒定时间内运行,只是无法传递数据集对象:

Here is the solution, this example runs in constant time using the TF dataset, you just can't pass in the dataset object:

import tensorflow as tf
import numpy as np
import time

SIZE = 5000

inp = tf.keras.layers.Input(shape=(SIZE,), dtype='float32')
x = tf.keras.layers.Dense(units=SIZE)(inp)

model = tf.keras.Model(inputs=inp, outputs=x)

np_data = np.random.rand(1, SIZE)
ds = tf.data.Dataset.from_tensor_slices(np_data).batch(1).repeat()
it = tf.data.make_one_shot_iterator(ds)
tensor = it.get_next()

debug_time = time.time()
while True:
    model.predict(x=tensor, steps=1)
    print('Processing {:.2f}'.format(time.time() - debug_time))
    debug_time = time.time()

这篇关于Keras使用tf.data.Dataset预测循环内存泄漏,但不使用numpy数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆