变量的backward() 方法中的参数retain_graph 是什么意思? [英] What does the parameter retain_graph mean in the Variable's backward() method?

本文介绍了变量的backward() 方法中的参数retain_graph 是什么意思?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在阅读

假设我们有一个上面显示的计算图.变量 de 是输出,a 是输入.例如,

导入火炬从 torch.autograd 导入变量a = Variable(torch.rand(1, 4), requires_grad=True)b = a**2c = b*2d = c.mean()e = c.sum()

当我们做 d.backward() 时,那很好.计算完成后,图的计算d的部分将默认被释放以节省内存.所以如果我们执行e.backward(),就会弹出错误信息.为了做e.backward(),我们必须在d.backward()<中将参数retain_graph设置为True/code>,即,

d.backward(retain_graph=True)

只要你在backward方法中使用了retain_graph=True,你就可以随时向后进行:

d.backward(retain_graph=True) # 没问题e.backward(retain_graph=True) # 很好d.backward() # 也可以e.backward() # 会出现错误!

可以找到更多有用的讨论此处.

一个真实的用例

现在,一个真正的用例是多任务学习,其中您可能有多个损失,这些损失可能位于不同的层.假设您有 2 个损失:loss1loss2,它们位于不同的层中.为了将 loss1loss2 w.r.t 的梯度反向传播到网络的可学习权重.您必须在第一个反向传播损失的 backward() 方法中使用 retain_graph=True.

# 假设你先反向传播 loss1,然后是 loss2(你也可以反向传播)loss1.backward(retain_graph=True)loss2.backward() # 现在图被释放了,下一个批量梯度下降过程准备好了optimizer.step() # 更新网络参数

I'm going through the neural transfer pytorch tutorial and am confused about the use of retain_variable(deprecated, now referred to as retain_graph). The code example show:

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_variables=True):
        #Why is retain_variables True??
        self.loss.backward(retain_variables=retain_variables)
        return self.loss

From the documentation

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

So by setting retain_graph= True, we're not freeing the memory allocated for the graph on the backward pass. What is the advantage of keeping this memory around, why do we need it?

解决方案

@cleros is pretty on the point about the use of retain_graph=True. In essence, it will retain any necessary information to calculate a certain variable, so that we can do backward pass on it.

An illustrative example

Suppose that we have a computation graph shown above. The variable d and e is the output, and a is the input. For example,

import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()

when we do d.backward(), that is fine. After this computation, the part of graph that calculate d will be freed by default to save memory. So if we do e.backward(), the error message will pop up. In order to do e.backward(), we have to set the parameter retain_graph to True in d.backward(), i.e.,

d.backward(retain_graph=True)

As long as you use retain_graph=True in your backward method, you can do backward any time you want:

d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!

More useful discussion can be found here.

A real use case

Right now, a real use case is multi-task learning where you have multiple loss which maybe be at different layers. Suppose that you have 2 losses: loss1 and loss2 and they reside in different layers. In order to backprop the gradient of loss1 and loss2 w.r.t to the learnable weight of your network independently. You have to use retain_graph=True in backward() method in the first back-propagated loss.

# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters

这篇关于变量的backward() 方法中的参数retain_graph 是什么意思?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆