如何从 PyTorch 的 ResNet 模型中删除最后一个 FC 层? [英] How to remove the last FC layer from a ResNet model in PyTorch?

查看:59
本文介绍了如何从 PyTorch 的 ResNet 模型中删除最后一个 FC 层?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我使用的是 PyTorch 的 ResNet152 模型.我想从模型中去除最后一个 FC 层.这是我的代码:

I am using a ResNet152 model from PyTorch. I'd like to strip off the last FC layer from the model. Here's my code:

from torchvision import datasets, transforms, models
model = models.resnet152(pretrained=True)
print(model)

当我打印模型时,最后几行如下所示:

When I print the model, the last few lines look like this:

    (2):  Bottleneck(
      (conv1):  Conv2d(2048,  512,  kernel_size=(1,  1),  stride=(1,  1),  bias=False)
      (bn1):  BatchNorm2d(512,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)
      (conv2):  Conv2d(512,  512,  kernel_size=(3,  3),  stride=(1,  1),  padding=(1,  1),  bias=False)
      (bn2):  BatchNorm2d(512,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)
      (conv3):  Conv2d(512,  2048,  kernel_size=(1,  1),  stride=(1,  1),  bias=False)
      (bn3):  BatchNorm2d(2048,  eps=1e-05,  momentum=0.1,  affine=True,  track_running_stats=True)
      (relu):  ReLU(inplace)
    )
  )
  (avgpool):  AvgPool2d(kernel_size=7,  stride=1,  padding=0)
  (fc):  Linear(in_features=2048,  out_features=1000,  bias=True)
)

我想从模型中删除最后一个 fc 层.

I want to remove that last fc layer from the model.

我在这里找到了一个答案 (如何在 Pytorch 中将预训练的 FC 层转换为 CONV 层),其中 mexmex 似乎提供了我正在寻找的答案:

I found an answer here on SO (How to convert pretrained FC layers to CONV layers in Pytorch), where mexmex seems to provide the answer I'm looking for:

list(model.modules()) # to inspect the modules of your model
my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer

所以我将这些行添加到我的代码中:

So I added those lines to my code like this:

model = models.resnet152(pretrained=True)
list(model.modules()) # to inspect the modules of your model
my_model = nn.Sequential(*list(model.modules())[:-1]) # strips off last linear layer
print(my_model)

但是这段代码不像宣传的那样工作——至少对我来说不是.这篇博文的其余部分详细解释了为什么该答案不起作用,因此该问题不会被重复关闭.

But this code doesn't work as advertised -- as least not for me. The rest of this post is a detailed explanation of why that answer doesn't work so this question doesn't get closed as a duplicate.

首先,打印出来的模型比以前大了近 5 倍.我看到的模型和以前一样,但后面似乎是模型的重复,但可能变平了.

First, the printed model is nearly 5x larger than before. I see the same model as before, but followed by what appears to be a repeat of the model, but perhaps flattened.

    (2): Bottleneck(
      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
    )
  )
  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)
(1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(5): Sequential(
  . . . this goes on for ~1600 more lines . . .
  (415): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (416): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (417): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (418): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (419): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (420): ReLU(inplace)
  (421): AvgPool2d(kernel_size=7, stride=1, padding=0)
)

其次,fc 层仍然存在——它之后的 Conv2D 层看起来就像 ResNet152 的第一层.

Second, the fc layer is still there -- and the Conv2D layer after it looks just like the first layer of ResNet152.

第三,如果我尝试调用 my_model.forward(),pytorch 会抱怨尺寸不匹配.它期望大小为 [1, 3, 224, 224],但输入为 [1, 1000].所以看起来整个模型的副本(减去 fc 层)被附加到原始模型中.

Third, if I try to invoke my_model.forward(), pytorch complains about a size mismatch. It expects size [1, 3, 224, 224], but the input was [1, 1000]. So it looks like a copy of the entire model (minus the fc layer) is getting appended to the original model.

最重要的是,我在 SO 上找到的唯一答案实际上不起作用.

Bottom line, the only answer I found on SO doesn't actually work.

推荐答案

对于 ResNet 模型,您可以使用 children 属性来访问层,因为 pytorch 中的 ResNet 模型由 nn 个模块组成.(在 pytorch 0.4.1 上测试)

For ResNet model, you can use children attribute to access layers since ResNet model in pytorch consist of nn modules. (Tested on pytorch 0.4.1)

model = models.resnet152(pretrained=True)
newmodel = torch.nn.Sequential(*(list(model.children())[:-1]))
print(newmodel)

更新:虽然这个问题没有通用的答案可以适用于所有 pytorch 模型,但它应该适用于所有结构良好的模型.您添加到模型中的现有层(例如 torch.nn.线性torch.nn.Conv2d, torch.nn.BatchNorm2d...) 全部基于 torch.nn.Module 类.如果您实现自定义层并将其添加到您的网络中,您应该从 pytorch 的 torch.nn.Module 类继承它.如文档 所述,children 属性允许您访问类/模型/网络的模块.

Update: Although there is not an universal answer for the question that can work on all pytorch models, it should work on all well structured ones. Existing layers you add to your model (such as torch.nn.Linear, torch.nn.Conv2d, torch.nn.BatchNorm2d...) all based on torch.nn.Module class. And if you implement a custom layer and add that to your network you should inherit it from pytorch's torch.nn.Module class. As written in documentation, children attribute lets you access the modules of your class/model/network.

def children(self):
        r"""Returns an iterator over immediate children modules.  

更新:重要的是要注意 children() 返回立即"模块,这意味着如果网络的最后一个模块是一个连续的,它将返回整个连续的模块.

Update: It is important to note that children() returns "immediate" modules, which means if last module of your network is a sequential, it will return whole sequential.

这篇关于如何从 PyTorch 的 ResNet 模型中删除最后一个 FC 层?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆