加速将函数作为参数与 numba 的函数 [英] speed up function that takes a function as argument with numba

查看:34
本文介绍了加速将函数作为参数与 numba 的函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用 numba 来加速将另一个函数作为参数的函数.一个最小的例子如下:

I am trying to use numba to speed up a function that takes another function as argument. A minimal example would be the following:

import numba as nb

def f(x):
    return x*x

@nb.jit(nopython=True)
def call_func(func,x):
    return func(x)

if __name__ == '__main__':
    print(call_func(f,5))

然而,这不起作用,因为显然 numba 不知道如何处理该函数参数.回溯很长:

This, however, doesn't work, as apparently numba doesn't know what to do with that function argument. The traceback is quite long:

Traceback (most recent call last):
  File "numba_function.py", line 15, in <module>
    print(call_func(f,5))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 330, in _compile_for_args
    raise e
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 307, in _compile_for_args
    return self.compile(tuple(argtypes))
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 579, in compile
    cres = self._compiler.compile(args, return_type)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/dispatcher.py", line 80, in compile
    flags=flags, locals=self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 740, in compile_extra
    return pipeline.compile_extra(func)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 360, in compile_extra
    return self._compile_bytecode()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 699, in _compile_bytecode
    return self._compile_core()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 686, in _compile_core
    res = pm.run(self.status)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 246, in run
    raise patched_exception
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 238, in run
    stage()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 452, in stage_nopython_frontend
    self.locals)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/compiler.py", line 841, in type_inference_stage
    infer.propagate()
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 773, in propagate
    raise errors[0]
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 129, in propagate
    constraint(typeinfer)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 380, in __call__
    self.resolve(typeinfer, typevars, fnty)
  File "/opt/local/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/numba/typeinfer.py", line 402, in resolve
    raise TypingError(msg, loc=self.loc)
numba.errors.TypingError: Failed at nopython (nopython frontend)
Invalid usage of pyobject with parameters (int64)
No type info available for pyobject as a callable.
File "numba_function.py", line 10
[1] During: resolving callee type: pyobject
[2] During: typing of call at numba_function.py (10)

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class 'function'>

有没有办法解决这个问题?

Is there a way to fix this?

推荐答案

这取决于你传递给 call_funcfunc 是否可以在 nopython 中编译代码>模式.

It depends if the func you pass to call_func can be compiled in nopython mode.

如果它不能在 nopython 模式下编译,那么它是不可能的,因为 numba 不支持在 nopython 函数中调用 python(这就是它被称为 nopython 的原因).

If it can't be compiled in nopython mode then it's impossible because numba doesn't support python calls inside a nopython function (that's the reason why it's called nopython).

但是,如果它可以在 nopython 模式下编译,则可以使用闭包:

However if it can be compiled in nopython mode you can use a closure:

import numba as nb

def f(x):
    return x*x

def call_func(func, x):
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner(x)

if __name__ == '__main__':
    print(call_func(f,5))

这种方法有一些明显的缺点,因为它需要在每次调用 call_func 时编译 funcinner.这意味着只有编译函数的加速比编译成本大时才可行.如果您多次使用相同的函数调用 call_func,您可以减轻这种开销:

That approach has some obvious downsides because it needs to compile func and inner every time you call call_func. That means it's only viable if the speedup by compiling the function is bigger than the compilation cost. You can mitigate that overhead if you call call_func with the same function several times:

import numba as nb

def f(x):
    return x*x

def call_func(func):  # only take func
    func = nb.njit(func)   # compile func in nopython mode!
    @nb.njit
    def inner(x):
        return func(x)
    return inner  # return the closure

if __name__ == '__main__':
    call_func_with_f = call_func(f)   # compile once
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version
    print(call_func_with_f(5))        # call the compiled version

只是一个一般性说明:我不会创建接受函数参数的 numba 函数.如果你不能对函数 numba 进行硬编码,就不能产生真正快速的函数,而且如果你还包括闭包的编译成本,那么它通常是不值得的.

Just a general note: I wouldn't create numba functions that take a function argument. If you can't hardcode the function numba can't produce really fast functions and if you also include the compilation cost for closures it's mostly just not worth it.

这篇关于加速将函数作为参数与 numba 的函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆