pytorch 的 nn.Module 如何注册子模块? [英] How does pytorch's nn.Module register submodule?

查看:25
本文介绍了pytorch 的 nn.Module 如何注册子模块?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当我阅读 torch.nn.Module 的源代码(python)时,我发现属性 self._modules 已在许多函数中使用,例如self.modules(), self.children() 等等 但是我没有找到任何函数更新它.那么,self._modules 将在哪里更新?另外pytorch的nn.Module是如何注册子模块的?

When I read the source code(python) of torch.nn.Module , I found the attribute self._modules has been used in many functions like self.modules(), self.children(), etc. However, I didn't find any functions updating it. So, where will the self._modules be updated? Furthermore, how does pytorch's nn.Module register submodule?

class Module(object):
    def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

    def named_modules(self, memo=None, prefix=''):
        if memo is None:
            memo = set()
        if self not in memo:
            memo.add(self)
            yield prefix, self
            for name, module in self._modules.items():
                if module is None:
                    continue
                submodule_prefix = prefix + ('.' if prefix else '') + name
                for m in module.named_modules(memo, submodule_prefix):
                    yield m

推荐答案

通常通过为 nn.module 的实例设置属性来注册模块和参数.特别是,这种行为是通过对__setattr__方法进行cuatomizing来实现的:

The modules and parameters are usually registered by setting an attribute for an instance of nn.module. Particularly, this kind of behavior is implemented by cuatomizing the __setattr__ method:

def __setattr__(self, name, value):
        def remove_from(*dicts):
            for d in dicts:
                if name in d:
                    del d[name]

        params = self.__dict__.get('_parameters')
        if isinstance(value, Parameter):
            if params is None:
                raise AttributeError(
                    "cannot assign parameters before Module.__init__() call")
            remove_from(self.__dict__, self._buffers, self._modules)
            self.register_parameter(name, value)
        elif params is not None and name in params:
            if value is not None:
                raise TypeError("cannot assign '{}' as parameter '{}' "
                                "(torch.nn.Parameter or None expected)"
                                .format(torch.typename(value), name))
            self.register_parameter(name, value)
        else:
            modules = self.__dict__.get('_modules')
            if isinstance(value, Module):
                if modules is None:
                    raise AttributeError(
                        "cannot assign module before Module.__init__() call")
                remove_from(self.__dict__, self._parameters, self._buffers)
                modules[name] = value
            elif modules is not None and name in modules:
                if value is not None:
                    raise TypeError("cannot assign '{}' as child module '{}' "
                                    "(torch.nn.Module or None expected)"
                                    .format(torch.typename(value), name))
                modules[name] = value
            else:
                buffers = self.__dict__.get('_buffers')
                if buffers is not None and name in buffers:
                    if value is not None and not isinstance(value, torch.Tensor):
                        raise TypeError("cannot assign '{}' as buffer '{}' "
                                        "(torch.Tensor or None expected)"
                                        .format(torch.typename(value), name))
                    buffers[name] = value
                else:
                    object.__setattr__(self, name, value)

参见 https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py 找到这个方法.

See https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py to find this method.

这篇关于pytorch 的 nn.Module 如何注册子模块?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆