以numpy编制索引(与max/argmax有关) [英] indexing in numpy (related to max/argmax)

查看:95
本文介绍了以numpy编制索引(与max/argmax有关)的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我有一个N维的numpy数组x和一个(N-1)维的索引数组m(例如,m = x.argmax(axis=-1)).我想构造(N-1)维数组y,使y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]](对于上面的argmax示例将等同于y = x.max(axis=-1)). 对于N = 3,我可以通过

Suppose I have an N-dimensional numpy array x and an (N-1)-dimensional index array m (for example, m = x.argmax(axis=-1)). I'd like to construct (N-1) dimensional array y such that y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]] (for the argmax example above it would be equivalent to y = x.max(axis=-1)). For N=3 I could achieve what I want by

y = x[np.arange(x.shape[0])[:, np.newaxis], np.arange(x.shape[1]), m]

问题是,如何对任意N执行此操作?

The question is, how do I do this for an arbitrary N?

推荐答案

这是使用

Here's one approach using reshaping and linear indexing to handle multi-dimensional arrays of arbitrary dimensions -

shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

让我们以ndarray6 dimensions的情况为例,假设我们正在使用m = x.argmax(axis=-1)索引到最后一个维度.因此,输出将为x.max(-1).让我们针对提议的解决方案进行验证-

Let's take a sample case with a ndarray of 6 dimensions and let's say we are using m = x.argmax(axis=-1) to index into the last dimension. So, the output would be x.max(-1). Let's verify this for the proposed solution -

In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))

In [122]: m = x.argmax(axis=-1)

In [123]: shp = x.shape[:-1]
     ...: n_ele = np.prod(shp)
     ...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
     ...: 

In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True


我喜欢 @B. M.'s解决方案的优雅之处.因此,这是一个用于测试这两个基准的运行时测试-


I liked @B. M.'s solution for its elegance. So, here's a runtime test to benchmark these two -

def reshape_based(x,m):
    shp = x.shape[:-1]
    n_ele = np.prod(shp)
    return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)

def indices_based(x,m):  ## @B. M.'s solution
    firstdims=np.indices(x.shape[:-1])
    ind=tuple(firstdims)+(m,) 
    return x[ind]

时间-

In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
     ...: m = x.argmax(axis=-1)
     ...: 

In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop

In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop

这篇关于以numpy编制索引(与max/argmax有关)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆