用外行的话来说,pytorch 中的 gather 函数有什么作用? [英] What does the gather function do in pytorch in layman terms?

查看:9
本文介绍了用外行的话来说,pytorch 中的 gather 函数有什么作用?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我已经阅读了官方文档这个 但很难理解发生了什么.

I have been through the official doc and this but it is hard to understand what is going on.

我正在尝试了解 DQN 源代码和它使用了第 197 行的 gather 函数.

I am trying to understand a DQN source code and it uses the gather function on line 197.

有人能简单解释一下gather函数的作用吗?该函数的目的是什么?

Could someone explain in simple terms what the gather function does? What is the purpose of that function?

推荐答案

torch.gather 函数(或torch.Tensor.gather)是一种多索引选择方法.从官方文档看下面的例子:

The torch.gather function (or torch.Tensor.gather) is a multi-index selection method. Look at the following example from the official docs:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

让我们从不同参数的语义开始:第一个参数 input 是我们想要从中选择元素的源张量.第二个 dim 是我们想要收集的维度(或 tensorflow/numpy 中的轴).最后,index 是索引 input 的索引.至于操作的语义,官方文档是这样解释的:

Let's start with going through the semantics of the different arguments: The first argument, input, is the source tensor that we want to select elements from. The second, dim, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, index are the indices to index input. As for the semantics of the operation, this is how the official docs explain it:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

那么让我们来看看这个例子.

So let's go through the example.

输入张量是[[1, 2], [3, 4]],dim参数是1,即我们要从第二个方面.第二个维度的索引以 [0, 0][1, 0] 给出.

the input tensor is [[1, 2], [3, 4]], and the dim argument is 1, i.e. we want to collect from the second dimension. The indices for the second dimension are given as [0, 0] and [1, 0].

当我们跳过"第一个维度(我们要收集的维度是1)时,结果的第一个维度被隐式地给出为index.这意味着索引包含第二个维度或列索引,但不包含行索引.这些由 index 张量本身的索引给出.例如,这意味着输出将在其第一行中选择 input 张量的第一行的元素,如 index张量的第一行.由于列索引由 [0, 0] 给出,因此我们选择输入的第一行的第一个元素两次,结果 [1, 1].同理,结果第二行的元素是input张量的第二行被index张量的第二行元素索引的结果,得到在 [4, 3] 中.

As we "skip" the first dimension (the dimension we want to collect along is 1), the first dimension of the result is implicitly given as the first dimension of the index. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index tensor itself. For the example, this means that the output will have in its first row a selection of the elements of the input tensor's first row as well, as given by the first row of the index tensor's first row. As the column-indices are given by [0, 0], we therefore select the first element of the first row of the input twice, resulting in [1, 1]. Similarly, the elements of the second row of the result are a result of indexing the second row of the input tensor by the elements of the second row of the index tensor, resulting in [4, 3].

为了进一步说明这一点,让我们交换示例中的维度:

To illustrate this even further, let's swap the dimension in the example:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

如您所见,索引现在沿第一个维度收集.

As you can see, the indices are now collected along the first dimension.

对于你提到的例子,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather 将通过动作的批处理列表索引 q 值的行(即一批 q 值中的每个样本 q 值).结果将与您执行以下操作相同(尽管它会比循环快得多):

gather will index the rows of the q-values (i.e. the per-sample q-values in a batch of q-values) by the batch-list of actions. The result will be the same as if you had done the following (though it will be much faster than a loop):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

这篇关于用外行的话来说,pytorch 中的 gather 函数有什么作用?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆