如何在numpy中收集特定索引的元素? [英] how to gather elements of specific indices in numpy?

查看:158
本文介绍了如何在numpy中收集特定索引的元素?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想像下面那样在指定轴上收集指定索引的元素.

I want to gather elements of specified indices in specified axis like following.

x = [[1,2,3], [4,5,6]]
index = [[2,1], [0, 1]]
x[:, index] = [[3, 2], [4, 5]]

这本质上是pytorch中的collect操作,但是如您所知,这在numpy中无法通过这种方式实现.我想知道numpy中是否有这样的聚集"操作?

This is essentially gather operation in pytorch, but as you know, this is not achievable in numpy this way. I am wondering if there is such a "gather" operation in numpy?

推荐答案

我之前写过这篇文章,目的是在Numpy中复制PyTorch的gather.在这种情况下,self是您的x

I wrote this awhile ago to replicate PyTorch's gather in Numpy. In this case self is your x

def gather(self, dim, index):
    """
    Gathers values along an axis specified by ``dim``.

    For a 3-D tensor the output is specified by:
        out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
        out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
        out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

    Parameters
    ----------
    dim:
        The axis along which to index
    index:
        A tensor of indices of elements to gather

    Returns
    -------
    Output Tensor
    """
    idx_xsection_shape = index.shape[:dim] + \
        index.shape[dim + 1:]
    self_xsection_shape = self.shape[:dim] + self.shape[dim + 1:]
    if idx_xsection_shape != self_xsection_shape:
        raise ValueError("Except for dimension " + str(dim) +
                         ", all dimensions of index and self should be the same size")
    if index.dtype != np.dtype('int_'):
        raise TypeError("The values of index must be integers")
    data_swaped = np.swapaxes(self, 0, dim)
    index_swaped = np.swapaxes(index, 0, dim)
    gathered = np.choose(index_swaped, data_swaped)
    return np.swapaxes(gathered, 0, dim)

这些是测试用例:

# Test 1
    t = np.array([[65, 17], [14, 25], [76, 22]])
    idx = np.array([[0], [1], [0]])
    dim = 1
    result = gather(t, dim=dim, index=idx)
    expected = np.array([[65], [25], [76]])
    print(np.array_equal(result, expected))

# Test 2
    t = np.array([[47, 74, 44], [56, 9, 37]])
    idx = np.array([[0, 0, 1], [1, 1, 0], [0, 1, 0]])
    dim = 0
    result = gather(t, dim=dim, index=idx)
    expected = np.array([[47, 74, 37], [56, 9, 44.], [47, 9, 44]])
    print(np.array_equal(result, expected))

这篇关于如何在numpy中收集特定索引的元素?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆