如何在 Pytorch 中使用 torchvision.transforms 进行分割任务的数据增强? [英] How to use torchvision.transforms for data augmentation of segmentation task in Pytorch?

查看:39
本文介绍了如何在 Pytorch 中使用 torchvision.transforms 进行分割任务的数据增强?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我对 PyTorch 中执行的数据增强有点困惑.

I am a little bit confused about the data augmentation performed in PyTorch.

因为我们在处理分割任务,所以同样的数据增强我们需要数据和掩码,但其中一些是随机的,比如随机旋转.

Because we are dealing with segmentation tasks, we need data and mask for the same data augmentation, but some of them are random, such as random rotation.

Keras 提供了一个random seed保证data和mask做同样的操作,如下代码所示:

Keras provides a random seed guarantee that data and mask do the same operation, as shown in the following code:

    data_gen_args = dict(featurewise_center=True,
                         featurewise_std_normalization=True,
                         rotation_range=25,
                         horizontal_flip=True,
                         vertical_flip=True)


    image_datagen = ImageDataGenerator(**data_gen_args)
    mask_datagen = ImageDataGenerator(**data_gen_args)

    seed = 1
    image_generator = image_datagen.flow(train_data, seed=seed, batch_size=1)
    mask_generator = mask_datagen.flow(train_label, seed=seed, batch_size=1)

    train_generator = zip(image_generator, mask_generator)

在Pytorch官方文档中没有找到类似的描述,所以不知道如何保证data和mask能同步处理.

I didn't find a similar description in the official Pytorch documentation, so I don't know how to ensure that data and mask can be processed synchronously.

Pytorch 确实提供了这样的功能,但我想将其应用于自定义 Dataloader.

Pytorch does provide such a function, but I want to apply it to a custom Dataloader.

例如:

def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))

    temp_img = np.load(Image_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')

    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = np.uint8(img)
        mask = np.uint8(mask)
        img = self.transforms(img)
        mask = self.transforms(mask)

    return img, mask

这种情况下,img和mask会分开变换,因为随机旋转等一些操作是随机的,所以mask和image的对应关系可能会发生变化.换句话说,图像可能已经旋转但蒙版没有这样做.

In this case, img and mask will be transformed separately, because some operations such as random rotation are random, so the correspondence between mask and image may be changed. In other words, the image may have rotated but the mask did not do this.

我使用了 augmentations.py,但出现错误:

I used the method in augmentations.py, but I got an error::

Traceback (most recent call last):
  File "test_transform.py", line 87, in <module>
    for batch_idx, image, mask in enumerate(train_loader):
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in __next__
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 314, in <listcomp>
    batch = self.collate_fn([self.dataset[i] for i in indices])
  File "/home/dirk/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/utils/data/dataset.py", line 103, in __getitem__
    return self.dataset[self.indices[idx]]
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/data.py", line 164, in __getitem__
    img, mask = self.transforms(img, mask)
  File "/home/dirk/home/data/dirk/segmentation_unet_pytorch/augmentations.py", line 17, in __call__
    img, mask = a(img, mask)
TypeError: __call__() takes 2 positional arguments but 3 were given

这是我的__getitem__()代码:

data_transforms = {
    'train': Compose([
        RandomHorizontallyFlip(),
        RandomRotate(degree=25),
        transforms.ToTensor()
    ]),
}

train_set = DatasetUnetForTestTransform(fold=args.fold, random_index=args.random_index,transforms=data_transforms['train'])

# __getitem__ in class DatasetUnetForTestTransform
def __getitem__(self, index):
    img = np.zeros((self.im_ht, self.im_wd, channel_size))
    mask = np.zeros((self.im_ht, self.im_wd, channel_size))
    temp_img = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_label = np.load(Label_path + '{:0>4}'.format(self.patient_index[index]) + '.npy')
    temp_img, temp_label = crop_data_label_from_0(temp_img, temp_label)
    for i in range(channel_size):
        img[:,:,i] = temp_img[self.count[index] + i]
        mask[:,:,i] = temp_label[self.count[index] + i]

    if self.transforms:
        img = T.ToPILImage()(np.uint8(img))
        mask = T.ToPILImage()(np.uint8(mask))
        img, mask = self.transforms(img, mask)

    img = T.ToTensor()(img).copy()
    mask = T.ToTensor()(mask).copy()
    return img, mask

编辑 2

发现在ToTensor之后,相同标签之间的骰子变成了255而不是1,怎么解决?

EDIT 2

I found that after ToTensor, the dice between the same labels becomes 255 instead of 1, how to fix it?

# Dice computation
def DSC_computation(label, pred):
    pred_sum = pred.sum()
    label_sum = label.sum()
    inter_sum = np.logical_and(pred, label).sum()
    return 2 * float(inter_sum) / (pred_sum + label_sum)

请随意询问是否需要更多代码来解释问题.

Feel free to ask if more code is needed to explain the problem.

推荐答案

torchvision 也提供了类似的功能 [文档].

torchvision also provides similar functions [document].

这是一个简单的例子,

import torchvision
from torchvision import transforms

trans = transforms.Compose([transforms.CenterCrop((178, 178)),
                                    transforms.Resize(128),
                                    transforms.RandomRotation(20),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dset = torchvision.datasets.MNIST(data_root, transforms=trans)

<小时>

编辑

自定义您自己的 CelebA 数据集的简短示例.请注意,要应用转换,您需要在 __getitem__ 中调用 transform 列表.

class CelebADataset(Dataset):
    def __init__(self, root, transforms=None, num=None):
        super(CelebADataset, self).__init__()

        self.img_root = os.path.join(root, 'img_align_celeba')
        self.attr_root = os.path.join(root, 'Anno/list_attr_celeba.txt')
        self.transforms = transforms

        df = pd.read_csv(self.attr_root, sep='\s+', header=1, index_col=0)
        #print(df.columns.tolist())
        if num is None:
            self.labels = df.values
            self.img_name = df.index.values
        else:
            self.labels = df.values[:num]
            self.img_name = df.index.values[:num]

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_root, self.img_name[index]))
        # only use blond_hair, eyeglass, male, smile
        indices = [9, 15, 20, 31]
        label = np.take(self.labels[index], indices)
        label[label==-1] = 0

        if self.transforms is not None:
            img = self.transforms(img)

        return np.asarray(img), label

    def __len__(self):
        return len(self.labels)

<小时>

编辑 2

乍一看,我可能会错过一些东西.您的问题的重点是如何将相同"的数据预处理应用于 img 和标签.据我了解,没有可用的 Pytorch 内置函数.所以,我之前做的就是自己实现增强.


EDIT 2

I probably miss something at the first glance. The main point of your problem is how to apply "the same" data preprocessing to img and labels. To my understanding, there is no available Pytorch built-in function. So, what I did before is to implement the augmentation by myself.

class RandomRotate(object):
    def __init__(self, degree):
        self.degree = degree

    def __call__(self, img, mask):
        rotate_degree = random.random() * 2 * self.degree - self.degree
        return img.rotate(rotate_degree, Image.BILINEAR), 
                           mask.rotate(rotate_degree, Image.NEAREST)

注意输入应该是PIL格式.请参阅这个了解更多信息.

Note that the input should be PIL format. See this for more information.

这篇关于如何在 Pytorch 中使用 torchvision.transforms 进行分割任务的数据增强?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆