Keras自定义数据生成器,用于不适合内存的大型hdf5文件 [英] Keras custom data generator for large hdf5 file which does not fit into memory

查看:301
本文介绍了Keras自定义数据生成器,用于不适合内存的大型hdf5文件的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用预训练的InceptionV3模型对 food- 101个数据集,其中包含101个类别的食物图像,每个类别1000个.到目前为止,我已经将该数据集预处理为单个hdf5文件(我认为与训练时随身加载图像相比,这是有益的),其中包含以下表格:

I'm trying to use the pretrained InceptionV3 model to classify the food-101 dataset, which containts food images for 101 categories, 1000 per category. I've preprocessed this dataset into a single hdf5 file (I assumed this is beneficial compared to loading images on the go when training) so far, which has the following tables inside:

数据拆分是标准的70%训练,20%验证,10%测试,因此例如valid_img的大小为20200 * 299 * 299 * 3.标签是对Keras进行单编码的,因此有效标签的大小为20200 * 101.

The data split is the standard 70% train, 20% validation, 10% test, so for example the valid_img has a size of 20200*299*299*3. The labels are onehotencoded for Keras, so valid_labels has a size of 20200*101.

此hdf5文件的大小为27.1 GB ,因此它不适合我的内存. (有8 GB,尽管在运行Ubuntu时实际上只能使用4-5个演出.而且我的GPU是带有2 GB VRAM的GTX 960,到目前为止,当我尝试启动时,它似乎有1.5 GB可用于python训练脚本).我正在使用Tensorflow后端.

This hdf5 file has a size of 27.1 GB, so it will not fit into my memory. (Have 8 GB of it, although effectively only probably 4-5 gigs is usable while running Ubuntu. Also my GPU is GTX 960 with 2 GB of VRAM, and so far it looked like 1.5 GB is available for python when I try to start the training script). I'm using Tensorflow backend.

我的第一个想法是将model.train_on_batch()与这样的双嵌套for循环一起使用:

The first idea I had is to use model.train_on_batch() with a double nested for loop like this:

#Loading InceptionV3, adding my fully connected layers, compiling model...    

dataset = h5py.File('/home/uzoltan/PycharmProjects/food-101/food-101_299x299.hdf5', 'r')
    epoch = 50
    for i in range(epoch):
        for i in range(100): #1000 images can fit in the memory easily, this could probably be range(10) too
            train_images = dataset["train_img"][i * 706:(i + 1) * 706, ...]
            train_labels = dataset["train_labels"][i * 706:(i + 1) * 706, ...]
            val_images = dataset["valid_img"][i * 202:(i + 1) * 202, ...]
            val_labels = dataset["valid_labels"][i * 202:(i + 1) * 202, ...]
            model.train_on_batch(x=train_images, y=train_labels, class_weight=None,
                                 sample_weight=None, )

这种方法的问题是train_on_batch为验证或批处理改组提供0支持,因此批处理在每个时期的顺序都不相同.

My problem with this approach is that train_on_batch provides 0 support for validation or batch shuffling, so that the batches are not in the same order every epoch.

所以我看向了model.fit_generator(),它具有提供与fit()相同的所有功能的不错的属性,加上内置的ImageDataGenerator,您可以在屏幕上进行图像增强(旋转,水平翻转等).同时使用CPU,这样您的模型可以更强大.我的问题是,如果我理解正确的话,ImageDataGenerator.flow(x,y)方法立即需要所有样本和标签,但是我的训练/验证数据无法放入我的RAM中.

So I looked towards model.fit_generator() which has the nice property of providing all the same functionality as fit(), plus with the built in ImageDataGenerator you can do image augmentations (rotations, horizontal flips, etc.) at the same time with the CPU, so that your model can be more robust. My problem here is, that if I understand it correctly, the ImageDataGenerator.flow(x,y) method needs all the samples and labels at once, but my training/validation data wont fit into my RAM.

我认为这是自定义数据生成器出现的地方,但是在仔细研究了一些示例之后,我可以在Keras GitHub/Issues页面上找到,但我仍然没有真正了解如何实现自定义生成器,可以从我的hdf5文件中批量读取数据.有人可以为我提供一个很好的示例或指针吗?如何将自定义批处理生成器与图像增强功能结合在一起?还是为train_on_batch()实施某种手动验证和批量改组更容易?如果是这样,我也可以在其中使用一些指针.

Here is where I think custom data generators come into the picture, but after looking extensively at some examples I could find on the Keras GitHub/Issues page, I still dont really get how should I implement a custom generator, which would read in batches of data from my hdf5 file. Can someone provide me with a good example or pointers? How could I couple the custom batch generator with the image augmentations? Or maybe is it easier to implement some kind of manual validation and batch shuffling for train_on_batch()? If so, I could use some pointer there too.

推荐答案

对于仍在寻找答案的任何人,我围绕ImageDataGeneator的

For anyone still looking for an answer, I made the following "crude wrapper" around ImageDataGeneator's apply_transform method.

from numpy.random import uniform, randint
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import numpy as np

class CustomImagesGenerator:
    def __init__(self, x, zoom_range, shear_range, rescale, horizontal_flip, batch_size):
        self.x = x
        self.zoom_range = zoom_range
        self.shear_range = shear_range
        self.rescale = rescale
        self.horizontal_flip = horizontal_flip
        self.batch_size = batch_size
        self.__img_gen = ImageDataGenerator()
        self.__batch_index = 0

    def __len__(self):
        # steps_per_epoch, if unspecified, will use the len(generator) as a number of steps.
        # hence this
        return np.floor(self.x.shape[0]/self.batch_size)

    def next(self):
        return self.__next__()

    def __next__(self):
        start = self.__batch_index*self.batch_size
        stop = start + self.batch_size
        self.__batch_index += 1
        if stop > len(self.x):
            raise StopIteration
        transformed = np.array(self.x[start:stop])  # loads from hdf5
        for i in range(len(transformed)):
            zoom = uniform(self.zoom_range[0], self.zoom_range[1])
            transformations = {
                'zx': zoom,
                'zy': zoom,
                'shear': uniform(-self.shear_range, self.shear_range),
                'flip_horizontal': self.horizontal_flip and bool(randint(0,2))
            }
            transformed[i] = self.__img_gen.apply_transform(transformed[i], transformations)
        return transformed * self.rescale

可以这样称呼:

import h5py
f = h5py.File("my_heavy_dataset_file.hdf5", 'r')
images = f['mydatasets/images']

my_gen = CustomImagesGenerator(
    images, 
    zoom_range=[0.8, 1], 
    shear_range=6, 
    rescale=1./255, 
    horizontal_flip=True, 
    batch_size=64
)

model.fit_generator(my_gen)

这篇关于Keras自定义数据生成器,用于不适合内存的大型hdf5文件的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆