批量加载的数据量可变吗? [英] Data loading with variable batch size?

查看:100
本文介绍了批量加载的数据量可变吗?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我目前正在研究基于补丁的超分辨率.大多数论文将图像分成较小的补丁,然后将补丁用作模型的输入.我能够使用自定义数据加载器创建补丁.代码如下:

I am currently working on patch based super-resolution. Most of the papers divide an image into smaller patches and then use the patches as input to the models.I was able to create patches using custom dataloader. The code is given below:

import torch.utils.data as data
from torchvision.transforms import CenterCrop, ToTensor, Compose, ToPILImage, Resize, RandomHorizontalFlip, RandomVerticalFlip
from os import listdir
from os.path import join
from PIL import Image
import random
import os
import numpy as np
import torch

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"])

class TrainDatasetFromFolder(data.Dataset):
    def __init__(self, dataset_dir, patch_size, is_gray, stride):
        super(TrainDatasetFromFolder, self).__init__()
        self.imageHrfilenames = []
        self.imageHrfilenames.extend(join(dataset_dir, x)
                                     for x in sorted(listdir(dataset_dir)) if is_image_file(x))
        self.is_gray = is_gray
        self.patchSize = patch_size
        self.stride = stride

    def _load_file(self, index):
        filename = self.imageHrfilenames[index]
        hr = Image.open(self.imageHrfilenames[index])
        downsizes = (1, 0.7, 0.45)
        downsize = 2
        w_ = int(hr.width * downsizes[downsize])
        h_ = int(hr.height * downsizes[downsize])
        aug = Compose([Resize([h_, w_], interpolation=Image.BICUBIC),
                       RandomHorizontalFlip(),
                       RandomVerticalFlip()])

        hr = aug(hr)
        rv = random.randint(0, 4)
        hr = hr.rotate(90*rv, expand=1)
        filename = os.path.splitext(os.path.split(filename)[-1])[0]
        return hr, filename

    def _patching(self, img):

        img = ToTensor()(img)
        LR_ = Compose([ToPILImage(), Resize(self.patchSize//2, interpolation=Image.BICUBIC), ToTensor()])

        HR_p, LR_p = [], []
        for i in range(0, img.shape[1] - self.patchSize, self.stride):
            for j in range(0, img.shape[2] - self.patchSize, self.stride):
                temp = img[:, i:i + self.patchSize, j:j + self.patchSize]
                HR_p += [temp]
                LR_p += [LR_(temp)]

        return torch.stack(LR_p),torch.stack(HR_p)

    def __getitem__(self, index):
        HR_, filename = self._load_file(index)
        LR_p, HR_p = self._patching(HR_)
        return LR_p, HR_p

    def __len__(self):
        return len(self.imageHrfilenames)

假设批处理大小为1,则它拍摄一张图像并给出大小为[x,3,patchsize,patchsize]的输出.当批处理大小为2时,我将有两个不同的输出,大小为[x,3,patchsize,patchsize](例如,图像1可以给出[50,3,patchsize,patchsize],图像2可以给出[75,3,patchsize,patchsize]).为了解决这个问题,需要一个自定义的整理函数,该函数沿维度0堆叠这两个输出.整理函数如下:

Suppose the batch size is 1, it takes an image and gives an output of size [x,3,patchsize,patchsize]. When batch size is 2, I will have two different outputs of size [x,3,patchsize,patchsize] (for example image 1 may give[50,3,patchsize,patchsize], image 2 may give[75,3,patchsize,patchsize] ). To handle this a custom collate function was required that stacks these two outputs along dimension 0. The collate function is given below:

def my_collate(batch):
    data = torch.cat([item[0] for item in batch],dim = 0)
    target = torch.cat([item[1] for item in batch],dim = 0)

    return [data, target]

此归类函数沿x串联(从上面的示例中,我最终得到[125,3,patchsize,pathsize].出于训练目的,我需要使用最小批量大小为25来训练模型.是否有任何方法或函数可以可以直接使用必要数量的图像作为数据输入到数据加载器,直接从数据加载器直接获得大小为[25 , 3, patchsize, pathsize]的输出?

This collate function concatenates along x (From the above example, I finally get [125,3,patchsize,pathsize]. For training purposes, I need to train the model using a minibatch size of say 25. Is there any method or any functions which I can use to directly get an output of size [25 , 3, patchsize, pathsize] directly from the dataloader using the necessary number of images as input to the Dataloader?

推荐答案

以下代码段适用于您的目的.

The following code snippet works for your purpose.

首先,我们定义一个ToyDataset,它接受variable length in dimension 0的张量(tensors)的列表.这类似于您的数据集返回的样本.

First, we define a ToyDataset which takes in a list of tensors (tensors) of variable length in dimension 0. This is similar to the samples returned by your dataset.

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import RandomSampler

class ToyDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(tensors)

第二,我们定义一个自定义数据加载器.创建数据集和数据加载器的通常的Pytorch二分法大致如下:有一个索引dataset,您可以向其传递索引,并从数据集中返回相关的样本.有一个sampler会产生一个索引,有不同的绘制索引的策略会引起不同的采样器. batch_sampler使用采样器一次绘制多个索引(与batch_size指定的数量相同).有一个dataloader结合了采样器和数据集,可以让您遍历一个数据集,重要的是数据加载器还拥有一个函数(collate_fn),该函数指定应如何使用batch_sampler的索引从数据集中检索到多个样本.结合.对于您的用例,通常的PyTorch二分法效果不佳,因为除了绘制固定数量的索引外,我们还需要绘制索引,直到与索引关联的对象超过所需的累积大小为止.这意味着我们需要立即检查对象,并使用此知识来决定是退还批次还是保留工程图索引.这是下面的自定义数据加载器的作用:

Secondly, we define a custom data loader. The usual Pytorch dichotomy to create datasets and data loaders is roughly the following: There is an indexed dataset, to which you can pass an index and it returns the associated sample from the dataset. There is a sampler which yields an index, there are different strategies to draw indices which give rise to different samplers. The sampler is used by a batch_sampler to draw multiple indices at once (as many as specified by batch_size). There is a dataloader which combines sampler and dataset to let you iterate over a dataset, importantly the data loader also owns a function (collate_fn) which specifies how the multiple samples retrieved from the dataset using the indices from the batch_sampler should be combined. For your use case, the usual PyTorch dichotomy does not work well, because instead of drawing a fixed number of indices, we need to draw indices until the objects associated with the indices exceed the cumulative size we desire. This means we need immediate inspection of the objects and use this knowledge to decide whether to return a batch or keep drawing indices. This is what the custom data loader below does:

class CustomLoader(object):

    def __init__(self, dataset, my_bsz, drop_last=True):
        self.ds = dataset
        self.my_bsz = my_bsz
        self.drop_last = drop_last
        self.sampler = RandomSampler(dataset)

    def __iter__(self):
        batch = torch.Tensor()
        for idx in self.sampler:
            batch = torch.cat([batch, self.ds[idx]])
            while batch.size(0) >= self.my_bsz:
                if batch.size(0) == self.my_bsz:
                    yield batch
                    batch = torch.Tensor()
                else:
                    return_batch, batch = batch.split([self.my_bsz,batch.size(0)-self.my_bsz])
                    yield return_batch
        if batch.size(0) > 0 and not self.drop_last:
            yield batch

在这里我们遍历数据集,绘制索引并加载关联的对象后,将其连接到我们之前绘制的张量(batch).我们一直这样做,直到达到所需的大小,这样我们才能切出并批量生产.我们将行保留在batch中,但没有产生.因为单个实例可能超过了所需的batch_size,所以我们使用while loop.

Here we iterate over the dataset, after drawing an index and loading the associated object, we concatenate it to the tensors we drew before (batch). We keep doing this until we reach the desired size, such that we can cut out and yield a batch. We retain the rows in batch, which we did not yield. Because it may be the case that a single instance exceeds the desired batch_size, we use a while loop.

您可以修改此最小的CustomDataloader以添加PyTorch数据加载器样式的更多功能.也不需要使用RandomSampler提取索引,其他索引也可以很好地工作.如果您的数据很大,例如通过使用列表并跟踪其张量的累积长度,也可以避免重复出现.

You could modify this minimal CustomDataloader to add more features in the style of PyTorch's dataloader. There is also no need to use a RandomSampler to draw in indices, others would work equally well. It would also be possible to avoid repeated concats, in case your data is large by using for example a list and keeping track of the cumulative length of its tensors.

这里是一个示例,演示了它的工作原理:

Here is an example, that demonstrates it works:

patch_size = 5
channels = 3
dim0sizes = torch.LongTensor(100).random_(1, 100)
data = torch.randn(size=(dim0sizes.sum(), channels, patch_size, patch_size))
tensors = torch.split(data, list(dim0sizes))

ds = ToyDataset(tensors)
dl = CustomLoader(ds, my_bsz=250, drop_last=False)
for i in dl:
    print(i.size(0))

这篇关于批量加载的数据量可变吗?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆