Numba:使用具有默认值的参数以显式签名调用jit [英] Numba: calling jit with explicit signature using arguments with default values
问题描述
我正在使用numba制作一些包含numpy数组上的循环的函数.
I'm using numba to make some functions containing cycles on numpy arrays.
一切都很好,我可以使用jit
,我学会了如何定义签名.
Everything is fine and dandy, I can use jit
and I learned how to define the signature.
现在,我尝试在具有可选参数的函数上使用jit,例如:
Now I tried using jit on a function with optional arguments, e.g.:
from numba import jit
import numpy as np
@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
return a + b
这有效,但是如果我使用的是optional(float64)
而不是optional(float)
(与int
或int64
相同).我花了1个小时试图弄清楚这个语法(实际上,我的一个朋友偶然发现了这个解决方案,因为他忘记了在浮点数后写64
),但是,出于对我的爱,我不明白为什么会这样所以.我在互联网上找不到任何内容,并且有关该主题的numba的文档充其量是稀缺的(并且它们指定optional
应该采用numba类型).
This works, but if instead of optional(float)
I use optional(float64)
it doesn't (same thing with int
or int64
). I lost 1 hour trying to figure this syntax out (actually, a friend of mine found this solution by chance because he forgot to write the 64
after the float), but, for the love of me, I cannot understand why this is so. I can't find anything on the internet and numba's docs on the topic are scarce at best (and they specify that optional
should take a numba type).
有人知道这是怎么回事吗?我想念什么?
Does anyone know how this works? What am I missing?
推荐答案
啊,但是异常消息应该给出提示:
Ah, but the exception message should give a hint:
from numba import jit
import numpy as np
@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
return a + b
>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)
这意味着optional
在这里是错误的选择.实际上, optional
表示None
或该类型".但是您需要一个可选参数,而不是可能是float
和None
的参数,例如:
That means optional
is the wrong choice here. In fact optional
represents None
or "that type". But you want an optional argument, not an argument that could be a float
and None
, e.g.:
>>> fun(10, None) # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'
我怀疑它只是为optional(float)
工作的机会",因为从numbas的角度来看,float
仅仅是一个任意Python对象",因此使用optional(float)
,您可以传递任何内容在那里(这显然没有给出论点).对于optional(float64)
,它只能是None
或float64
.该类别的范围不足以允许不提供参数.
I suspect that it just "happens" to work for optional(float)
because float
is just an "arbitary Python object" from numbas point of view, so with optional(float)
you could pass anything in there (this apparently includs not giving the argument). With optional(float64)
it could only be None
or a float64
. That category isn't broad enough to allow not providing the argument.
如果输入类型Omitted
,则可以使用:
It works if you give the type Omitted
:
from numba import jit
import numpy as np
@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
return a + b
>>> fun(10.)
13.0
但是,看来Omitted
实际上并未包含在文档中,并且它具有一些粗糙的边缘".例如,即使没有签名似乎也可能无法在具有该签名的nopython模式下进行编译:
However it seems like Omitted
isn't actually included in the documentation and that it has some "rough edges". For example it can't be compiled in nopython mode with that signature, even though it seems possible without signature:
@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
return a + b
TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))
-----------
@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
return a + b
>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)
-----------
@njit
def fun(a, b=3):
return a + b
>>> fun(10.)
13.0
这篇关于Numba:使用具有默认值的参数以显式签名调用jit的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!