多维数组上的np.argmax,保持一些索引不变 [英] np.argmax on multidimensional arrays, keeping some indexes fixed

查看:1156
本文介绍了多维数组上的np.argmax,保持一些索引不变的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个2D集合的集合,取决于两个整数索引,比如p1和p2,每个矩阵具有相同的形状。

I have a collection of 2D narrays, depending on two integer indexes, say p1 and p2, with each matrix of the same shape.

然后我需要查找,对于每对(p1,p2),矩阵的最大值和这些最大值的索引。
一个微不足道的,虽然很慢,这样做的方法就是做这样的事情

Then I need to find, for each pair (p1,p2), the maximum value of the matrix and the indexes of these maxima. A trivial, albeit slow, way to do this would would be to do something like this

import numpy as np
import itertools
range1=range(1,10)
range2=range(1,20)

for p1,p2 in itertools.product(range1,range1):
    mat=np.random.rand(10,10)
    index=np.unravel_index(mat.argmax(), mat.shape)
    m=mat[index]
    print m, index

对于我的应用程序,遗憾的是这太慢了,我想由于使用双循环。
因此我尝试将所有东西打包成一个4维数组(比如BigMatrix),其中前两个坐标是索引p1,p2,另外两个是矩阵的坐标。

For my application this is unfortunately too slow, I guess due to the usage of double for loops. Therefore I tried to pack everything in a 4-dimensional array (say BigMatrix), where the first two coordinates are the indexes p1,p2, and the other 2 are the coordinates of the matrices.

np.amax命令

    >>res=np.amax(BigMatrix,axis=(2,3))
    >>res.shape
         (10,20)
    >>res[p1,p2]==np.amax(BigMatrix[p1,p2,:,:])
         True

按预期工作,因为它循环通过2轴和3轴。我如何为np.argmax做同样的事情?请记住速度很重要。

works as expected, as it loops through the 2 and 3 axis. How can I do the same for np.argmax? Please keep in mind that speed is important.

提前非常感谢,

Enzo

推荐答案

这对我来说适用于 Mat 是大矩阵。

This here works for me where Mat is the big matrix.

# flatten the 3 and 4 dimensions of Mat and obtain the 1d index for the maximum
# for each p1 and p2
index1d = np.argmax(Mat.reshape(Mat.shape[0],Mat.shape[1],-1),axis=2)

# compute the indices of the 3 and 4 dimensionality for all p1 and p2
index_x, index_y = np.unravel_index(index1d,Mat[0,0].shape)

# bring the indices into the right shape
index = np.array((index_x,index_y)).reshape(2,-1).transpose()

# get the maxima
max_val = np.amax(Mat,axis=(2,3)).reshape(-1)

# combine maxima and indices
sol = np.column_stack((max_val,index))

print sol

这篇关于多维数组上的np.argmax,保持一些索引不变的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆