Keras中的fit_generator是否应该在每个时期后重置生成器? [英] Is fit_generator in Keras supposed to reset the generator after each epoch?

查看:66
本文介绍了Keras中的fit_generator是否应该在每个时期后重置生成器?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试将 fit_generator 与自定义生成器一起使用,以读取对于内存而言太大的数据.我想训练125万行,因此我一次生成了50,000行的生成器. fit_generator 有25个 steps_per_epoch ,我认为每纪元会引入1.25MM.我添加了一条打印语句,以便可以看到该过程正在执行的偏移量,当发现进入步骤2时,我发现它超出了最大值.该文件中总共有175万条记录,并且一次它经过10个步骤,因此在 create_feature_matrix 调用中得到了索引错误(因为它不带任何行).

I am trying to use fit_generator with a custom generator to read in data that's too big for memory. There are 1.25 million rows I want to train on, so I have the generator yield 50,000 rows at a time. fit_generator has 25 steps_per_epoch, which I thought would bring in those 1.25MM per epoch. I added a print statement so that I could see how much offset the process was doing, and I found that it exceeded the max when it got a few steps into epoch 2. There are a total of 1.75 million records in that file, and once it passes 10 steps, it gets an index error in the create_feature_matrix call (because it brings in no rows).

def get_next_data_batch():
    import gc
    nrows = 50000
    skiprows = 0

    while True:
        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
        print(skiprows)
        x,y = create_feature_matrix(d)
        yield x,y
        skiprows = skiprows + nrows
        gc.collect()
get_data = get_next_data_batch()

... set up a Keras NN ...

model.fit_generator(get_next_data_batch(), epochs=100,steps_per_epoch=25,verbose=1,workers=4,callbacks=callbacks_list)

我使用fit_generator错误吗?或者需要对自定义生成器进行一些更改才能使其正常工作吗?

Am I using fit_generator wrong or is there some change that needs to be made to my custom generator to get this to work?

推荐答案

否- fit_generator 不会重置Generator,它只是继续调用它.为了实现您想要的行为,您可以尝试以下操作:

No - fit_generator doesn't reset generator, it's simply continuing calling it. In order to achieve the behavior you want you may try the following:

def get_next_data_batch(nb_of_calls_before_reset=25):
    import gc
    nrows = 50000
    skiprows = 0
    nb_calls = 0

    while True:
        d = pd.read_csv(file_loc,skiprows=range(1,skiprows),nrows=nrows,index_col=0)
        print(skiprows)
        x,y = create_feature_matrix(d)
        yield x,y
        nb_calls += 1
        if nb_calls == nb_of_calls_before_reset:
            skiprows = 0
        else:
            skiprows = skiprows + nrows
        gc.collect()

这篇关于Keras中的fit_generator是否应该在每个时期后重置生成器?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆