使用 Keras 的生成器 model.fit_generator [英] Use a generator for Keras model.fit_generator

查看:44
本文介绍了使用 Keras 的生成器 model.fit_generator的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在编写用于训练 Keras 模型的自定义生成器时,我最初尝试使用 generator 语法.所以我yield来自__next__.但是,当我尝试使用 model.fit_generator 训练我的模式时,我会收到一个错误,指出我的生成器不是迭代器.修复方法是将 yield 更改为 return,这也需要重新调整 __next__ 的逻辑以跟踪状态.与让 yield 为我完成工作相比,这相当麻烦.

I originally tried to use generator syntax when writing a custom generator for training a Keras model. So I yielded from __next__. However, when I would try to train my mode with model.fit_generator I would get an error that my generator was not an iterator. The fix was to change yield to return which also necessitated rejiggering the logic of __next__ to track state. It's quite cumbersome compared to letting yield do the work for me.

有什么方法可以让我用 yield 完成这项工作?如果我必须使用 return 语句,我将需要编写更多的迭代器,这些迭代器必须具有非常笨拙的逻辑.

Is there a way I can make this work with yield? I will need to write several more iterators that will have to have very clunky logic if I have to use a return statement.

推荐答案

我无法帮助调试您的代码,因为您没有发布它,但我缩写了我为语义分割项目编写的自定义数据生成器供您使用用作模板:

I can't help debug your code since you didn't post it, but I abbreviated a custom data generator I wrote for a semantic segmentation project for you to use as a template:

def generate_data(directory, batch_size):
    """Replaces Keras' native ImageDataGenerator."""
    i = 0
    file_list = os.listdir(directory)
    while True:
        image_batch = []
        for b in range(batch_size):
            if i == len(file_list):
                i = 0
                random.shuffle(file_list)
            sample = file_list[i]
            i += 1
            image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)
            image_batch.append((image.astype(float) - 128) / 128)

        yield np.array(image_batch)

用法:

model.fit_generator(
    generate_data('~/my_data', batch_size),
    steps_per_epoch=len(os.listdir('~/my_data')) // batch_size)

这篇关于使用 Keras 的生成器 model.fit_generator的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆