了解何时在Pytorch中使用python列表 [英] Understanding when to use python list in Pytorch

查看:172
本文介绍了了解何时在Pytorch中使用python列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

基本上,正如该主题所讨论的那样,此处,您不能使用python列表包装子模块(例如,图层);否则,Pytorch不会更新列表内子模块的参数.相反,您应该使用nn.ModuleList来包装子模块,以确保将更新其参数.现在,我还看到了类似以下代码的代码,其中作者使用python列表计算损失,然后执行loss.backward()进行更新(在RL的增强算法中).这是代码:

Basically as this thread discusses here, you cannot use python list to wrap your sub-modules (for example your layers); otherwise, Pytorch is not going to update the parameters of the sub-modules inside the list. Instead you should use nn.ModuleList to wrap your sub-modules to make sure their parameters are going to be updated. Now I have also seen codes like following where the author uses python list to calculate the loss and then do loss.backward() to do the update (in reinforce algorithm of RL). Here is the code:

 policy_loss = []
    for log_prob in self.controller.log_probability_slected_action_list:
        policy_loss.append(- log_prob * (average_reward - b))
    self.optimizer.zero_grad()
    final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
    final_policy_loss.backward()
    self.optimizer.step()

为什么使用这种格式的列表可以更新模块的参数,但第一种情况不起作用?我现在很困惑.如果我在前面的代码policy_loss = nn.ModuleList([])中进行了更改,它将引发一个异常,表明张量浮点不是子模块.

Why using the list in this format works for updating the parameters of modules but the first case does not work? I am very confused now. If I change in the previous code policy_loss = nn.ModuleList([]), it throws an exception saying that tensor float is not sub-module.

推荐答案

您误解了Module是什么. Module存储参数并定义正向传递的实现.

You are misunderstanding what Modules are. A Module stores parameters and defines an implementation of the forward pass.

允许您使用张量和参数执行任意计算,从而生成其他新张量. Modules不必知道那些张量.您还可以将张量列表存储在Python列表中.调用backward时,它必须位于标量张量上,因此应为级联之和.这些张量是损失,而不是参数,因此它们不应是Module的属性,也不应该包装在ModuleList中.

You're allowed to perform arbitrary computation with tensors and parameters resulting in other new tensors. Modules need not be aware of those tensors. You're also allowed to store lists of tensors in Python lists. When calling backward it needs to be on a scalar tensor thus the sum of the concatenation. These tensors are losses and not parameters so they should not be attributes of a Module nor wrapped in a ModuleList.

这篇关于了解何时在Pytorch中使用python列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆