PyTorch:如何正确创建 nn.Linear() 列表 [英] PyTorch : How to properly create a list of nn.Linear()

查看:75
本文介绍了PyTorch:如何正确创建 nn.Linear() 列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我创建了一个以 nn.Module 作为子类的类.

I have created a class that has nn.Module as subclass.

在我的课堂上,我必须创建 N 个线性变换,其中 N 作为类参数给出.

In my class, I have to create N number of linear transformation, where N is given as class parameters.

因此我按照以下步骤进行:

I therefore proceed as follow :

    self.list_1 = []

    for i in range(N):
        self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias))

在 forward 方法中,我调用这些矩阵(使用 list_1[i])并连接结果.

In the forward method, i call these matrices (with list_1[i]) and concat the results.

两件事:

1)

即使我使用model.cuda(),这些线性变换也用于cpu 并且我收到以下错误:

Even though I use model.cuda(), these Linear transform are used on cpu and i get the following error :

RuntimeError: 类型为 Variable[torch.cuda.FloatTensor] 的预期对象,但发现参数 #1 'mat2' 的类型为 Variable[torch.FloatTensor]

RuntimeError: Expected object of type Variable[torch.cuda.FloatTensor] but found type Variable[torch.FloatTensor] for argument #1 'mat2'

我必须这样做

self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias).cuda())

如果相反,我不需要:

self.nn = nn.Linear(self.x, 1, bias=mlp_bias)

然后直接使用self.nn.

and then use self.nn directly.

2)

出于更明显的原因,当我在 main 中打印(model) 时,我的列表中的线性矩阵没有打印出来.

For more obvious reason, when I print(model) in my main, the Linear matrices in my list arent printed.

还有什么办法吗?也许使用 bmm ?我发现它不太容易,而且我实际上想单独获得我的 N 个结果.

Is there any other way. maybe using bmm ? I find it less easy, and i actually want to have my N results separately.

先谢谢你,

M

推荐答案

您可以使用 nn.ModuleList 来包装线性层列表,如这里

You can use nn.ModuleList to wrap your list of linear layers as explained here

self.list_1 = nn.ModuleList(self.list_1)

这篇关于PyTorch:如何正确创建 nn.Linear() 列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆