如何加载和使用PyTorch(.pth.tar)模型 [英] How can I load and use a PyTorch (.pth.tar) model
问题描述
我对Torch不太熟悉,主要使用Tensorflow.但是,我需要使用在Torch中重新训练的重新训练的初始模型.由于为我的特定应用重新训练初始模型需要大量的计算资源,因此我想使用已经重新训练的模型.
I am not very familiar with Torch, and I primarily use Tensorflow. I, however, need to use a retrained inception model that was retrained in Torch. Due to the large amount of computing resources required to retrain an inception model for my particular application, I would like to use the model that was already retrained.
此模型另存为.pth.tar
文件.
我希望能够首先加载此模型.到目前为止,我已经能够确定必须使用以下内容:
I would like to be able to first load this model. So far, I have been able to figure out that I must use the following:
model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')
这似乎行得通,因为print(model)
会打印出大量数字和其他值,我认为这是偏差的权重值.
This seems to work, because print(model)
prints out a large set of numbers and other values, which I presume are the values for the weights an biases.
在此之后,我需要能够使用它对图像进行分类.我还没弄清楚.我该如何格式化图像?是否应将图像转换为数组?之后,如何将输入数据传递到网络?
After this, I need to be able to classify an image with it. I haven't been able to figure this out. How must I format the image? Should the image be converted into an array? After this, how must I pass the input data to the network?
推荐答案
您基本上需要执行与tensorflow中相同的操作.也就是说,当您存储网络时,将仅存储参数(即网络中的可训练对象),而不是胶水",这就是使用训练模型所需的全部逻辑.
因此,如果您有一个.pth.tar
文件,则可以加载它,从而覆盖已经定义的模型的参数值.
you basically need to do the same as in tensorflow. That is, when you store a network, only the parameters (i.e. the trainable objects in your network) will be stored, but not the "glue", that is all the logic you need to use a trained model.
So if you have a .pth.tar
file, you can load it, thereby overriding the parameter values of a model already defined.
这意味着保存/加载模型的一般过程如下:
That means that the general procedure of saving/loading a model is as follows:
- 编写您的网络定义(即您的
nn.Module
对象) - 以您想要的方式训练或以其他方式更改网络参数
- 使用
torch.save
保存参数
- 当您要使用该网络时,请使用与
nn.Module
对象相同的定义来首先实例化pytorch网络 - 然后使用
torch.load
覆盖网络参数的值
- write your network definition (i.e. your
nn.Module
object) - train or otherwise change the network's parameters in a way you want
- save the parameters using
torch.save
- when you want to use that network, use the same definition of an
nn.Module
object to first instantiate a pytorch network - then override the values of the network's parameters using
torch.load
这里是有关如何执行此操作的讨论: pytorch论坛
Here's a discussion with some references on how to do this: pytorch forums
这是一个超短的mwe:
And here's a super short mwe:
# to store
torch.save({
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')
# to load
checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
这篇关于如何加载和使用PyTorch(.pth.tar)模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!