numpy数组的累积argmax [英] cumulative argmax of a numpy array
本文介绍了numpy数组的累积argmax的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
考虑数组a
np.random.seed([3,1415])
a = np.random.randint(0, 10, (10, 2))
a
array([[0, 2],
[7, 3],
[8, 7],
[0, 6],
[8, 6],
[0, 2],
[0, 4],
[9, 7],
[3, 2],
[4, 3]])
获取累积argmax的矢量化方法是什么?
What is a vectorized way to get the cumulative argmax?
array([[0, 0], <-- both start off as max position
[1, 1], <-- 7 > 0 so 1st col = 1, 3 > 2 2nd col = 1
[2, 2], <-- 8 > 7 1st col = 2, 7 > 3 2nd col = 2
[2, 2], <-- 0 < 8 1st col stays the same, 6 < 7 2nd col stays the same
[2, 2],
[2, 2],
[2, 2],
[7, 2], <-- 9 is new max of 2nd col, argmax is now 7
[7, 2],
[7, 2]])
这是一种非矢量化的方法.
Here is a non-vectorized way to do it.
请注意,随着窗口的扩展,argmax将应用于不断增长的窗口.
Notice that as the window expands, argmax applies to the growing window.
pd.DataFrame(a).expanding().apply(np.argmax).astype(int).values
array([[0, 0],
[1, 1],
[2, 2],
[2, 2],
[2, 2],
[2, 2],
[2, 2],
[7, 2],
[7, 2],
[7, 2]])
推荐答案
这是一个矢量化的纯NumPy解决方案,执行起来非常轻松:
Here's a vectorized pure NumPy solution that performs pretty snappily:
def cumargmax(a):
m = np.maximum.accumulate(a)
x = np.repeat(np.arange(a.shape[0])[:, None], a.shape[1], axis=1)
x[1:] *= m[:-1] < m[1:]
np.maximum.accumulate(x, axis=0, out=x)
return x
那么我们有:
>>> cumargmax(a)
array([[0, 0],
[1, 1],
[2, 2],
[2, 2],
[2, 2],
[2, 2],
[2, 2],
[7, 2],
[7, 2],
[7, 2]])
对具有成千上万个值的数组进行的一些快速测试表明,这比在Python级别(隐式或显式)循环快10到50倍.
Some quick testing on arrays with thousands to millions of values suggests that this is anywhere between 10-50 times faster than looping at the Python level (either implicitly or explicitly).
这篇关于numpy数组的累积argmax的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!
查看全文