Numba 数据类型错误:无法统一数组 [英] Numba data type error: Cannot unify array

查看:65
本文介绍了Numba 数据类型错误:无法统一数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 Numba 来加速一系列功能,如下所示.如果我将函数 PosMomentSingle 中的 step_size 变量设置为浮点数(例如 step_size = 0.5),而不是整数(例如 step_size= 1.0),我收到以下错误:

I am using Numba to speed up a series of functions as shown below. if I set the step_size variable in function PosMomentSingle to a float (e.g. step_size = 0.5), instead of an integer (e.g step_size = 1.0), I get the following error:

Cannot unify array(float32, 1d, C) and array(float64, 1d, C) for 'axle_coords.2', defined at <ipython-input-182-37c789ca2187> (12)

File "<ipython-input-182-37c789ca2187>", line 12:
def nbSimpleSpanMoment(L, axles, spacings, step_size):
    <source elided>
    
    while np.min(axle_coords) < L:

我发现很难理解问题是什么,但我的猜测是 @jit (nbSimpleSpanMoment) 之后的函数存在问题,有某种数据类型不匹配.我尝试将所有变量设置为 float32,然后设置为 float64(例如 L = np.float32(L)),但无论我尝试什么都会产生一组新的错误.由于错误消息非常神秘,我无法调试问题.有 numba 经验的人可以解释一下我在这里做错了什么吗?

I found it quite hard to understand what the problem is, but my guess is there is an issue with the function after @jit (nbSimpleSpanMoment), with some kind of a datatype mismatch. I tried setting all variables to float32, then to float64 (e.g. L = np.float32(L)) but whatever I try creates a new set of errors. Since the error message is quite cryptic, I am unable to debug the issue. Can someone with numba experience explain what I am doing wrong here?

我把我的代码放在下面来重现这个问题.

I placed my code below to recreate the problem.

感谢您的帮助!

import numba as nb
import numpy as np

@nb.vectorize(nopython=True)
def nbvectMoment(L,x):
    if x<L/2.0:
        return 0.5*x
    else:
        return 0.5*(L-x)

@nb.jit(nopython=True)
def nbSimpleSpanMoment(L, axles, spacings, step_size):
    travel = L + np.sum(spacings)
    maxmoment = 0
    axle_coords = -np.cumsum(spacings)
    moment_inf = np.empty_like(axles)
    while np.min(axle_coords) < L:
        axle_coords = axle_coords + step_size
        y = nbvectMoment(L,axle_coords)
        for k in range(y.shape[0]):
            if axle_coords[k] >=0 and axle_coords[k] <= L:
                moment_inf[k] = y[k]
            else:
                moment_inf[k] = 0.0   
        moment = np.sum(moment_inf * axles)
        if maxmoment < moment:
            maxmoment = moment
    return np.around(maxmoment,1)

def PosMomentSingle(current_axles, current_spacings):
    data_list = []
    for L in range (1,201):
        L=float(L)        
        if L <= 40:
            step_size = 0.5
        else:
            step_size = 0.5            
        axles = np.array(current_axles, dtype='f')
        spacings = np.array(current_spacings, dtype='f')            
        axles_inv = axles[::-1]
        spacings_inv = spacings[::-1]           
        spacings = np.insert(spacings,0,0)
        spacings_inv = np.insert(spacings_inv,0,0)            
        left_to_right = nbSimpleSpanMoment(L, axles, spacings, step_size)
        right_to_left = nbSimpleSpanMoment(L, axles_inv, spacings_inv, step_size)            
        data_list.append(max(left_to_right, right_to_left))
    return data_list

load_effects = []
for v in range(14,31):
    load_effects.append(PosMomentSingle([8, 32, 32], [14, v]))
load_effects = np.array(load_effects)

推荐答案

删除代码中的所有类型转换后,返回以下错误

After removing all type conversions in your code, the following error was returned

TypingError: Cannot unify array(int64, 1d, C) and array(float64, 1d, C) for 'axle_coords.2'

这帮助我将错误追溯到 spacingsdtype.在您的代码中,它初始化为 C 兼容 single,这似乎与 python float32 不同,请参阅 此处.将其更改为 np.float64 后,代码现在可以运行了.

This helped me to trace back the error to the dtype of spacings. In your code this initialized as a C compatible single, which seems to be different from a python float32, see here. After changing this to np.float64 the code now runs.

下面的代码现在可以运行,并且不再出现统一错误.

The code below now runs and unify error does not occur anymore.

import numba as nb
import numpy as np

@nb.vectorize(nopython=True)
def nbvectMoment(L,x):
    if x<L/2.0:
        return 0.5*x
    else:
        return 0.5*(L-x)

@nb.jit(nopython=True)
def nbSimpleSpanMoment(L, axles, spacings, step_size):
    travel = L + np.sum(spacings)
    maxmoment = 0
    axle_coords = -np.cumsum(spacings)
    moment_inf = np.empty_like(axles)
    while np.min(axle_coords) < L:
        axle_coords = axle_coords + step_size
        y = nbvectMoment(L,axle_coords)
        for k in range(y.shape[0]):
            if axle_coords[k] >=0 and axle_coords[k] <= L:
                moment_inf[k] = y[k]
            else:
                moment_inf[k] = 0.0
        moment = np.sum(moment_inf * axles)
        if maxmoment < moment:
            maxmoment = moment
    return np.around(maxmoment,1)

def PosMomentSingle(current_axles, current_spacings):
    data_list = []
    for L in range (1,201):
        L=float(L)
        if L <= 40:
            step_size = 0.5
        else:
            step_size = 0.5
        axles = np.array(current_axles, np.float32)
        spacings = np.array(current_spacings, dtype=np.float64)
        axles_inv = axles[::-1]
        spacings_inv = spacings[::-1]
        spacings = np.insert(spacings,0,0)
        spacings_inv = np.insert(spacings_inv,0,0)
        left_to_right = nbSimpleSpanMoment(L, axles, spacings, step_size)
        right_to_left = nbSimpleSpanMoment(L, axles_inv, spacings_inv, step_size)
        data_list.append(max(left_to_right, right_to_left))
    return data_list

load_effects = []
for v in range(14,31):
    load_effects.append(PosMomentSingle([8, 32, 32], [14, v]))
load_effects = np.array(load_effects)

这篇关于Numba 数据类型错误:无法统一数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆