如何通过索引访问 pytorch 模块中的层? [英] How can I access layers in a pytorch module by index?

查看:51
本文介绍了如何通过索引访问 pytorch 模块中的层?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试编写一个具有多个层的 pytorch 模块.由于我需要中间输出,我不能像往常一样将它们全部放入 Sequantial 中.另一方面,由于有很多层,我的想法是将层放在一个列表中并在循环中通过索引访问它们.下面描述我想要实现的目标:

I am trying to write a pytorch module with multiple layers. Since I need the intermediate outputs I cannot put them all in a Sequantial as usual. On the other hand, since there are many layers, what I have in mind is to put the layers in a list and access them by index in a loop. Below describe what I am trying to achieve:

import torch
import torch.nn as nn
import torch.optim as optim

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.layer_list = []

        self.layer_list.append(nn.Linear(2,3))
        self.layer_list.append(nn.Linear(3,4))
        self.layer_list.append(nn.Linear(4,5))

    def forward(self, x):
        res_list = [x]
        for i in range(len(self.layer_list)):
            res_list.append(self.layer_list[i](res_list[-1]))
        return res_list


model = MyModel()
x = torch.randn(4,2)
y = model(x)

print(y)

optimizer = optim.Adam(model.parameters())

forward 方法工作正常,但是当我想设置优化器时,程序说

The forward method works fine, but when I want to set an optimizer the program says

ValueError: optimizer got an empty parameter list

列表中的图层似乎没有在这里注册.我能做什么?

It appears that the layers in the list are not registered here. What can I do?

推荐答案

如果你把你的图层放在 python 列表中,pytorch 不会正确注册它们.您必须使用 ModuleList (https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).

If you put your layers in a python list, pytorch does not register them correctly. You have to do so using ModuleList (https://pytorch.org/docs/master/generated/torch.nn.ModuleList.html).

ModuleList 可以像常规 Python 列表一样被索引,但它包含的模块已正确注册,并且所有模块方法都可以看到.

ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods.

您的代码应该类似于:


import torch
import torch.nn as nn
import torch.optim as optim

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.layer_list = nn.ModuleList()  # << the only changed line! <<

        self.layer_list.append(nn.Linear(2,3))
        self.layer_list.append(nn.Linear(3,4))
        self.layer_list.append(nn.Linear(4,5))

    def forward(self, x):
        res_list = [x]
        for i in range(len(self.layer_list)):
            res_list.append(self.layer_list[i](res_list[-1]))
        return res_list

通过使用 ModuleList,您可以确保所有层都在计算图中注册.

By using ModuleList you make sure all layers are registered in the computational graph.

还有一个 ModuleDict 如果你想按名称索引你的图层,你可以使用它.您可以在此处查看 pytorch 的容器:https://pytorch.org/docs/master/nn.html#containers

There is also a ModuleDict that you can use if you want to index your layers by name. You can check pytorch's containers here: https://pytorch.org/docs/master/nn.html#containers

这篇关于如何通过索引访问 pytorch 模块中的层?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆