如何加载和使用预先保留的 PyTorch InceptionV3 模型对图像进行分类 [英] How to load and use a pretained PyTorch InceptionV3 model to classify an image

查看:137
本文介绍了如何加载和使用预先保留的 PyTorch InceptionV3 模型对图像进行分类的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我和 有同样的问题我可以加载和使用 PyTorch (.pth.tar) 模型吗,该模型没有被接受的答案,或者我可以弄清楚如何遵循给出的建议.

I have the same problem as How can I load and use a PyTorch (.pth.tar) model which does not have an accepted answer or one I can figure out how to follow the advice given.

我是 PyTorch 的新手.我正在尝试加载此处引用的预训练 PyTorch 模型:https://github.com/macaodha/inat_comp_2018

I'm new to PyTorch. I am trying to load the pretrained PyTorch model referenced here: https://github.com/macaodha/inat_comp_2018

我很确定我缺少一些胶水.

I'm pretty sure I am missing some glue.

# load the model
import torch
model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')

# try to get it to classify an image
imsize = 256
loader = transforms.Compose([transforms.Scale(imsize), transforms.ToTensor()])

def image_loader(image_name):
    """load image, returns cuda tensor"""
    image = Image.open(image_name)
    image = loader(image).float()
    image = Variable(image, requires_grad=True)
    image = image.unsqueeze(0)  
    return image.cpu()  #assumes that you're using CPU

image = image_loader("test-image.jpg")

产生错误:

在 ()----> 1 model.predict(image)

in () ----> 1 model.predict(image)

AttributeError: 'dict' 对象没有属性 'predict

AttributeError: 'dict' object has no attribute 'predict

推荐答案

问题

您的 model 实际上不是模型.保存时,它不仅包含参数,还包含有关模型的其他信息,形式有点类似于字典.

Problem

Your model isn't actually a model. When it is saved, it contains not only the parameters, but also other information about the model as a form somewhat similar to a dict.

因此,torch.load("iNat_2018_InceptionV3.pth.tar") 只是返回dict,当然它没有一个叫做predict.

Therefore, torch.load("iNat_2018_InceptionV3.pth.tar") simply returns dict, which of course does not have an attribute called predict.

model=torch.load("iNat_2018_InceptionV3.pth.tar",map_location='cpu')
type(model)
# dict

解决方案

在这种情况下,在一般情况下,您首先需要做的是根据官方指南实例化所需的模型类"加载模型".

# First try
from torchvision.models import Inception3
v3 = Inception3()
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

然而,直接输入model['state_dict']会引起一些关于Inception3参数形状不匹配的错误.

However, directly inputing the model['state_dict'] will raise some errors regarding mismatching shapes of Inception3's parameters.

了解在实例化后 Inception3 发生了什么变化很重要.幸运的是,您可以在原作者的train_inat.py 中找到.

It is important to know what was changed to the Inception3 after its instantiation. Luckily, you can find that in the original author's train_inat.py.

# What the author has done
model = inception_v3(pretrained=True)
model.fc = nn.Linear(2048, args.num_classes) #where args.num_classes = 8142
model.aux_logits = False

既然我们知道要更改什么,让我们对第一次尝试进行一些修改.

Now that we know what to change, lets make some modification to our first try.

# Second try
from torchvision.models import Inception3
v3 = Inception3()
v3.fc = nn.Linear(2048, 8142)
v3.aux_logits = False
v3.load_state_dict(model['state_dict']) # model that was imported in your code.

然后就可以成功加载模型了!

And there you go with successfully loaded model!

这篇关于如何加载和使用预先保留的 PyTorch InceptionV3 模型对图像进行分类的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆