如果我仅对某些样本进行转发,什么时候可以释放计算图? [英] When will the computation graph be freed if I only do forward for some samples?

查看:86
本文介绍了如果我仅对某些样本进行转发,什么时候可以释放计算图?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个用例,其中我对批次中的每个样本进行转发,并且仅根据样本模型输出中的某些条件来累积某些样本的损失.这是一个说明性的代码,

I have a use case where I do forward for each sample in a batch and only accumulate loss for some of the samples based on some condition on the model output of the sample. Here is an illustrating code,

for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    total_loss = 0

    loss_count_local = 0
    for i in range(len(target)):
        im = Variable(data[i].unsqueeze(0).cuda())
        y = Variable(torch.FloatTensor([target[i]]).cuda())

        out = model(im)

        # if out satisfy some condtion, we will calculate loss
        # for this sample, else proceed to next sample
        if some_condition(out):
            loss = criterion(out, y)
        else:
            continue

        total_loss += loss
        loss_count_local += 1

        if loss_count_local == 32 or i == (len(target)-1):
            total_loss /= loss_count_local
            total_loss.backward()
            total_loss = 0
            loss_count_local = 0

    optimizer.step()

我的问题是,我对所有样本都执行正向操作,但对某些样本只执行反向操作.那些没有造成损失的样本的图表何时会被释放?这些图形是否仅在for循环结束后或在我转发下一个示例后立即释放?我在这里有点困惑.

My question is, as I do forward for all samples but only do backward for some of the samples. When will the graph for those samples which do not contribute to the loss be freed? Will these graphs be freed only after the for loop has ended or immediately after I do forward for the next sample? I am a little confused here.

对于那些确实有助于total_loss的样本,我们执行total_loss.backward()后将立即释放其图形.是吗?

Also for those samples that do contribute to the total_loss, their graph will be freed immediately after we do total_loss.backward(). Is that right?

推荐答案

让我们从有关PyTorch如何释放内存的一般讨论开始:

Let's start with a general discussion of how PyTorch frees memory:

首先,我们应该强调PyTorch使用存储在Python对象属性中的隐式声明的图. (请记住,它是Python,所以一切都是对象).更具体地说,torch.autograd.Variable具有.grad_fn属性.此属性的类型定义了我们拥有哪种计算节点(例如加法)以及该节点的输入.

First, we should emphasize that PyTorch uses an implicitly declared graph that is stored in Python object attributes. (Remember, it's Python, so everything is an object). More specifically, torch.autograd.Variables have a .grad_fn attribute. This attribute's type defines what kind of computation node we have (e.g. an addition), and the input to that node.

这很重要,因为Pytorch只需使用标准的python垃圾回收器(如果非常积极)就可以释放内存.在这种情况下,这意味着(隐式声明的)计算图将保持活动状态,只要在当前作用域中存在对持有它们的对象的引用即可!

This is important because Pytorch frees memory simply by using the standard python garbage collector (if fairly aggressively). In this context, this means that the (implicitly declared) computation graphs will be kept alive as long as there are references to the objects holding them in the current scope!

这意味着如果您对样本s_1 ... s_k进行某种批处理,计算每个样本的损失并在末尾加上损失,该累积损失将保留对每个损失的引用,而该损失又包含对每个计算节点的引用它.

This means that if you e.g. do some kind of batching on samples s_1 ... s_k, compute the loss for each and add the loss at the end, that cumulative loss will hold references to each individual loss, which in turn holds references to each of the computation nodes that computed it.

因此,与Pytorch相比,应用于代码的问题更多地是关于Python(或更具体地说,是其垃圾收集器)如何处理引用.由于您将丢失累积在一个对象(total_loss)中,因此可以使指针保持活动状态,因此,除非在外部循环中重新初始化该对象,否则不要释放内存.

So your question applied to your code is more about how Python (or, more specifically its garbage collector) handles references than about Pytorch does. Since you accumulate the loss in one object (total_loss), you keep pointers alive, and thereby do not free the memory until you re-initialize that object in the outer loop.

应用于您的示例,这意味着您在正向传递中(在out = model(im)处)创建的计算图仅由out对象及其以后的任何计算引用.因此,如果您计算损耗并将其求和,您将使对out的引用保持活动状态,从而对计算图保持引用.但是,如果不使用它,则垃圾收集器应递归收集out及其计算图.

Applied to your example, this means that the computation graph you create in the forward pass (at out = model(im)) is only referenced by the out object and any future computations thereof. So if you compute the loss and sum it, you will keep references to out alive, and thereby to the computation graph. If you do not use it, however, the garbage collector should recursively collect out, and its computation graph.

这篇关于如果我仅对某些样本进行转发,什么时候可以释放计算图?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆