Numba:使用具有默认值的参数以显式签名调用jit [英] Numba: calling jit with explicit signature using arguments with default values

查看:151
本文介绍了Numba:使用具有默认值的参数以显式签名调用jit的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用numba制作一些包含numpy数组上的循环的函数.

I'm using numba to make some functions containing cycles on numpy arrays.

一切都很好,我可以使用jit,我学会了如何定义签名.

Everything is fine and dandy, I can use jit and I learned how to define the signature.

现在,我尝试在具有可选参数的函数上使用jit,例如:

Now I tried using jit on a function with optional arguments, e.g.:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float))'])
def fun(a, b=3):
    return a + b

这有效,但是如果我使用的是optional(float64)而不是optional(float)(与intint64相同).我花了1个小时试图弄清楚这个语法(实际上,我的一个朋友偶然发现了这个解决方案,因为他忘记了在浮点数后写64),但是,出于对我的爱,我不明白为什么会这样所以.我在互联网上找不到任何内容,并且有关该主题的numba的文档充其量是稀缺的(并且它们指定optional应该采用numba类型).

This works, but if instead of optional(float) I use optional(float64) it doesn't (same thing with int or int64). I lost 1 hour trying to figure this syntax out (actually, a friend of mine found this solution by chance because he forgot to write the 64 after the float), but, for the love of me, I cannot understand why this is so. I can't find anything on the internet and numba's docs on the topic are scarce at best (and they specify that optional should take a numba type).

有人知道这是怎么回事吗?我想念什么?

Does anyone know how this works? What am I missing?

推荐答案

啊,但是异常消息应该给出提示:

Ah, but the exception message should give a hint:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, optional(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3.0)

这意味着optional在这里是错误的选择.实际上, optional表示None或该类型".但是您需要一个可选参数,而不是可能是floatNone的参数,例如:

That means optional is the wrong choice here. In fact optional represents None or "that type". But you want an optional argument, not an argument that could be a float and None, e.g.:

>>> fun(10, None)  # doesn't fail because of the signature!
TypeError: unsupported operand type(s) for +: 'int' and 'NoneType'

我怀疑它只是为optional(float)工作的机会",因为从numbas的角度来看,float仅仅是一个任意Python对象",因此使用optional(float),您可以传递任何内容在那里(这显然没有给出论点).对于optional(float64),它只能是Nonefloat64.该类别的范围不足以允许不提供参数.

I suspect that it just "happens" to work for optional(float) because float is just an "arbitary Python object" from numbas point of view, so with optional(float) you could pass anything in there (this apparently includs not giving the argument). With optional(float64) it could only be None or a float64. That category isn't broad enough to allow not providing the argument.

如果输入类型Omitted,则可以使用:

It works if you give the type Omitted:

from numba import jit
import numpy as np

@jit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3.):
    return a + b

>>> fun(10.)
13.0

但是,看来Omitted实际上并未包含在文档中,并且它具有一些粗糙的边缘".例如,即使没有签名似乎也可能无法在具有该签名的nopython模式下进行编译:

However it seems like Omitted isn't actually included in the documentation and that it has some "rough edges". For example it can't be compiled in nopython mode with that signature, even though it seems possible without signature:

@njit(['float64(float64, float64)', 'float64(float64, Omitted(float64))'])
def fun(a, b=3):
    return a + b

TypingError: Failed at nopython (nopython frontend)
Invalid usage of + with parameters (float64, class(float64))

-----------

@njit(['float64(float64, float64)', 'float64(float64, Omitted(3.))'])
def fun(a, b=3):
    return a + b

>>> fun(10.)
TypeError: No matching definition for argument type(s) float64, omitted(default=3)

-----------

@njit
def fun(a, b=3):
    return a + b

>>> fun(10.)
13.0

这篇关于Numba:使用具有默认值的参数以显式签名调用jit的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆