如何将数组指针传递给Numba函数? [英] How to pass array pointer to Numba function?

查看:231
本文介绍了如何将数组指针传递给Numba函数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想创建一个Numba编译的函数,该函数将指针或数组的内存地址作为参数并对其进行计算,例如,修改基础数据.

I'd like to create a Numba-compiled function that takes a pointer or the memory address of an array as an argument and does calculations on it, e.g., modifies the underlying data.

用于说明此问题的纯Python版本如下所示:

The pure-python version to illustrate this looks like this:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array


def modify_data(addr):
    """ a function taking the memory address of an array to modify it """
    ptr = ctypes.c_void_p(addr)
    data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
    data += 2

addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])

如您在示例中所见,数组arr被修改而没有将其显式传递给函数.在我的用例中,数组的形状和dtype是已知的,并且将始终保持不变,这将简化界面.

As you can see in the example, the array arr got modified without passing it to the function explicitly. In my use case, the shape and dtype of the array are known and will remain unchanged at all times, which should simplify the interface.

我现在尝试编译modify_data函数,但是失败了.我的第一次尝试是使用

I now tried to compile the modify_data function, but failed. My first attempt was to use

shape = arr.shape
dtype = arr.dtype

@nb.njit
def modify_data_nb(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2


ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr)   # <<< error

此操作失败,并显示cannot determine Numba type of <class 'ctypes.c_void_p'>,即它不知道如何解释指针.

This failed with cannot determine Numba type of <class 'ctypes.c_void_p'>, i.e., it does not know how to interpret the pointer.

我尝试放置显式类型,

arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape

@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(ptr, shape)
    data += 2

但这没有帮助.它没有引发任何错误,但是我不知道如何调用函数modify_data_nb.我尝试了以下选项

but this did not help. It did not throw any errors, but I do not know how to call the function modify_data_nb. I tried the following options

modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64

ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject

是否可以从arr获取正确的指针格式,以便将其传递给Numba编译的modify_data_nb函数?另外,还有另一种将内存位置传递给功能的方法.

Is there a way to obtain the correct pointer format from arr so I can pass it to the Numba-compiled modify_data_nb function? Alternatively, is there another way of passing the memory location to function.

我通过使用scipy.LowLevelCallable及其神奇之处取得了一些进步:

I made some progress by using scipy.LowLevelCallable and its magic:

arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])

# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype

@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
    data = nb.carray(ptr, shape, dtype=dtype)
    data += 2

modify_data_llc = LowLevelCallable(modify_data.ctypes).function    

# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))

# call the function only with the pointer
modify_data_llc(ptr)

# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])

我现在可以调用一个函数来访问数组,但是该函数不再是Numba函数.特别是,它不能在其他Numba函数中使用.

I can now call a function to access the array, but this function is no longer a Numba function. In particular, it cannot be used in other Numba functions.

推荐答案

感谢伟大的@stuartarchibald,我现在有了一个可行的解决方案:

Thanks to the great @stuartarchibald, I now have a working solution:

import ctypes
import numba as nb
import numpy as np

arr = np.arange(5).astype(np.double)  # create arbitrary numpy array
print(arr)

@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
    """ returns a void pointer from a given memory address """
    from numba.core import types, cgutils
    sig = types.voidptr(src)

    def codegen(cgctx, builder, sig, args):
        return builder.inttoptr(args[0], cgutils.voidptr_t)
    return sig, codegen

addr = arr.ctypes.data

@nb.njit
def modify_data():
    """ a function taking the memory address of an array to modify it """
    data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
    data += 2

modify_data()
print(arr)

键是新的address_as_void_pointer函数,它将内存地址(以int形式提供)转换为可由numba.carray使用的指针.

The key is the new address_as_void_pointer function that turns a memory address (given as an int) into a pointer that is usable by numba.carray.

这篇关于如何将数组指针传递给Numba函数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆