PyTorch:nn.Sequential() 中特定模块的访问权重 [英] PyTorch: access weights of a specific module in nn.Sequential()

查看:133
本文介绍了PyTorch:nn.Sequential() 中特定模块的访问权重的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我在 PyTorch 中使用预定义模块时,我通常可以相当轻松地访问其权重.但是,如果我先将模块包装在 nn.Sequential() 中,我该如何访问它们?r.g:

When I use a pre-defined module in PyTorch, I can typically access its weights fairly easily. However, how do I access them if I wrapped the module in nn.Sequential() first? r.g:

class My_Model_1(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_1, self).__init__()
        self.layer = nn.Linear(D_in,D_out)
    def forward(self,x):
        out = self.layer(x)
        return out

class My_Model_2(nn.Module):
    def __init__(self,D_in,D_out):
        super(My_Model_2, self).__init__()
        self.layer = nn.Sequential(nn.Linear(D_in,D_out))
    def forward(self,x):
        out = self.layer(x)
        return out

model_1 = My_Model_1(10,10)
print(model_1.layer.weight)
model_2 = My_Model_2(10,10)

我现在如何打印重量?model_2.layer.0.weight 不起作用.

How do I print the weights now? model_2.layer.0.weight doesn't work.

推荐答案

来自 PyTorch 论坛,这是推荐的方式:

From the PyTorch forum, this is the recommended way:

model_2.layer[0].weight

这篇关于PyTorch:nn.Sequential() 中特定模块的访问权重的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆