仅扩充K折交叉验证中的训练集 [英] Augmenting only the training set in K-folds cross validation

查看:141
本文介绍了仅扩充K折交叉验证中的训练集的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试为不平衡的数据集(类别0 = 4000张图像,类别1 =大约250张图像)创建一个二进制CNN分类器,我想对其执行5倍交叉验证。目前,我正在将训练集加载到ImageLoader中,该图像集将应用转换/增强(?)并将其加载到DataLoader中。但是,这会导致我的训练部分和验证部分都包含增强的数据。

I am trying to create a binary CNN classifier for an unbalanced dataset (class 0 = 4000 images, class 1 = around 250 images), which I want to perform 5-fold cross validation on. Currently I am loading my training set into an ImageLoader that applies my transformations/augmentations(?) and loads it into a DataLoader. However, this results in both my training splits and validation splits containing the augmented data.

我最初是离线使用转换(离线增强?)来平衡我的数据集,但是从这个线程( https://stats.stackexchange.com/questions/175504/如何进行数据扩展和训练验证拆分),似乎仅增加训练集是理想的选择。我也更愿意在完全扩充的训练数据上训练我的模型,然后通过5倍交叉验证在非增强数据上对其进行验证

I originally applied transformations offline (offline augmentation?) to balance my dataset, but from this thread (https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split), it seems it would be ideal to only augment the training set. I would also prefer to train my model on solely augmented training data and then validate it on non-augmented data in a 5-fold cross validation

我的数据被组织为根/标签/图像,其中有2个标签文件夹(0和1),并且图像按各自的标签分类。

My data is organized as root/label/images, where there are 2 label folders (0 and 1) and images sorted into the respective labels.

total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])

//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)

model.train()
//Model train/eval works but may be overpredict 

我确定我在代码中做的不是很理想或者做错了,但是我似乎无法可以找到有关仅专门增强交叉验证中的训练内容的任何文档!

I'm sure I'm doing something sub-optimally or wrong in this code, but I can't seem to find any documentation on specifically augmenting only the training splits in cross-validation!

任何帮助将不胜感激!

推荐答案

一种方法是实现包装器Dataset类,该类将转换应用于ImageFolder数据集的输出。例如

One approach is to implement a wrapper Dataset class that applies transforms to the output of your ImageFolder dataset. For example

class WrapperDataset:
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        image, label = self.dataset[index]
        if self.transform is not None:
            image = self.transform(image)
        if self.target_transform is not None:
            label = self.target_transform(label)
        return image, label

    def __len__(self):
        return len(self.dataset)

然后您可以在代码中使用不同的转换包装较大的数据集。

Then you could use this in your code by wrapping the larger dataset with different transforms.

total_set = datasets.ImageFolder(ROOT)

# Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in splits.split(total_set):
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['train_transforms']),
        batch_size=32, sampler=train_sampler)
    valid_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
        batch_size=32, sampler=valid_sampler)

    # train/validate now

我没有测试此代码,因为我没有完整的代码/模型,但概念应该很清楚。

I haven't tested this code since I don't have your full code/models but the concept should be clear.

这篇关于仅扩充K折交叉验证中的训练集的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆