加速将函数作为参数与 numba 的函数 [英] speed up function that takes a function as argument with numba
问题描述
我正在尝试使用 numba
来加速将另一个函数作为参数的函数.一个最小的例子如下:
I am trying to use numba
to speed up a function that takes another function as argument. A minimal example would be the following:
import numba as nb
def f(x):
return x*x
@nb.jit(nopython=True)
def call_func(func,x):
return func(x)
if __name__ == '__main__':
print(call_func(f,5))
然而,这不起作用,因为显然 numba
不知道如何处理该函数参数.回溯很长:
This, however, doesn't work, as apparently numba
doesn't know what to do with that function argument. The traceback is quite long:
Traceback (most recent call last):
File "numba_function.py", line 15, in <module>
print(call_func(f,5))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
raise e
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
return self.compile(tuple(argtypes))
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
cres = self._compiler.compile(args, return_type)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
flags=flags, locals=self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
return pipeline.compile_extra(func)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
return self._compile_bytecode()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
return self._compile_core()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
res = pm.run(self.status)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
raise patched_exception
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
stage()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
self.locals)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
infer.propagate()
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
raise errors[0]
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
constraint(typeinfer)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
self.resolve(typeinfer, typevars, fnty)
File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>
有没有办法解决这个问题?
Is there a way to fix this?
推荐答案
这取决于你传递给 call_func
的 func
是否可以在 nopython
中编译代码>模式.
It depends if the func
you pass to call_func
can be compiled in nopython
mode.
如果它不能在 nopython 模式下编译,那么它是不可能的,因为 numba 不支持在 nopython 函数中调用 python(这就是它被称为 nopython 的原因).
If it can't be compiled in nopython mode then it's impossible because numba doesn't support python calls inside a nopython function (that's the reason why it's called nopython).
但是,如果它可以在 nopython 模式下编译,则可以使用闭包:
However if it can be compiled in nopython mode you can use a closure:
import numba as nb
def f(x):
return x*x
def call_func(func, x):
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner(x)
if __name__ == '__main__':
print(call_func(f,5))
这种方法有一些明显的缺点,因为它需要在每次调用 call_func
时编译 func
和 inner
.这意味着只有编译函数的加速比编译成本大时才可行.如果您多次使用相同的函数调用 call_func
,您可以减轻这种开销:
That approach has some obvious downsides because it needs to compile func
and inner
every time you call call_func
. That means it's only viable if the speedup by compiling the function is bigger than the compilation cost. You can mitigate that overhead if you call call_func
with the same function several times:
import numba as nb
def f(x):
return x*x
def call_func(func): # only take func
func = nb.njit(func) # compile func in nopython mode!
@nb.njit
def inner(x):
return func(x)
return inner # return the closure
if __name__ == '__main__':
call_func_with_f = call_func(f) # compile once
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
print(call_func_with_f(5)) # call the compiled version
只是一个一般性说明:我不会创建接受函数参数的 numba 函数.如果你不能对函数 numba 进行硬编码,就不能产生真正快速的函数,而且如果你还包括闭包的编译成本,那么它通常是不值得的.
Just a general note: I wouldn't create numba functions that take a function argument. If you can't hardcode the function numba can't produce really fast functions and if you also include the compilation cost for closures it's mostly just not worth it.
这篇关于加速将函数作为参数与 numba 的函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!