Pytorch:了解 nn.Module 类在内部是如何工作的 [英] Pytorch: Understand how nn.Module class internally work

查看:46
本文介绍了Pytorch:了解 nn.Module 类在内部是如何工作的的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

通常,一个 nn.Module 可以被子类继承,如下所示.

Generally, a nn.Module can be inherited by a subclass as below.

def init_weights(m):
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform(m.weight)  # 

class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc1 = nn.Linear(20, 1)
        self.apply(init_weights)

    def forward(self, x):
        x = self.fc1(x)
        return x

我的第一个问题是,为什么我可以简单地运行下面的代码,即使我的 __init__ 没有任何 training_signals 的正参数,它看起来像 training_signals 被传递给 forward() 方法.它是如何工作的?

My 1st question is, why I can simply run the code below even my __init__ doesn't have any positinoal arguments for training_signals and it looks like that training_signals is passed to forward() method. How does it work?

model = LinearRegression()
training_signals = torch.rand(1000,20)
model(training_signals)

第二个问题是self.apply(init_weights)内部是如何工作的?是否在调用forward方法之前执行?

The second question is that how does self.apply(init_weights) internally work? Is it executed before calling forward method?

推荐答案

Q1:为什么我可以简单地运行下面的代码,即使我的 __init__ 没有任何 training_signals 的位置参数,而且看起来像 training_signals 传递给 forward() 方法.它是如何工作的?

Q1: Why I can simply run the code below even my __init__ doesn't have any positional arguments for training_signals and it looks like that training_signals is passed to forward() method. How does it work?

首先,当您运行此行时会调用 __init__:

First, the __init__ is called when you run this line:

model = LinearRegression()

如您所见,您没有传递任何参数,您不应该这样做.__init__ 的签名与基类的签名相同(运行 super(LinearRegression, self).__init__() 时调用).正如您在此处所见,nn.Module 的 init 签名只是 def __init__(self) (就像你的一样).

As you can see, you pass no parameters, and you shouldn't. The signature of your __init__ is the same as the one of the base class (which you call when you run super(LinearRegression, self).__init__()). As you can see here, nn.Module's init signature is simply def __init__(self) (just like yours).

其次,model 现在是一个对象.当您运行以下行时:

Second, model is now an object. When you run the line below:

model(training_signals)

您实际上是在调用 __call__ 方法并将 training_signals 作为位置参数传递.如您所见此处,除此之外,__call__ 方法调用了 forward 方法:

You are actually calling the __call__ method and passing training_signals as a positional parameter. As you can see here, among many other things, the __call__ method calls the forward method:

result = self.forward(*input, **kwargs)

__call__ 的所有参数(位置和命名)传递给 forward.

passing all parameters (positional and named) of the __call__ to the forward.

Q2:self.apply(init_weights) 在内部是如何工作的?是在调用forward方法之前执行吗?

Q2: How does self.apply(init_weights) internally work? Is it executed before calling forward method?

PyTorch 是开源的,因此您只需转到源代码并检查它.如您所见此处,实现很简单:

PyTorch is Open Source, so you can simply go to the source-code and check it. As you can see here, the implementation is quite simple:

def apply(self, fn):
    for module in self.children():
        module.apply(fn)
    fn(self)
    return self

引用函数的文档:它fn 递归地应用于每个子模块(由 .children() 返回)以及 自我".基于实现,您还可以了解需求:

Quoting the documentation of the function: it "applies fn recursively to every submodule (as returned by .children()) as well as self". Based on the implementation, you can also understand the requirements:

  • fn 必须是可调用的;
  • fn 只接收一个 Module 对象作为输入;
  • fn must be a callable;
  • fn receives as input only a Module object;

这篇关于Pytorch:了解 nn.Module 类在内部是如何工作的的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆