在数组中查找最大值索引的最快方法是什么? [英] What's the fastest way of finding the index of the maximum value in an array?
问题描述
我有一个类型为 f32
的2D数组(来自 ndarray :: ArrayView2
),我想找到每行最大值的索引,并将索引值放入另一个数组。
I have a 2D array of type f32
(from ndarray::ArrayView2
) and I want to find the index of the maximum value in each row, and put the index value into another array.
Python中的等效项如下:
The equivalent in Python is something like:
import numpy as np
for i in range (0, max_val, batch_size):
sims = xp.dot(batch, vectors.T)
# sims is the dot product of batch and vectors.T
# the shape is, for example, (1024, 10000)
best_rows[i: i+batch_size] = sims.argmax(axis = 1)
In Python,函数 .argmax
很快,但是在Rust中没有类似的函数。最快的方法是什么?
In Python, the function .argmax
is very fast, but I don't see any function like that in Rust. What's the fastest way of doing so?
推荐答案
来自@David A 的方法很酷,但如上所述,有一个陷阱: f32
& f64
不实现 Ord :: cmp
。 (这在您所知的地方确实很痛苦。)
The approach from @David A is cool, but as mentioned, there's a catch: f32
& f64
do not implement Ord::cmp
. (Which is really a pain in your-know-where.)
有多种解决方法:您可以实现 cmp
自己,也可以使用 ordered-float
等。
There are multiple ways of solving that: You can implement cmp
yourself, or you can use ordered-float
, etc..
是更大项目的一部分,我们在使用外部软件包时非常小心。此外,我很确定我们没有任何 NaN
值。因此,我宁愿使用 fold
,如果您仔细查看 max_by_key
源代码,它们就是它们的意思。
In my case, this is a part of a bigger project and we are very careful about using external packages. Besides, I am pretty sure we don't have any NaN
values. Therefore I would prefer using fold
, which, if you take a close look at the max_by_key
source code, is what they have been using too.
for (i, row) in matrix.axis_iter(Axis(1)).enumerate() {
let (max_idx, max_val) =
row.iter()
.enumerate()
.fold((0, row[0]), |(idx_max, val_max), (idx, val)| {
if &val_max > val {
(idx_max, val_max)
} else {
(idx, *val)
}
});
}
这篇关于在数组中查找最大值索引的最快方法是什么?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!