如何将数组指针传递给Numba函数? [英] How to pass array pointer to Numba function?
问题描述
我想创建一个Numba编译的函数,该函数将指针或数组的内存地址作为参数并对其进行计算,例如,修改基础数据.
I'd like to create a Numba-compiled function that takes a pointer or the memory address of an array as an argument and does calculations on it, e.g., modifies the underlying data.
用于说明此问题的纯Python版本如下所示:
The pure-python version to illustrate this looks like this:
import ctypes
import numba as nb
import numpy as np
arr = np.arange(5).astype(np.double) # create arbitrary numpy array
def modify_data(addr):
""" a function taking the memory address of an array to modify it """
ptr = ctypes.c_void_p(addr)
data = nb.carray(ptr, arr.shape, dtype=arr.dtype)
data += 2
addr = arr.ctypes.data
modify_data(addr)
arr
# >>> array([2., 3., 4., 5., 6.])
如您在示例中所见,数组arr
被修改而没有将其显式传递给函数.在我的用例中,数组的形状和dtype是已知的,并且将始终保持不变,这将简化界面.
As you can see in the example, the array arr
got modified without passing it to the function explicitly. In my use case, the shape and dtype of the array are known and will remain unchanged at all times, which should simplify the interface.
我现在尝试编译modify_data
函数,但是失败了.我的第一次尝试是使用
I now tried to compile the modify_data
function, but failed. My first attempt was to use
shape = arr.shape
dtype = arr.dtype
@nb.njit
def modify_data_nb(ptr):
data = nb.carray(ptr, shape, dtype=dtype)
data += 2
ptr = ctypes.c_void_p(addr)
modify_data_nb(ptr) # <<< error
此操作失败,并显示cannot determine Numba type of <class 'ctypes.c_void_p'>
,即它不知道如何解释指针.
This failed with cannot determine Numba type of <class 'ctypes.c_void_p'>
, i.e., it does not know how to interpret the pointer.
我尝试放置显式类型,
arr_ptr_type = nb.types.CPointer(nb.float64)
shape = arr.shape
@nb.njit(nb.types.void(arr_ptr_type))
def modify_data_nb(ptr):
""" a function taking the memory address of an array to modify it """
data = nb.carray(ptr, shape)
data += 2
但这没有帮助.它没有引发任何错误,但是我不知道如何调用函数modify_data_nb
.我尝试了以下选项
but this did not help. It did not throw any errors, but I do not know how to call the function modify_data_nb
. I tried the following options
modify_data_nb(arr.ctypes.data)
# TypeError: No matching definition for argument type(s) int64
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject
ptr = ctypes.c_void_p(arr.ctypes.data)
modify_data_nb(ptr)
# TypeError: No matching definition for argument type(s) pyobject
是否可以从arr
获取正确的指针格式,以便将其传递给Numba编译的modify_data_nb
函数?另外,还有另一种将内存位置传递给功能的方法.
Is there a way to obtain the correct pointer format from arr
so I can pass it to the Numba-compiled modify_data_nb
function? Alternatively, is there another way of passing the memory location to function.
我通过使用scipy.LowLevelCallable
及其神奇之处取得了一些进步:
I made some progress by using scipy.LowLevelCallable
and its magic:
arr = np.arange(3).astype(np.double)
print(arr)
# >>> array([0., 1., 2.])
# create the function taking a pointer
shape = arr.shape
dtype = arr.dtype
@nb.cfunc(nb.types.void(nb.types.CPointer(nb.types.double)))
def modify_data(ptr):
data = nb.carray(ptr, shape, dtype=dtype)
data += 2
modify_data_llc = LowLevelCallable(modify_data.ctypes).function
# create pointer to array
ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
# call the function only with the pointer
modify_data_llc(ptr)
# check whether array got modified
print(arr)
# >>> array([2., 3., 4.])
我现在可以调用一个函数来访问数组,但是该函数不再是Numba函数.特别是,它不能在其他Numba函数中使用.
I can now call a function to access the array, but this function is no longer a Numba function. In particular, it cannot be used in other Numba functions.
推荐答案
感谢伟大的@stuartarchibald,我现在有了一个可行的解决方案:
Thanks to the great @stuartarchibald, I now have a working solution:
import ctypes
import numba as nb
import numpy as np
arr = np.arange(5).astype(np.double) # create arbitrary numpy array
print(arr)
@nb.extending.intrinsic
def address_as_void_pointer(typingctx, src):
""" returns a void pointer from a given memory address """
from numba.core import types, cgutils
sig = types.voidptr(src)
def codegen(cgctx, builder, sig, args):
return builder.inttoptr(args[0], cgutils.voidptr_t)
return sig, codegen
addr = arr.ctypes.data
@nb.njit
def modify_data():
""" a function taking the memory address of an array to modify it """
data = nb.carray(address_as_void_pointer(addr), arr.shape, dtype=arr.dtype)
data += 2
modify_data()
print(arr)
键是新的address_as_void_pointer
函数,它将内存地址(以int形式提供)转换为可由numba.carray
使用的指针.
The key is the new address_as_void_pointer
function that turns a memory address (given as an int) into a pointer that is usable by numba.carray
.
这篇关于如何将数组指针传递给Numba函数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!