结合 Numpy “where"声明 [英] Combine Numpy "where" statements
问题描述
我正在尝试加速使用 Numpy 的 where()
函数.有两次对 where()
的调用,它们返回一个索引数组,其中语句被评估为 True
,然后将它们与 numpy 的 intersect1d()
函数,其中长度返回交点.
I am trying to speed up a code that is using Numpy's where()
function. There are two calls to where()
, which return an array of indices for where the statement is evaluated as True
, which are then compared for overlap with numpy's intersect1d()
function, of which the length of the intersection is returned.
import numpy as np
def find_match(x,y,z):
A = np.where(x == z)
B = np.where(y == z)
#A = True
#B = True
return len(np.intersect1d(A,B))
N = np.power(10, 8)
M = 10
X = np.random.randint(M, size=N)
Y = np.random.randint(M, size=N)
Z = np.random.randint(M, size=N)
#print(X,Y,Z)
print(find_match(X,Y,Z))
时间:
这段代码在我的笔记本电脑上大约需要 8 秒.如果我将
np.where()
替换为A=True
和B=True
,则大约需要 5 秒.如果我只替换np.where()
中的一个,那么大约需要 6 秒.
This code takes about 8 seconds on my laptop. If I replace both the
np.where()
withA=True
andB=True
, then it takes about 5 seconds. If I replace only one of thenp.where()
then it takes about 6 seconds.
放大,通过切换到 N = np.power(10, 9)
,代码需要 87 秒.替换两个 np.where()
语句导致代码需要 51 秒.仅替换 np.where()
之一需要大约 61 秒.
Scaling up, by switching to N = np.power(10, 9)
, the code takes 87 seconds. Replacing both the np.where()
statements results in the code takes 51 seconds. Replacing just one of the np.where()
takes about 61 seconds.
我的问题:如何合并两个可以加速代码的np.where
语句?
My question: How can I merge the two np.where
statements that can speed up the code?
我尝试了什么?实际上,通过用 for 循环替换较慢的查找,代码的此迭代提高了速度(~4 倍).多处理将在此代码中在更高级别使用,所以我不能在这里也应用它.
What I've tried? Actually, this iteration of the code has improved speed (~4x) by replacing a slower lookup with for-loops. Multiprocessing will be used at a higher level in this code, so I can't apply it also here.
为了记录,实际数据是字符串,所以做整数数学不会有帮助
For the record, the actual data are strings, so doing integer math won't be helpful
版本信息:
Python 3.9.1 (default, Jan 8 2021, 17:17:43)
[Clang 12.0.0 (clang-1200.0.32.28)] on darwin
>>> import numpy
>>> print(numpy.__version__)
1.19.5
推荐答案
这有帮助吗?
def find_match2(x, y, z):
return len(np.nonzero(np.logical_and(x == z, y == z))[0])
样品运行:
In [227]: print(find_match(X,Y,Z))
1000896
In [228]: print(find_match2(X,Y,Z))
1000896
In [229]: %timeit find_match(X,Y,Z)
2.37 s ± 70.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [230]: %timeit find_match2(X,Y,Z)
272 ms ± 9.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
为了可重复性,我在创建数组之前添加了 np.random.seed(210)
.
I've added np.random.seed(210)
before creating the arrays for the sake of reproducibility.
这篇关于结合 Numpy “where"声明的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!