Pytorch nn模块概括 [英] Pytorch nn Module generalization

查看:63
本文介绍了Pytorch nn模块概括的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

让我们看一下简单的类:

Let us take a look at the simple class:

class Temp1(nn.Module):

    def __init__(self, stateSize, actionSize, layers=[10, 5], activations=[F.tanh, F.tanh] ):

        super(Temp1, self).__init__()
        self.layer1 = nn.Linear(stateSize, layers[0])
        self.layer2 = nn.Linear(layers[0], layers[1])
        self.fcFinal = nn.Linear( layers[1], actionSize )
        return

这是一个相当简单的pytorch模块.它创建了一个简单的顺序密集网络.如果我们检查其隐藏参数,则会看到以下内容:

This is a fairly straight forward pytorch module. It creates a simple sequential dense network. If we check its hidden parameters, we see the following:

t1 = Temp1(2, 2)
list(t1.parameters())

这是预期的结果...

This is the expected result ...

[Parameter containing:
 tensor([[-0.0311, -0.5513],
         [-0.0634, -0.3783],
         [-0.2514,  0.6139],
         [ 0.4711, -0.0241],
         [-0.1739,  0.2208],
         [-0.1533,  0.3838],
         [-0.6490, -0.5784],
         [ 0.5312,  0.6703],
         [ 0.3506,  0.3652],
         [ 0.1768, -0.4158]], requires_grad=True), Parameter containing:
 tensor([-0.3199, -0.4154, -0.5530, -0.6738, -0.4411,  0.2641, -0.3576,  0.0447,
          0.0254,  0.0965], requires_grad=True), Parameter containing:
 tensor([[-2.8257e-01,  6.7583e-02,  9.0356e-02,  1.0868e-01,  4.0876e-02,
           4.0616e-02,  4.4419e-02, -8.1544e-02,  2.5244e-01,  3.8777e-03],
         [-8.0950e-03, -1.4175e-01, -2.9492e-01,  3.1439e-01, -2.3065e-01,
          -6.6631e-02,  3.0047e-01,  2.8353e-01,  2.3457e-01, -3.1399e-03],
         [-5.2522e-02, -2.2183e-01, -1.5485e-01,  2.6317e-01,  2.8273e-01,
          -7.4823e-02, -5.3704e-02,  9.3526e-02, -1.7916e-01, -3.1132e-04],
         [ 8.9063e-02,  2.9263e-01, -1.0052e-01,  8.7005e-02, -1.1246e-01,
          -2.7968e-01,  4.1411e-02, -1.6776e-01,  1.2363e-01, -2.2808e-01],
         [ 2.9244e-02,  5.8296e-02, -2.9729e-01, -3.1437e-01, -9.3182e-02,
          -7.5236e-03,  5.6159e-02, -2.2075e-02,  1.0337e-01,  8.1123e-02]],
        requires_grad=True), Parameter containing:
 tensor([ 0.2240,  0.0997, -0.0047, -0.1784, -0.0369], requires_grad=True), Parameter containing:
 tensor([[ 0.3546, -0.2180,  0.1723, -0.0463,  0.2572],
         [-0.1669, -0.1364, -0.0398,  0.2233, -0.1805]], requires_grad=True), Parameter containing:
 tensor([ 0.0871, -0.1698], requires_grad=True)]

现在,让我们尝试概括一下:

Now, let us try to generalize this a bit:

class Temp(nn.Module):

    def __init__(self, stateSize, actionSize, layers=[10, 5], activations=[F.tanh, F.tanh] ):

        super(Temp, self).__init__()

        # Generate the fullly connected layer functions
        self.fcLayers = []

        oldN = stateSize
        for i, layer in enumerate(layers):
            self.fcLayers.append( nn.Linear(oldN, layer) )
            oldN = layer
        self.fcFinal = nn.Linear( oldN, actionSize )
        return

事实证明,此模块中的参数数量不再相同...

It turns out that the number of parameters within this module is no longer the same ...

t = Temp(2, 3)
list(t.parameters())

[Parameter containing:
 tensor([[-0.3342,  0.4111,  0.0418,  0.4457,  0.0648],
         [ 0.4364, -0.0360, -0.2239,  0.4025,  0.1661],
         [ 0.1932, -0.0896,  0.3269, -0.2179,  0.1035]], requires_grad=True),
 Parameter containing:
 tensor([-0.2867, -0.1354, -0.0026], requires_grad=True)]

我相信了解为什么.更大的问题是,我们如何克服这个问题?例如,第二种通用​​方法将无法正确发送到GPU,并且不会由优化器进行训练.

I believe understand why this is happening. The bigger question is, how do we overcome this problem? The second, generalized method for example will not be sent to the GPU properly, and will not be trained by an optimizer.

推荐答案

问题在于,通用"版本中的大多数 nn.Linear 层都存储在常规的pythonic列表中( self.fcLayers ).的问题不知道要查找 nn.paramters nn.Module 的常规pythonic成员中.

The problem is that most of the nn.Linear layers in the "generalized" version are stored in a regular pythonic list (self.fcLayers). pytorch does not know to look for nn.Paramters inside regular pythonic members of nn.Module.

解决方案:
如果您希望通过可以管理它们,您需要使用专门的 pytorch容器.
例如,如果您使用 nn.ModuleList 而不是常规的pythonic列表:

Solution:
If you wish to store nn.Modules in a way that pytorch can manage them, you need to use specialized pytorch containers.
For instance, if you use nn.ModuleList instead of a regular pythonic list:

self.fcLayers = nn.ModuleList([])

您的示例应该可以正常工作.

your example should work fine.

顺便说一句,
您需要pytorch知道 nn.Module 的成员本身就是模块,不仅可以获取其参数,还可以用于其他功能,例如将其移至gpu/cpu,将其模式设置为eval/培训等.

BTW,
you need pytorch to know that members of your nn.Module are modules themselves not only to get their parameters, but also for other functions, such as moving them to gpu/cpu, setting their mode to eval/training etc.

这篇关于Pytorch nn模块概括的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆