如何加载和使用预先保留的 PyTorch InceptionV3 模型对图像进行分类 [英] How to load and use a pretained PyTorch InceptionV3 model to classify an image
问题描述
我和 有同样的问题我可以加载和使用 PyTorch (.pth.tar) 模型吗,该模型没有被接受的答案,或者我可以弄清楚如何遵循给出的建议.
I have the same problem as How can I load and use a PyTorch (.pth.tar) model which does not have an accepted answer or one I can figure out how to follow the advice given.
我是 PyTorch 的新手.我正在尝试加载此处引用的预训练 PyTorch 模型:https://github.com/macaodha/inat_comp_2018
I'm new to PyTorch. I am trying to load the pretrained PyTorch model referenced here: https://github.com/macaodha/inat_comp_2018
我很确定我缺少一些胶水.
I'm pretty sure I am missing some glue.
# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])
def image_loader(image_name):
"""load image, returns cuda tensor"""
image = Image.open(image_name)
image = loader(image).float()
image = Variable(image, requires_grad=True)
image = image.unsqueeze(0)
return image.cpu() #assumes that you're using CPU
image = image_loader("test-image.jpg")
产生错误:
在 ()----> 1 model.predict(image)
in () ----> 1 model.predict(image)
AttributeError: 'dict' 对象没有属性 'predict
AttributeError: 'dict' object has no attribute 'predict
推荐答案
问题
您的 model
实际上不是模型.保存时,它不仅包含参数,还包含有关模型的其他信息,形式有点类似于字典.
Problem
Your model
isn't actually a model. When it is saved, it contains not only the parameters, but also other information about the model as a form somewhat similar to a dict.
因此,torch.load("iNat_2018_InceptionV3.pth.tar")
只是返回dict
,当然它没有一个叫做predict的属性代码>.
Therefore, torch.load("iNat_2018_InceptionV3.pth.tar")
simply returns dict
, which of course does not have an attribute called predict
.
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict
解决方案
在这种情况下,在一般情况下,您首先需要做的是根据官方指南实例化所需的模型类"加载模型".
# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.
然而,直接输入model['state_dict']
会引起一些关于Inception3
参数形状不匹配的错误.
However, directly inputing the model['state_dict']
will raise some errors regarding mismatching shapes of Inception3
's parameters.
了解在实例化后 Inception3
发生了什么变化很重要.幸运的是,您可以在原作者的train_inat.py
中找到.
It is important to know what was changed to the Inception3
after its instantiation. Luckily, you can find that in the original author's train_inat.py
.
# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False
既然我们知道要更改什么,让我们对第一次尝试进行一些修改.
Now that we know what to change, lets make some modification to our first try.
# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.
然后就可以成功加载模型了!
And there you go with successfully loaded model!
这篇关于如何加载和使用预先保留的 PyTorch InceptionV3 模型对图像进行分类的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!