Keras:如果batch_size无法整除数据大小,该怎么办? [英] Keras: What if the size of data is not divisible by batch_size?

查看:197
本文介绍了Keras:如果batch_size无法整除数据大小,该怎么办?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是Keras的新手,刚刚开始研究一些示例.我正在处理以下问题:我有4032个样本,其中约650个样本用于拟合或基本处于训练状态,然后将其余样本用于测试模型.问题是我不断收到以下错误:

I am new to Keras and just started working on some examples. I am dealing with the following problem: I have 4032 samples and use about 650 of them as for the fit or basically the training state and then use the rest for testing the model. The problem is that I keep getting the following error:

Exception: In a stateful network, you should only pass inputs with a number of samples that can be divided by the batch size.

我理解为什么会出现此错误,我的问题是,如果我的数据大小不能被batch_size整除,该怎么办?我曾经使用Deeplearning4j LSTM进行操作,而不必处理此问题.反正有解决这个问题的方法吗?

I understand why I am getting this error, my question is, what if the size of my data is not divisible by batch_size? I used to work with Deeplearning4j LSTM and did not have to deal with this problem. Is there anyway to get around with this?

谢谢

推荐答案

最简单的解决方案是使用fit_generator代替fit.我编写了一个简单的dataloader类,该类可以继承来做更复杂的事情.看起来像这样,将get_next_batch_data重新定义为您的数据(包括增强等).

The simplest solution is to use fit_generator instead of fit. I write a simple dataloader class that can be inherited to do more complex stuff. It would look something like this with get_next_batch_data redefined to whatever your data is including stuff like augmentation etc..

class BatchedLoader():
    def __init__(self):
        self.possible_indices = [0,1,2,...N] #(say N = 33)
        self.cur_it = 0
        self.cur_epoch = 0

    def get_batch_indices(self):
        batch_indices = self.possible_indices [cur_it : cur_it + batchsize]
        # If len(batch_indices) < batchsize, the you've reached the end
        # In that case, reset cur_it to 0 and increase cur_epoch and shuffle possible_indices if wanted
        # And add remaining K = batchsize - len(batch_indices) to batch_indices


    def get_next_batch_data(self):
        # batch_indices = self.get_batch_indices()
        # The data points corresponding to those indices will be your next batch data

这篇关于Keras:如果batch_size无法整除数据大小,该怎么办?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆