高效返回数组中第一个满足条件的值的索引 [英] Efficiently return the index of the first value satisfying condition in array

查看:39
本文介绍了高效返回数组中第一个满足条件的值的索引的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要找到满足条件的 1d NumPy 数组或 Pandas 数字系列中第一个值的索引.数组很大,索引可能靠近数组的开头可能根本不满足条件.我无法提前判断哪个更有可能.如果不满足条件,返回值应该是-1.我考虑了几种方法.

I need to find the index of the first value in a 1d NumPy array, or Pandas numeric series, satisfying a condition. The array is large and the index may be near the start or end of the array, or the condition may not be met at all. I can't tell in advance which is more likely. If the condition is not met, the return value should be -1. I've considered a few approaches.

# func(arr) returns a Boolean array
idx = next(iter(np.where(func(arr))[0]), -1)

但这通常太慢了,因为 func(arr)整个 数组应用向量化函数,而不是在满足条件时停止.具体来说,当条件在数组的开始附近满足时,代价是昂贵的.

But this is often too slow as func(arr) applies a vectorised function on the entire array rather than stopping when the condition is met. Specifically, it is expensive when the condition is met near the start of the array.

np.argmax 稍微快一点,但无法识别何时从不满足条件:

np.random.seed(0)
arr = np.random.rand(10**7)

assert next(iter(np.where(arr > 0.999999)[0]), -1) == np.argmax(arr > 0.999999)

%timeit next(iter(np.where(arr > 0.999999)[0]), -1)  # 21.2 ms
%timeit np.argmax(arr > 0.999999)                    # 17.7 ms

np.argmax(arr > 1.0) 返回0,即满足条件的实例.

np.argmax(arr > 1.0) returns 0, i.e. an instance when the condition is not satisfied.

# func(arr) returns a Boolean scalar
idx = next((idx for idx, val in enumerate(arr) if func(arr)), -1)

但是当在数组的end附近满足条件时,这太慢了.大概这是因为生成器表达式从大量 __next__ 调用中产生了昂贵的开销.

But this is too slow when the condition is met near the end of the array. Presumably this is because the generator expression has an expensive overhead from a large number of __next__ calls.

总是是一种妥协还是有一种方法,对于通用的func,可以有效地提取第一个索引?

Is this always a compromise or is there a way, for generic func, to extract the first index efficiently?

对于基准测试,假设 func 在某个值大于给定常量时找到索引:

For benchmarking, assume func finds the index when a value is greater than a given constant:

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0
import numpy as np

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

# Start of array benchmark
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

# End of array benchmark
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

推荐答案

numba

使用numba可以优化两者场景.从语法上讲,您只需要用一个简单的 for 循环构造一个函数:

numba

With numba it's possible to optimise both scenarios. Syntactically, you need only construct a function with a simple for loop:

from numba import njit

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

idx = get_first_index_nb(A, 0.9)

Numba 通过 JIT(及时")编译代码并利用 CPU 级优化.没有 @njit 装饰器的 常规 for 循环通常比您已经尝试过的方法晚满足条件的情况.

Numba improves performance by JIT ("Just In Time") compiling code and leveraging CPU-level optimisations. A regular for loop without the @njit decorator would typically be slower than the methods you've already tried for the case where the condition is met late.

对于 Pandas 数字系列 df['data'],您可以简单地将 NumPy 表示提供给 JIT 编译函数:

For a Pandas numeric series df['data'], you can simply feed the NumPy representation to the JIT-compiled function:

idx = get_first_index_nb(df['data'].values, 0.9)

概括

由于 numba 允许 函数作为参数,并且假设传递的函数也可以被 JIT 编译,你可以得到一种方法来计算满足任意条件的 n 个索引func.

Generalisation

Since numba permits functions as arguments, and assuming the passed the function can also be JIT-compiled, you can arrive at a method to calculate the nth index where a condition is met for an arbitrary func.

@njit
def get_nth_index_count(A, func, count):
    c = 0
    for i in range(len(A)):
        if func(A[i]):
            c += 1
            if c == count:
                return i
    return -1

@njit
def func(val):
    return val > 0.9

# get index of 3rd value where func evaluates to True
idx = get_nth_index_count(arr, func, 3)

对于第三个 last 值,您可以反向输入 arr[::-1],并对 len(arr) 的结果求反 -1,说明 0 索引所需的 - 1.

For the 3rd last value, you can feed the reverse, arr[::-1], and negate the result from len(arr) - 1, the - 1 necessary to account for 0-indexing.

# Python 3.6.5, NumPy 1.14.3, Numba 0.38.0

np.random.seed(0)
arr = np.random.rand(10**7)
m = 0.9
n = 0.999999

@njit
def get_first_index_nb(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

def get_first_index_np(A, k):
    for i in range(len(A)):
        if A[i] > k:
            return i
    return -1

%timeit get_first_index_nb(arr, m)                                 # 375 ns
%timeit get_first_index_np(arr, m)                                 # 2.71 µs
%timeit next(iter(np.where(arr > m)[0]), -1)                       # 43.5 ms
%timeit next((idx for idx, val in enumerate(arr) if val > m), -1)  # 2.5 µs

%timeit get_first_index_nb(arr, n)                                 # 204 µs
%timeit get_first_index_np(arr, n)                                 # 44.8 ms
%timeit next(iter(np.where(arr > n)[0]), -1)                       # 21.4 ms
%timeit next((idx for idx, val in enumerate(arr) if val > n), -1)  # 39.2 ms

这篇关于高效返回数组中第一个满足条件的值的索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆