使用 numba 索引 numpy 数组时出现类型错误 [英] TypeError when indexing numpy array using numba

查看:186
本文介绍了使用 numba 索引 numpy 数组时出现类型错误的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我需要对基于另一个数组的一维 numpy 数组(下面:data)中的元素求和,其中包含有关类成员资格(labels>).我在下面的代码中使用 numba 来加快速度.但是,如果我没有在 ret[int(find(labels, g))] += y 行中使用 int() 进行显式转换,我会收到一条错误消息:

I need to sum up elements in a 1D numpy array (below: data) based on another array with information on class memberships (labels). I use numbain the code below to speed it up. However, If I dot not explicitly cast with int() in the line ret[int(find(labels, g))] += y, I reveice an error message:

TypeError: 不支持的数组索引类型?int64

有没有比显式转换更好的解决方法?

Is there a better workaround that explicit casting?

import numpy as np
from numba import jit

labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)

@jit(nopython=True)
def find(seq, value):
    for ct, x in enumerate(seq):
        if x == value:
            return ct

@jit(nopython=True)
def subsumNumba(data, groups, labels):
    ret = np.zeros(len(labels))
    for y, g in zip(data, groups):
        # not working without casting with int()
        ret[int(find(labels, g))] += y
    return ret

推荐答案

问题在于 find 可以返回 intNone如果它没有找到任何东西,因此我认为 ?int64 错误.为了避免强制转换,您需要在 find 退出时没有找到所需值时提供一个 int 返回值,然后在调用者中进行处理.

The problem is that find can either return an int or None if it doesn't find anything, thus I think the ?int64 error. To avoid casting, you need to provide an int return value when find exits without finding the desired value and then handle it in the caller.

这篇关于使用 numba 索引 numpy 数组时出现类型错误的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆