如何将 Pytorch Dataloader 转换为 numpy 数组以使用 matplotlib 显示图像数据? [英] How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib?
问题描述
我是 Pytorch 的新手.在开始使用 CNN 进行训练之前,我一直在尝试学习如何查看输入图像.我很难将图像更改为可与 matplotlib 一起使用的形式.
I am new to Pytorch. I have been trying to learn how to view my input images before I begin training on my CNN. I am having a very hard time changing the images into a form that can be used with matplotlib.
到目前为止,我已经尝试过:
So far I have tried this:
from multiprocessing import freeze_support
import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
import PIL
num_classes = 5
batch_size = 100
num_of_workers = 5
DATA_PATH_TRAIN = 'C:\UsersAeryesPycharmProjectssimplecnnimages\train'
DATA_PATH_TEST = 'C:\UsersAeryesPycharmProjectssimplecnnimages\test'
trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToPImage(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
print(npimg)
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
def main():
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(images)
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
if __name__ == "__main__":
main()
然而,这会引发错误:
[[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
0.27450982]
[0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
0.40784314]
[0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
0.34509805]
...
[0.2784314 0.21960783 0.2352941 ... 0.5803922 0.46666667
0.25882354]
[0.26666668 0.16862744 0.23137254 ... 0.2901961 0.29803923
0.2509804 ]
[0.30980393 0.39607844 0.28627452 ... 0.1490196 0.10588235
0.19607842]]
[[0.2352941 0.06274509 0.15686274 ... 0.09411764 0.3019608
0.19215685]
[0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
0.3019608 ]
[0.20392156 0.13333333 0.1607843 ... 0.16862744 0.2117647
0.22745097]
...
[0.18039215 0.16862744 0.1490196 ... 0.45882353 0.36078432
0.16470587]
[0.1607843 0.10588235 0.14117646 ... 0.2117647 0.18039215
0.10980392]
[0.18039215 0.3019608 0.2117647 ... 0.11372548 0.06274509
0.04705882]]]
...
[[[0.8980392 0.8784314 0.8509804 ... 0.627451 0.627451
0.627451 ]
[0.8509804 0.8235294 0.7921569 ... 0.54901963 0.5568628
0.56078434]
[0.7921569 0.7529412 0.7176471 ... 0.47058824 0.48235294
0.49411765]
...
[0.3764706 0.38431373 0.3764706 ... 0.4509804 0.43137255
0.39607844]
[0.38431373 0.39607844 0.3882353 ... 0.4509804 0.43137255
0.39607844]
[0.3882353 0.4 0.39607844 ... 0.44313726 0.42352942
0.39215687]]
[[0.9254902 0.90588236 0.88235295 ... 0.60784316 0.6
0.5921569 ]
[0.88235295 0.85490197 0.8235294 ... 0.5411765 0.5372549
0.53333336]
[0.8235294 0.7882353 0.75686276 ... 0.47058824 0.47058824
0.47058824]
...
[0.50980395 0.5176471 0.5137255 ... 0.58431375 0.5647059
0.53333336]
[0.5137255 0.53333336 0.5254902 ... 0.58431375 0.5686275
0.53333336]
[0.5176471 0.53333336 0.5294118 ... 0.5764706 0.56078434
0.5294118 ]]
[[0.95686275 0.9372549 0.90588236 ... 0.18823528 0.19999999
0.20784312]
[0.9098039 0.8784314 0.8352941 ... 0.1607843 0.17254901
0.18039215]
[0.84313726 0.7921569 0.7490196 ... 0.1372549 0.14509803
0.15294117]
...
[0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
0.02745098]
[0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
0.03529412]
[0.05098039 0.0745098 0.07843137 ... 0.12549019 0.10196078
0.04705882]]]
[[[0.30588236 0.28627452 0.24313724 ... 0.2901961 0.26666668
0.21568626]
[0.8156863 0.6666667 0.5921569 ... 0.18039215 0.23921567
0.21568626]
[0.9019608 0.83137256 0.85490197 ... 0.21960783 0.36862746
0.23921567]
...
[0.7058824 0.83137256 0.85490197 ... 0.2627451 0.24313724
0.20784312]
[0.7137255 0.84313726 0.84705883 ... 0.26666668 0.29803923
0.21568626]
[0.7254902 0.8235294 0.8392157 ... 0.2509804 0.27058825
0.2352941 ]]
[[0.24705881 0.22745097 0.19215685 ... 0.2784314 0.25490198
0.19607842]
[0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
0.20392156]
[0.5921569 0.4509804 0.49803922 ... 0.20784312 0.3764706
0.2352941 ]
...
[0.42352942 0.4627451 0.42352942 ... 0.23921567 0.23137254
0.19999999]
[0.45882353 0.5176471 0.35686275 ... 0.23921567 0.26666668
0.19607842]
[0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
0.21568626]]
[[0.23137254 0.20784312 0.1490196 ... 0.30588236 0.28627452
0.19607842]
[0.61960787 0.3764706 0.26666668 ... 0.16470587 0.24313724
0.21568626]
[0.57254905 0.43137255 0.48235294 ... 0.2235294 0.40392157
0.25882354]
...
[0.4 0.42352942 0.37254903 ... 0.25490198 0.24705881
0.21568626]
[0.43137255 0.4509804 0.29411766 ... 0.25882354 0.28235295
0.20392156]
[0.38431373 0.3529412 0.25490198 ... 0.2352941 0.25490198
0.23137254]]]
[[[0.06274509 0.09019607 0.11372548 ... 0.5803922 0.5176471
0.59607846]
[0.09411764 0.14509803 0.1372549 ... 0.5294118 0.49803922
0.5058824 ]
[0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
0.38431373]
...
[0.15294117 0.12941176 0.1607843 ... 0.85882354 0.8509804
0.80784315]
[0.14509803 0.10588235 0.1607843 ... 0.8666667 0.85882354
0.8 ]
[0.1490196 0.10588235 0.16470587 ... 0.827451 0.8156863
0.7921569 ]]
[[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
0.6039216 ]
[0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
0.5372549 ]
[0.03921568 0.0745098 0.09803921 ... 0.48235294 0.4392157
0.4117647 ]
...
[0.2117647 0.14509803 0.2784314 ... 0.43137255 0.3529412
0.34117648]
[0.2235294 0.11372548 0.2509804 ... 0.4509804 0.39607844
0.2509804 ]
[0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
0.3254902 ]]
[[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
0.45490196]
[0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
0.3882353 ]
[0.03921568 0.06666666 0.0862745 ... 0.3764706 0.33333334
0.28235295]
...
[0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
0.09411764]
[0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
0.05490196]
[0.12156862 0.11372548 0.17254901 ... 0.2352941 0.17254901
0.14117646]]]]
Traceback (most recent call last):
File "image_loader.py", line 51, in <module>
main()
File "image_loader.py", line 46, in main
imshow(images)
File "image_loader.py", line 38, in imshow
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
File "C:UsersAeryesAppDataLocalProgramsPythonPython36libsite-packages
umpycorefromnumeric.py", line 598, in transpose
return _wrapfunc(a, 'transpose', axes)
File "C:UsersAeryesAppDataLocalProgramsPythonPython36libsite-packages
umpycorefromnumeric.py", line 51, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose
我试图打印出数组以获取维度,但我不知道该怎么做.这很令人困惑.
I tried to print out the arrays to get the dimensions but I do not know what to make of this. It is very confusing.
这是我的直接问题:如何在使用 DataLoader 对象中的张量进行训练之前查看输入图像?
Here is my direct question: How do I view the input images before training using the tensors in my DataLoader object?
推荐答案
首先,dataloader
输出4维张量——[batch, channel, height, width]
.Matplotlib 和其他图像处理库经常需要[height, width, channel]
.您使用转置是正确的,只是方式不对.
First of all, dataloader
output 4 dimensional tensor - [batch, channel, height, width]
. Matplotlib and other image processing libraries often requires [height, width, channel]
. You are right about using the transpose, just not in the right way.
在您的 images
中会有很多图片,所以首先您需要选择一个(或编写一个 for 循环来保存所有图片).这将是简单的 images[i]
,通常我使用 i=0
.
There will be a lot of images in your images
so first you need to pick one (or write a for loop to save all of them). This will be simply images[i]
, typically I use i=0
.
然后,您的转置应该将现在的 [channel, height, width]
张量转换为 [height, width, channel]
张量.为此,请使用 np.transpose(image.numpy(), (1, 2, 0))
,与您的非常相似.
Then, your transpose should convert a now [channel, height, width]
tensor to a [height, width, channel]
one. To do this, use np.transpose(image.numpy(), (1, 2, 0))
, very much like yours.
把它们放在一起,你应该有
Putting them together, you should have
plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))
有时您需要调用 .detach()
(将这部分从计算图中分离)和 .cpu()
(将数据从 GPU 传输到 CPU),具体取决于用例,这将是
Sometimes you need to call .detach()
(detach this part from the computational graph) and .cpu()
(transfer data from GPU to CPU) depending on the use case, that will be
plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))
这篇关于如何将 Pytorch Dataloader 转换为 numpy 数组以使用 matplotlib 显示图像数据?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!