在Pytorch上具有相同标签的点的批次 [英] Batches of points with the same label on Pytorch

查看:76
本文介绍了在Pytorch上具有相同标签的点的批次的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在每个包含N个训练点的批次上使用梯度下降训练一个神经网络.我希望这些批次仅包含具有相同标签的点,而不是从训练集中随机取样.

I want to train a neural network using gradient descent on batches that contain N training points each. I would like these batches to only contain points with the same label, instead of being randomly sampled from the training set.

例如,如果我正在使用MNIST进行训练,我希望拥有如下所示的批次:

For example, if I'm training using MNIST, I would like to have batches that look like the following:

batch_1 = {0,0,0,0,0,0,0,0}

batch_2 = {3,3,3,3,3,3,3,3}

batch_3 = {7,7,7,7,7,7,7,7}

...

,依此类推.

我该如何使用pytorch?

How can I do it using pytorch?

推荐答案

一种方法是为每个类创建子集和数据加载器,然后在每次迭代时通过在数据加载器之间随机切换来进行迭代:

One way to do it is to create subsets and dataloaders for each class and then iterate by randomly switching between the dataloaders at each iteration:

import torch
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST
from torchvision import transforms
import numpy as np

dataset = MNIST('path/to/mnist_root/', 
                transform=transforms.ToTensor(),
                download=True)

class_inds = [torch.where(dataset.targets == class_idx)[0]
              for class_idx in dataset.class_to_idx.values()]

dataloaders = [
    DataLoader(
        dataset=Subset(dataset, inds),
        batch_size=8,
        shuffle=True,
        drop_last=False)
    for inds in class_inds]

epochs = 1

for epoch in range(epochs):
    iterators = list(map(iter, dataloaders))   
    while iterators:         
        iterator = np.random.choice(iterators)
        try:
            images, labels = next(iterator)   
            print(labels)
            # do_more_stuff()

        except StopIteration:
            iterators.remove(iterator)

这将适用于任何数据集(不仅限于MNIST).这是每次迭代打印标签的结果:

This will work with any dataset (not just the MNIST). Here's the result of printing the labels at each iteration:

tensor([6, 6, 6, 6, 6, 6, 6, 6])
tensor([3, 3, 3, 3, 3, 3, 3, 3])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([5, 5, 5, 5, 5, 5, 5, 5])
tensor([8, 8, 8, 8, 8, 8, 8, 8])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
...
tensor([1, 1, 1, 1, 1, 1, 1, 1])
tensor([1, 1, 1, 1, 1, 1])

请注意,通过设置 drop_last = False ,到处将存在具有少于 batch_size 个元素的批次.通过将其设置为True,批次将全部相等,但是将删除一些数据点.

Note that by setting drop_last=False, there will be batches, here and there, with less than batch_size elements. By setting it to True, the batches will be all of equal size, but some data points will be dropped.

这篇关于在Pytorch上具有相同标签的点的批次的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆