结合 Numpy “where"声明 [英] Combine Numpy "where" statements

查看:65
本文介绍了结合 Numpy “where"声明的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试加速使用 Numpy 的 where() 函数.有两次对 where() 的调用,它们返回一个索引数组,其中语句被评估为 True,然后将它们与 numpy 的 intersect1d() 函数,其中长度返回交点.

I am trying to speed up a code that is using Numpy's where() function. There are two calls to where(), which return an array of indices for where the statement is evaluated as True, which are then compared for overlap with numpy's intersect1d() function, of which the length of the intersection is returned.

import numpy as np

def find_match(x,y,z):

    A = np.where(x == z)
    B = np.where(y == z)
    #A = True
    #B = True
    
    return len(np.intersect1d(A,B))

N = np.power(10, 8)
M = 10

X = np.random.randint(M, size=N)
Y = np.random.randint(M, size=N)
Z = np.random.randint(M, size=N)

#print(X,Y,Z)
print(find_match(X,Y,Z))

时间:

  • 这段代码在我的笔记本电脑上大约需要 8 秒.如果我将 np.where() 替换为 A=TrueB=True,则大约需要 5 秒.如果我只替换 np.where() 中的一个,那么大约需要 6 秒.

  • This code takes about 8 seconds on my laptop. If I replace both the np.where() with A=True and B=True, then it takes about 5 seconds. If I replace only one of the np.where() then it takes about 6 seconds.

放大,通过切换到 N = np.power(10, 9),代码需要 87 秒.替换两个 np.where() 语句导致代码需要 51 秒.仅替换 np.where() 之一需要大约 61 秒.

Scaling up, by switching to N = np.power(10, 9), the code takes 87 seconds. Replacing both the np.where() statements results in the code takes 51 seconds. Replacing just one of the np.where() takes about 61 seconds.

我的问题:如何合并两个可以加速代码的np.where语句?

My question: How can I merge the two np.where statements that can speed up the code?

我尝试了什么?实际上,通过用 for 循环替换较慢的查找,代码的此迭代提高了速度(~4 倍).多处理将在此代码中在更高级别使用,所以我不能在这里也应用它.

What I've tried? Actually, this iteration of the code has improved speed (~4x) by replacing a slower lookup with for-loops. Multiprocessing will be used at a higher level in this code, so I can't apply it also here.

为了记录,实际数据是字符串,所以做整数数学不会有帮助

For the record, the actual data are strings, so doing integer math won't be helpful

版本信息:

Python 3.9.1 (default, Jan  8 2021, 17:17:43) 
[Clang 12.0.0 (clang-1200.0.32.28)] on darwin
>>> import numpy
>>> print(numpy.__version__)
1.19.5

推荐答案

这有帮助吗?

def find_match2(x, y, z):
    return len(np.nonzero(np.logical_and(x == z, y == z))[0])

样品运行:

In [227]: print(find_match(X,Y,Z))
1000896

In [228]: print(find_match2(X,Y,Z))
1000896

In [229]: %timeit find_match(X,Y,Z)
2.37 s ± 70.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [230]: %timeit find_match2(X,Y,Z)
272 ms ± 9.64 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

为了可重复性,我在创建数组之前添加了 np.random.seed(210).

I've added np.random.seed(210) before creating the arrays for the sake of reproducibility.

这篇关于结合 Numpy “where"声明的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆