使用Keras model.fit_generator生成器 [英] Use a generator for Keras model.fit_generator
问题描述
在编写用于训练Keras模型的自定义生成器时,我最初尝试使用 generator
语法。所以我从 __ next __
编辑 yield
。但是,当我尝试用 model.fit_generator
训练我的模式时,我会得到一个错误,我的生成器不是迭代器。解决方法是将 yield
更改为 return
,这也需要重新调整 __ next __ 的逻辑
跟踪状态。与让 yield
为我做的工作相比,这是非常麻烦的。
I originally tried to use generator
syntax when writing a custom generator for training a Keras model. So I yield
ed from __next__
. However, when I would try to train my mode with model.fit_generator
I would get an error that my generator was not an iterator. The fix was to change yield
to return
which also necessitated rejiggering the logic of __next__
to track state. It's quite cumbersome compared to letting yield
do the work for me.
我有没有办法让这项工作收益率
?如果我必须使用 return
语句,我将需要编写几个必须具有非常笨重逻辑的迭代器。
Is there a way I can make this work with yield
? I will need to write several more iterators that will have to have very clunky logic if I have to use a return
statement.
推荐答案
我无法调试您的代码,因为您没有发布它,但我缩写了我为语义分段项目编写的自定义数据生成器,供您用作模板:
I can't help debug your code since you didn't post it, but I abbreviated a custom data generator I wrote for a semantic segmentation project for you to use as a template:
def generate_data(directory, batch_size):
"""Replaces Keras' native ImageDataGenerator."""
i = 0
file_list = os.listdir(directory)
while True:
image_batch = []
for b in range(batch_size):
if i == len(file_list):
i = 0
random.shuffle(file_list)
sample = file_list[i]
i += 1
image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)
image_batch.append((image.astype(float) - 128) / 128)
yield np.array(image_batch)
用法:
model.fit_generator(
generate_data('~/my_data', batch_size),
steps_per_epoch=len(os.listdir('~/my_data')) // batch_size)
这篇关于使用Keras model.fit_generator生成器的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!