如何将 Pytorch Dataloader 转换为 numpy 数组以使用 matplotlib 显示图像数据? [英] How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib?

查看:160
本文介绍了如何将 Pytorch Dataloader 转换为 numpy 数组以使用 matplotlib 显示图像数据?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我是 Pytorch 的新手.在开始使用 CNN 进行训练之前,我一直在尝试学习如何查看输入图像.我很难将图像更改为可与 matplotlib 一起使用的形式.

I am new to Pytorch. I have been trying to learn how to view my input images before I begin training on my CNN. I am having a very hard time changing the images into a form that can be used with matplotlib.

到目前为止,我已经尝试过:

So far I have tried this:

from multiprocessing import freeze_support

import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam

import matplotlib.pyplot as plt
import numpy as np
import PIL

num_classes = 5
batch_size = 100
num_of_workers = 5

DATA_PATH_TRAIN = 'C:\UsersAeryesPycharmProjectssimplecnnimages\train'
DATA_PATH_TEST = 'C:\UsersAeryesPycharmProjectssimplecnnimages\test'

trans = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(32),
    transforms.CenterCrop(32),
    transforms.ToPImage(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
    ])

train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    print(npimg)
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))

def main():
    # get some random training images
    dataiter = iter(train_loader)
    images, labels = dataiter.next()

    # show images
    imshow(images)
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

if __name__ == "__main__":
    main()

然而,这会引发错误:

  [[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
    0.27450982]
   [0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
    0.40784314]
   [0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
    0.34509805]
   ...
   [0.2784314  0.21960783 0.2352941  ... 0.5803922  0.46666667
    0.25882354]
   [0.26666668 0.16862744 0.23137254 ... 0.2901961  0.29803923
    0.2509804 ]
   [0.30980393 0.39607844 0.28627452 ... 0.1490196  0.10588235
    0.19607842]]

  [[0.2352941  0.06274509 0.15686274 ... 0.09411764 0.3019608
    0.19215685]
   [0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
    0.3019608 ]
   [0.20392156 0.13333333 0.1607843  ... 0.16862744 0.2117647
    0.22745097]
   ...
   [0.18039215 0.16862744 0.1490196  ... 0.45882353 0.36078432
    0.16470587]
   [0.1607843  0.10588235 0.14117646 ... 0.2117647  0.18039215
    0.10980392]
   [0.18039215 0.3019608  0.2117647  ... 0.11372548 0.06274509
    0.04705882]]]


 ...


 [[[0.8980392  0.8784314  0.8509804  ... 0.627451   0.627451
    0.627451  ]
   [0.8509804  0.8235294  0.7921569  ... 0.54901963 0.5568628
    0.56078434]
   [0.7921569  0.7529412  0.7176471  ... 0.47058824 0.48235294
    0.49411765]
   ...
   [0.3764706  0.38431373 0.3764706  ... 0.4509804  0.43137255
    0.39607844]
   [0.38431373 0.39607844 0.3882353  ... 0.4509804  0.43137255
    0.39607844]
   [0.3882353  0.4        0.39607844 ... 0.44313726 0.42352942
    0.39215687]]

  [[0.9254902  0.90588236 0.88235295 ... 0.60784316 0.6
    0.5921569 ]
   [0.88235295 0.85490197 0.8235294  ... 0.5411765  0.5372549
    0.53333336]
   [0.8235294  0.7882353  0.75686276 ... 0.47058824 0.47058824
    0.47058824]
   ...
   [0.50980395 0.5176471  0.5137255  ... 0.58431375 0.5647059
    0.53333336]
   [0.5137255  0.53333336 0.5254902  ... 0.58431375 0.5686275
    0.53333336]
   [0.5176471  0.53333336 0.5294118  ... 0.5764706  0.56078434
    0.5294118 ]]

  [[0.95686275 0.9372549  0.90588236 ... 0.18823528 0.19999999
    0.20784312]
   [0.9098039  0.8784314  0.8352941  ... 0.1607843  0.17254901
    0.18039215]
   [0.84313726 0.7921569  0.7490196  ... 0.1372549  0.14509803
    0.15294117]
   ...
   [0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
    0.02745098]
   [0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
    0.03529412]
   [0.05098039 0.0745098  0.07843137 ... 0.12549019 0.10196078
    0.04705882]]]


 [[[0.30588236 0.28627452 0.24313724 ... 0.2901961  0.26666668
    0.21568626]
   [0.8156863  0.6666667  0.5921569  ... 0.18039215 0.23921567
    0.21568626]
   [0.9019608  0.83137256 0.85490197 ... 0.21960783 0.36862746
    0.23921567]
   ...
   [0.7058824  0.83137256 0.85490197 ... 0.2627451  0.24313724
    0.20784312]
   [0.7137255  0.84313726 0.84705883 ... 0.26666668 0.29803923
    0.21568626]
   [0.7254902  0.8235294  0.8392157  ... 0.2509804  0.27058825
    0.2352941 ]]

  [[0.24705881 0.22745097 0.19215685 ... 0.2784314  0.25490198
    0.19607842]
   [0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
    0.20392156]
   [0.5921569  0.4509804  0.49803922 ... 0.20784312 0.3764706
    0.2352941 ]
   ...
   [0.42352942 0.4627451  0.42352942 ... 0.23921567 0.23137254
    0.19999999]
   [0.45882353 0.5176471  0.35686275 ... 0.23921567 0.26666668
    0.19607842]
   [0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
    0.21568626]]

  [[0.23137254 0.20784312 0.1490196  ... 0.30588236 0.28627452
    0.19607842]
   [0.61960787 0.3764706  0.26666668 ... 0.16470587 0.24313724
    0.21568626]
   [0.57254905 0.43137255 0.48235294 ... 0.2235294  0.40392157
    0.25882354]
   ...
   [0.4        0.42352942 0.37254903 ... 0.25490198 0.24705881
    0.21568626]
   [0.43137255 0.4509804  0.29411766 ... 0.25882354 0.28235295
    0.20392156]
   [0.38431373 0.3529412  0.25490198 ... 0.2352941  0.25490198
    0.23137254]]]


 [[[0.06274509 0.09019607 0.11372548 ... 0.5803922  0.5176471
    0.59607846]
   [0.09411764 0.14509803 0.1372549  ... 0.5294118  0.49803922
    0.5058824 ]
   [0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
    0.38431373]
   ...
   [0.15294117 0.12941176 0.1607843  ... 0.85882354 0.8509804
    0.80784315]
   [0.14509803 0.10588235 0.1607843  ... 0.8666667  0.85882354
    0.8       ]
   [0.1490196  0.10588235 0.16470587 ... 0.827451   0.8156863
    0.7921569 ]]

  [[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
    0.6039216 ]
   [0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
    0.5372549 ]
   [0.03921568 0.0745098  0.09803921 ... 0.48235294 0.4392157
    0.4117647 ]
   ...
   [0.2117647  0.14509803 0.2784314  ... 0.43137255 0.3529412
    0.34117648]
   [0.2235294  0.11372548 0.2509804  ... 0.4509804  0.39607844
    0.2509804 ]
   [0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
    0.3254902 ]]

  [[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
    0.45490196]
   [0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
    0.3882353 ]
   [0.03921568 0.06666666 0.0862745  ... 0.3764706  0.33333334
    0.28235295]
   ...
   [0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
    0.09411764]
   [0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
    0.05490196]
   [0.12156862 0.11372548 0.17254901 ... 0.2352941  0.17254901
    0.14117646]]]]
Traceback (most recent call last):
  File "image_loader.py", line 51, in <module>
    main()
  File "image_loader.py", line 46, in main
    imshow(images)
  File "image_loader.py", line 38, in imshow
    plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
  File "C:UsersAeryesAppDataLocalProgramsPythonPython36libsite-packages
umpycorefromnumeric.py", line 598, in transpose
    return _wrapfunc(a, 'transpose', axes)
  File "C:UsersAeryesAppDataLocalProgramsPythonPython36libsite-packages
umpycorefromnumeric.py", line 51, in _wrapfunc
    return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose

我试图打印出数组以获取维度,但我不知道该怎么做.这很令人困惑.

I tried to print out the arrays to get the dimensions but I do not know what to make of this. It is very confusing.

这是我的直接问题:如何在使用 DataLoader 对象中的张量进行训练之前查看输入图像?

Here is my direct question: How do I view the input images before training using the tensors in my DataLoader object?

推荐答案

首先,dataloader输出4维张量——[batch, channel, height, width].Matplotlib 和其他图像处理库经常需要[height, width, channel].您使用转置是正确的,只是方式不对.

First of all, dataloader output 4 dimensional tensor - [batch, channel, height, width]. Matplotlib and other image processing libraries often requires [height, width, channel]. You are right about using the transpose, just not in the right way.

在您的 images 中会有很多图片,所以首先您需要选择一个(或编写一个 for 循环来保存所有图片).这将是简单的 images[i],通常我使用 i=0.

There will be a lot of images in your images so first you need to pick one (or write a for loop to save all of them). This will be simply images[i], typically I use i=0.

然后,您的转置应该将现在的 [channel, height, width] 张量转换为 [height, width, channel] 张量.为此,请使用 np.transpose(image.numpy(), (1, 2, 0)),与您的非常相似.

Then, your transpose should convert a now [channel, height, width] tensor to a [height, width, channel] one. To do this, use np.transpose(image.numpy(), (1, 2, 0)), very much like yours.

把它们放在一起,你应该有

Putting them together, you should have

plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))

有时您需要调用 .detach()(将这部分从计算图中分离)和 .cpu()(将数据从 GPU 传输到 CPU),具体取决于用例,这将是

Sometimes you need to call .detach() (detach this part from the computational graph) and .cpu() (transfer data from GPU to CPU) depending on the use case, that will be

plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))

这篇关于如何将 Pytorch Dataloader 转换为 numpy 数组以使用 matplotlib 显示图像数据?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆