Keras有状态LSTM fit_generator如何使用batch_size> 1个 [英] Keras Stateful LSTM fit_generator how to use batch_size > 1

查看:235
本文介绍了Keras有状态LSTM fit_generator如何使用batch_size> 1个的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想使用Keras中的功能性API训练有状态 LSTM网络.

I want to train an stateful LSTM network using the functional API in Keras.

fit方法为fit_generator.

我可以使用batch_size = 1

我的输入层是:

Input(shape=(n_history, n_cols),batch_shape=(batch_size, n_history, n_cols), 
    dtype='float32', name='daily_input')

生成器如下:

def training_data():
    while 1:       
        for i in range(0,pdf_daily_data.shape[0]-n_history,1):            
            x = f(i)() # f(i) shape is (1, n_history, n_cols)
            y = y(i)
            yield (x,y)

然后拟合为:

model.fit_generator(training_data(),
                    steps_per_epoch=pdf_daily_data.shape[0]//batch_size,...

此方法可以很好地工作和训练,但是,自batch_size = 1

This works and trains well, however, very slow and performing a gradient update at every time step since batch_size = 1

如何在此配置下设置batch_size > 1? 记住:LSTM层具有 stateful = True

How, within this configuration, can I set a batch_size > 1 ? remember: the LSTM layer has stateful = True

推荐答案

您将必须修改生成器,以yeld希望批处理具有的所需元素数量.

You will have to modify your generator to yeld the desired number of elements you want your batch to have.

当前,您正在逐元素遍历数据元素(根据您的range()第三个参数),获得单个 xy,然后产生该元素.当您返回单个元素时,您将获得一个batch_size=1,因为您的fit_generator正在逐个元素地训练.

Currently you are iterating over your data element by element (as per your third parameter of range()), obtaining a single x and y, and then yielding that element. As you are returning a single element you are obtaining a batch_size=1, as your fit_generator is training element by element.

假设您的批处理大小为10,那么您将必须对数据进行切片并获得每个包含10个元素的段,然后yield这些切片而不是单个元素.只需确保将相应的更改反映在输入层的形状上,并传递相应的batch_size.

Say you want your batch size to be 10, you will then have to slice your data and obtain segments of 10 elements each, and yield those slices instead of single elements. Just be sure that you reflect those changes accordingly on the shape of your Input layers, passing the corresponding batch_size.

这篇关于Keras有状态LSTM fit_generator如何使用batch_size> 1个的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆