了解何时在Pytorch中使用python列表 [英] Understanding when to use python list in Pytorch
问题描述
基本上,正如该主题所讨论的那样,此处,您不能使用python列表包装子模块(例如,图层);否则,Pytorch不会更新列表内子模块的参数.相反,您应该使用nn.ModuleList
来包装子模块,以确保将更新其参数.现在,我还看到了类似以下代码的代码,其中作者使用python列表计算损失,然后执行loss.backward()
进行更新(在RL的增强算法中).这是代码:
Basically as this thread discusses here, you cannot use python list to wrap your sub-modules (for example your layers); otherwise, Pytorch is not going to update the parameters of the sub-modules inside the list. Instead you should use nn.ModuleList
to wrap your sub-modules to make sure their parameters are going to be updated. Now I have also seen codes like following where the author uses python list to calculate the loss and then do loss.backward()
to do the update (in reinforce algorithm of RL). Here is the code:
policy_loss = []
for log_prob in self.controller.log_probability_slected_action_list:
policy_loss.append(- log_prob * (average_reward - b))
self.optimizer.zero_grad()
final_policy_loss = (torch.cat(policy_loss).sum()) * gamma
final_policy_loss.backward()
self.optimizer.step()
为什么使用这种格式的列表可以更新模块的参数,但第一种情况不起作用?我现在很困惑.如果我在前面的代码policy_loss = nn.ModuleList([])
中进行了更改,它将引发一个异常,表明张量浮点不是子模块.
Why using the list in this format works for updating the parameters of modules but the first case does not work? I am very confused now. If I change in the previous code policy_loss = nn.ModuleList([])
, it throws an exception saying that tensor float is not sub-module.
推荐答案
您误解了Module
是什么. Module
存储参数并定义正向传递的实现.
You are misunderstanding what Module
s are. A Module
stores parameters and defines an implementation of the forward pass.
允许您使用张量和参数执行任意计算,从而生成其他新张量. Modules
不必知道那些张量.您还可以将张量列表存储在Python列表中.调用backward
时,它必须位于标量张量上,因此应为级联之和.这些张量是损失,而不是参数,因此它们不应是Module
的属性,也不应该包装在ModuleList
中.
You're allowed to perform arbitrary computation with tensors and parameters resulting in other new tensors. Modules
need not be aware of those tensors. You're also allowed to store lists of tensors in Python lists. When calling backward
it needs to be on a scalar tensor thus the sum of the concatenation. These tensors are losses and not parameters so they should not be attributes of a Module
nor wrapped in a ModuleList
.
这篇关于了解何时在Pytorch中使用python列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!