numba中的性能嵌套循环 [英] Performance nested loop in numba

查看:133
本文介绍了numba中的性能嵌套循环的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

出于性能原因,除NumPy之外,我还开始使用Numba.我的Numba算法正在运行,但是我觉得它应该更快.有一点使它放慢了速度.这是代码段:

For performance reasons, I have started to use Numba besides NumPy. My Numba algorithm is working, but I have the feeling that it should be faster. There is one point which is slowing it down. Here is the code snippet:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                    if numpy.all(ws[x1][0:i] == ws[x2][0:i]) and 
                    numpy.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1

我认为if命令正在减慢它的速度.有没有更好的办法? (我在此处尝试实现的功能与先前发布的问题有关:对单个分频器的可能性进行计数)ws是大小为(gn, l)的NumPy数组,其中包含01

In my opinion the if command is slowing it down. Is there a better way? (What I try to achieve here is related to a previous posted problem: Count possibilites for single crossovers) ws is a NumPy array of size (gn, l) containing 0's and 1's

推荐答案

考虑到要确保所有项目相等的逻辑,您可以利用以下事实:如果有任何一项不相等,则可以短路(即停止比较)计算.我对原始函数进行了一些修改,以使(1)您不会重复相同的比较两次,并且(2)在所有嵌套循环中将y相加,以便可以比较返回值:

Given the logic of wanting to ensure all items are equal, you can take advantage of the fact that if any are not equal, you can short-circuit (i.e stop comparing) the calculation. I modified your original function slightly so that (1) you don't repeat the same comparison twice, and (2) sum y over the all nested loops so there was a return that could be compared:

@nb.njit
def rfunc1(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):
                    if np.all(ws[x1][0:i] == ws[x2][0:i]) and np.all(ws[x1][i:l] == ws[x3][i:l]):
                        y += 1
                        ysum += 1

    return ysum


@nb.njit
def rfunc2(ws, a, l):
    gn = a**l
    ysum = 0
    for x1 in range(gn):
        for x2 in range(gn):
            for x3 in range(gn):
                y = 0.0
                for i in range(1, l):

                    incr_y = True
                    for j in range(i):
                        if ws[x1,j] != ws[x2,j]:
                            incr_y = False
                            break

                    if incr_y is True:
                        for j in range(i,l):
                            if ws[x1,j] != ws[x3,j]:
                                incr_y = False
                                break
                    if incr_y is True:
                        y += 1
                        ysum += 1
    return ysum

我不知道完整的功能是什么样子,但是希望这可以帮助您开始正确的道路.

I don't know what the complete function looks like, but hopefully this helps you get started on the right path.

现在一些时间:

l = 7
a = 2
gn = a**l
ws = np.random.randint(0,2,size=(gn,l))
In [23]:

%timeit rfunc1(ws, a , l)
1 loop, best of 3: 2.11 s per loop


%timeit rfunc2(ws, a , l)
1 loop, best of 3: 39.9 ms per loop

In [27]: rfunc1(ws, a , l)
Out[27]: 131919

In [30]: rfunc2(ws, a , l)
Out[30]: 131919

这使您的速度提高了50倍.

That gives you a 50x speed-up.

这篇关于numba中的性能嵌套循环的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆