以numpy编制索引(与max/argmax有关) [英] indexing in numpy (related to max/argmax)
问题描述
假设我有一个N维的numpy数组x
和一个(N-1)维的索引数组m
(例如,m = x.argmax(axis=-1)
).我想构造(N-1)维数组y
,使y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]]
(对于上面的argmax
示例将等同于y = x.max(axis=-1)
).
对于N = 3,我可以通过
Suppose I have an N-dimensional numpy array x
and an (N-1)-dimensional index array m
(for example, m = x.argmax(axis=-1)
). I'd like to construct (N-1) dimensional array y
such that y[i_1, ..., i_N-1] = x[i_1, ..., i_N-1, m[i_1, ..., i_N-1]]
(for the argmax
example above it would be equivalent to y = x.max(axis=-1)
).
For N=3 I could achieve what I want by
y = x[np.arange(x.shape[0])[:, np.newaxis], np.arange(x.shape[1]), m]
问题是,如何对任意N执行此操作?
The question is, how do I do this for an arbitrary N?
推荐答案
这是使用 linear indexing
以处理任意维的多维数组-
Here's one approach using reshaping
and linear indexing
to handle multi-dimensional arrays of arbitrary dimensions -
shp = x.shape[:-1]
n_ele = np.prod(shp)
y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
让我们以ndarray
为6 dimensions
的情况为例,假设我们正在使用m = x.argmax(axis=-1)
索引到最后一个维度.因此,输出将为x.max(-1)
.让我们针对提议的解决方案进行验证-
Let's take a sample case with a ndarray
of 6 dimensions
and let's say we are using m = x.argmax(axis=-1)
to index into the last dimension. So, the output would be x.max(-1)
. Let's verify this for the proposed solution -
In [121]: x = np.random.randint(0,9,(4,5,3,3,2,4))
In [122]: m = x.argmax(axis=-1)
In [123]: shp = x.shape[:-1]
...: n_ele = np.prod(shp)
...: y_out = x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
...:
In [124]: np.allclose(x.max(-1),y_out)
Out[124]: True
我喜欢 @B. M.'s
解决方案的优雅之处.因此,这是一个用于测试这两个基准的运行时测试-
I liked @B. M.'s
solution for its elegance. So, here's a runtime test to benchmark these two -
def reshape_based(x,m):
shp = x.shape[:-1]
n_ele = np.prod(shp)
return x.reshape(n_ele,-1)[np.arange(n_ele),m.ravel()].reshape(shp)
def indices_based(x,m): ## @B. M.'s solution
firstdims=np.indices(x.shape[:-1])
ind=tuple(firstdims)+(m,)
return x[ind]
时间-
In [152]: x = np.random.randint(0,9,(4,5,3,3,4,3,6,2,4,2,5))
...: m = x.argmax(axis=-1)
...:
In [153]: %timeit indices_based(x,m)
10 loops, best of 3: 30.2 ms per loop
In [154]: %timeit reshape_based(x,m)
100 loops, best of 3: 5.14 ms per loop
这篇关于以numpy编制索引(与max/argmax有关)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!