如何加载和使用 PyTorch (.pth.tar) 模型 [英] How can I load and use a PyTorch (.pth.tar) model

查看:241
本文介绍了如何加载和使用 PyTorch (.pth.tar) 模型的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我对 Torch 不是很熟悉,我主要使用 Tensorflow.但是,我需要使用在 Torch 中重新训练的重新训练的初始模型.由于为我的特定应用程序重新训练初始模型需要大量计算资源,我想使用已经重新训练的模型.

I am not very familiar with Torch, and I primarily use Tensorflow. I, however, need to use a retrained inception model that was retrained in Torch. Due to the large amount of computing resources required to retrain an inception model for my particular application, I would like to use the model that was already retrained.

此模型保存为 .pth.tar 文件.

This model is saved as a .pth.tar file.

我希望能够首先加载这个模型.到目前为止,我已经能够弄清楚我必须使用以下内容:

I would like to be able to first load this model. So far, I have been able to figure out that I must use the following:

model = torch.load('iNat_2018_InceptionV3.pth.tar', map_location='cpu')

这似乎有效,因为 print(model) 打印出大量数字和其他值,我认为这些值是权重和偏差的值.

This seems to work, because print(model) prints out a large set of numbers and other values, which I presume are the values for the weights an biases.

在此之后,我需要能够用它对图像进行分类.我一直无法弄清楚这一点.我必须如何格式化图像?图像是否应该转换为数组?在此之后,我必须如何将输入数据传递给网络?

After this, I need to be able to classify an image with it. I haven't been able to figure this out. How must I format the image? Should the image be converted into an array? After this, how must I pass the input data to the network?

推荐答案

你基本上需要像在 tensorflow 中那样做.也就是说,当您存储网络时,只会存储参数(即网络中的可训练对象),而不是胶水",这是使用训练模型所需的全部逻辑.因此,如果您有一个 .pth.tar 文件,您可以加载它,从而覆盖已定义模型的参数值.

you basically need to do the same as in tensorflow. That is, when you store a network, only the parameters (i.e. the trainable objects in your network) will be stored, but not the "glue", that is all the logic you need to use a trained model. So if you have a .pth.tar file, you can load it, thereby overriding the parameter values of a model already defined.

这意味着保存/加载模型的一般过程如下:

That means that the general procedure of saving/loading a model is as follows:

  • 编写您的网络定义(即您的 nn.Module 对象)
  • 以您想要的方式训练或以其他方式更改网络的参数
  • 使用torch.save
  • 保存参数
  • 当你想使用那个网络时,使用 nn.Module 对象的相同定义来首先实例化一个 pytorch 网络
  • then 使用 torch.load
  • 覆盖网络参数的值
  • write your network definition (i.e. your nn.Module object)
  • train or otherwise change the network's parameters in a way you want
  • save the parameters using torch.save
  • when you want to use that network, use the same definition of an nn.Module object to first instantiate a pytorch network
  • then override the values of the network's parameters using torch.load

以下是有关如何执行此操作的一些参考资料的讨论:pytorch 论坛

Here's a discussion with some references on how to do this: pytorch forums

这是一个超短的 mwe:

And here's a super short mwe:

# to store
torch.save({
    'state_dict': model.state_dict(),
    'optimizer' : optimizer.state_dict(),
}, 'filename.pth.tar')

# to load
checkpoint = torch.load('filename.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

这篇关于如何加载和使用 PyTorch (.pth.tar) 模型的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆