Numba代码比纯Python慢 [英] Numba code slower than pure python
问题描述
我一直在努力加快粒子过滤器的重新采样计算.由于python有很多方法可以加快速度,因此尽管我会尝试所有方法.不幸的是,numba版本的运行速度非常慢.由于Numba应该会加快速度,因此我认为这是我的错误.
I've been working on speeding up a resampling calculation for a particle filter. As python has many ways to speed it up, I though I'd try them all. Unfortunately, the numba version is incredibly slow. As Numba should result in a speed up, I assume this is an error on my part.
我尝试了4种不同的版本:
I tried 4 different versions:
- Numba
- Python
- 脾气暴躁
- Cython
每个代码如下:
import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample
@nb.autojit
def numba_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def python_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def numpy_resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = sp.cumsum(qs)
for j, key in enumerate(rands):
i = sp.argmax(lookup>key)
results[j] = xs[i]
return results
#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs,
np.ndarray[DTYPE_t, ndim=1] xs,
np.ndarray[DTYPE_t, ndim=1] rands):
if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
raise ValueError("Arrays must have same shape")
assert qs.dtype == xs.dtype == rands.dtype == DTYPE
cdef unsigned int n = qs.shape[0]
cdef unsigned int i, j
cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
"""
if __name__ == '__main__':
n = 100
xs = np.arange(n, dtype=np.float64)
qs = np.array([1.0/n,]*n)
rands = np.random.rand(n)
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Python Function:"
%timeit python_resample(qs, xs, rands)
print "Timing Numpy Function:"
%timeit numpy_resample(qs, xs, rands)
print "Timing Cython Function:"
%timeit cython_resample(qs, xs, rands)
这将导致以下输出:
Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop
知道为什么numba代码这么慢吗?我以为它至少可以与Numpy相提并论.
Any idea why the numba code is so slow? I assumed it would be at least comparable to Numpy.
注意:如果有人对如何加快Numpy或Cython代码示例的速度有任何想法,那也很好:)我的主要问题是关于Numba.
推荐答案
问题是numba无法理解lookup
的类型.如果将print nb.typeof(lookup)
放入方法中,则会看到numba将其视为对象,这很慢.通常,我只会在locals字典中定义lookup
的类型,但是却遇到了一个奇怪的错误.相反,我只是创建了一个小包装程序,以便可以显式定义输入和输出类型.
The problem is that numba can't intuit the type of lookup
. If you put a print nb.typeof(lookup)
in your method, you'll see that numba is treating it as an object, which is slow. Normally I would just define the type of lookup
in a locals dict, but I was getting a strange error. Instead I just created a little wrapper, so that I could explicitly define the input and output types.
@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
return np.cumsum(x)
@nb.autojit
def numba_resample2(qs, xs, rands):
n = qs.shape[0]
#lookup = np.cumsum(qs)
lookup = numba_cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
那我的时间是:
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)
Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop
如果使用jit
而不是autojit
,您甚至可以走得更快一些:
You can go even a little faster still if you use jit
instead of autojit
:
@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))
对我来说,将其从15.3微秒降低到12.5微秒,但仍然令人印象深刻,即autojit的性能如何.
For me that lowers it from 15.3 microseconds to 12.5 microseconds, but it's still impressive how well autojit does.
这篇关于Numba代码比纯Python慢的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!