使用 numba 索引 numpy 数组时出现类型错误 [英] TypeError when indexing numpy array using numba
问题描述
我需要对基于另一个数组的一维 numpy
数组(下面:data
)中的元素求和,其中包含有关类成员资格(labels
>).我在下面的代码中使用 numba
来加快速度.但是,如果我没有在 ret[int(find(labels, g))] += y
行中使用 int()
进行显式转换,我会收到一条错误消息:
I need to sum up elements in a 1D numpy
array (below: data
) based on another array with information on class memberships (labels
). I use numba
in the code below to speed it up. However, If I dot not explicitly cast with int()
in the line ret[int(find(labels, g))] += y
, I reveice an error message:
TypeError: 不支持的数组索引类型?int64
有没有比显式转换更好的解决方法?
Is there a better workaround that explicit casting?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
推荐答案
问题在于 find
可以返回 int
或 None
如果它没有找到任何东西,因此我认为 ?int64
错误.为了避免强制转换,您需要在 find
退出时没有找到所需值时提供一个 int
返回值,然后在调用者中进行处理.
The problem is that find
can either return an int
or None
if it doesn't find anything, thus I think the ?int64
error. To avoid casting, you need to provide an int
return value when find
exits without finding the desired value and then handle it in the caller.
这篇关于使用 numba 索引 numpy 数组时出现类型错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!