Keras:如何扩展validation_split以生成第三组(即测试集)? [英] Keras: How to expand validation_split to generate a third set i.e. test set?

查看:520
本文介绍了Keras:如何扩展validation_split以生成第三组(即测试集)?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我将Keras与TensorFlow后端一起使用.我正在将ImageDataGenerator与validation_split参数一起使用,以将数据拆分为训练集和验证集.因此,我将flow_from_directory的子集设置为"training"和"testing",如下所示:

I am using Keras with a TensorFlow backend. I am using the ImageDataGenerator with the validation_split argument to split my data into train set and validation set. As such, I use flow_from_directory with the subset set to "training" and "testing" like so:

total_gen = ImageDataGenerator(validation_split=0.3)


train_gen = data_generator.flow_from_directory(my_dir, target_size=(input_size, input_size), shuffle=False, seed=13,
                                                     class_mode='categorical', batch_size=BATCH_SIZE, subset="training")

valid_gen = data_generator.flow_from_directory(my_dir, target_size=(input_size, input_size), shuffle=False, seed=13,
                                                     class_mode='categorical', batch_size=32, subset="validation")

这非常方便,因为它允许我仅使用一个目录,而不是两个目录(一个用于培训,一个用于验证).现在,我想知道是否有可能扩展此过程以生成第三组,即测试集?

This is amazingly convenient, as it allows me to use only one directory instead of two (one for training and one for validation). Now I wonder if it is possible to expand this process in order to generating a third set i.e. test set?

推荐答案

这不可能是开箱即用的.您应该可以对

This is not possible out of the box. You should be able to do it with some minor modifications to the source code of ImageDataGenerator:

if subset is not None:
    if subset not in {'training', 'validation'}: # add a third subset here
        raise ValueError('Invalid subset name:', subset,
                         '; expected "training" or "validation".') # adjust message
    split_idx = int(len(x) * image_data_generator._validation_split) 
    # you'll need two split indices here
    if subset == 'validation':
        x = x[:split_idx]
        x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
        if y is not None:
            y = y[:split_idx]
    elif subset == '...' # add extra case here

    else:
        x = x[split_idx:]
        x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] # change slicing
        if y is not None:
            y = y[split_idx:] # change slicing

编辑:这是修改代码的方式:

Edit: this is how you could modify the code:

if subset is not None:
    if subset not in {'training', 'validation', 'test'}:
        raise ValueError('Invalid subset name:', subset,
                         '; expected "training" or "validation" or "test".')
    split_idxs = (int(len(x) * v) for v in image_data_generator._validation_split)
    if subset == 'validation':
        x = x[:split_idxs[0]]
        x_misc = [np.asarray(xx[:split_idxs[0]]) for xx in x_misc]
        if y is not None:
            y = y[:split_idxs[0]]
    elif subset == 'test':
        x = x[split_idxs[0]:split_idxs[1]]
        x_misc = [np.asarray(xx[split_idxs[0]:split_idxs[1]]) for xx in x_misc]
        if y is not None:
            y = y[split_idxs[0]:split_idxs[1]]
    else:
        x = x[split_idxs[1]:]
        x_misc = [np.asarray(xx[split_idxs[1]:]) for xx in x_misc]
        if y is not None:
            y = y[split_idxs[1]:]

基本上,validation_split现在应该是两个浮点数而不是单个浮点数的元组.验证数据将是介于0和validation_split[0]之间的数据,介于validation_split[0] and validation_split[1]之间的测试数据和介于validation_split[1]和1之间的训练数据的一部分.这是您可以使用的方式:

Basically, validation_split is now expected to be a tuple of two floats instead of a single float. The validation data will be the fraction of data between 0 and validation_split[0], test data between validation_split[0] and validation_split[1] and training data between validation_split[1] and 1. This is how you can use it:

import keras
# keras_custom_preprocessing is how i named my directory
from keras_custom_preprocessing.image import ImageDataGenerator

generator = ImageDataGenerator(validation_split=(0.1, 0.5))
# First 10%: validation data - next 40% test data - rest: training data        
gen = generator.flow_from_directory(directory='./data/', subset='test')
# Finds 40% of the images in the dir

您将需要在其他两到三行中修改文件(必须更改类型检查),但是仅此而已,就可以了.我有修改过的文件,如果您有兴趣,请告诉我,我可以将其托管在我的github上.

You will need to modify the file in two or three additional lines (there is a typecheck you will have to change), but that's it and that should work. I have the modified file, let me know if you are interested, I can host it on my github.

这篇关于Keras:如何扩展validation_split以生成第三组(即测试集)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆