将字典与numba njit函数一起使用 [英] Using Dictionaries with numba njit function
问题描述
当输入和返回为字典时,如何加快numba的功能?
How to speed up a funtion with numba when input and return are dictionaries?
我熟悉将numba用于接受数字并返回数组的函数,例如:
I'm familiar with using numba for functions that accept numbers and return arrays, like this:
@numba.jit('float64[:](int32,int32)',nopython=True)
def f(a, b):
# returns array 1d array
现在我有一个接受和返回字典的函数.如何在这里申请numba?
Now I have a function that accepts and returns dictionaries. How can I apply numba here?
def collocation(aeolus_data,val_data):
...
return sample_aeolus, sample_valdata
推荐答案
Numba版本43.0
中已添加了对Dictionary的支持.尽管它非常有限(不支持列表并设置为键/值).不过,您可以在此处此处阅读更新的文档 .
这是一个例子
The support for Dictionary has now been added in Numba version 43.0
. Although it quite limited (does not support list and set as key/values). You can however read the updated documentation here for more info.
Here is an example
import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict
# First create a dictionary using Dict.empty()
# Specify the data types for both key and value pairs
# Dict with key as strings and values of type float array
dict_param1 = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64[:],
)
# Dict with keys as string and values of type float
dict_param2 = Dict.empty(
key_type=types.unicode_type,
value_type=types.float64,
)
# Type-expressions are currently not supported inside jit functions.
float_array = types.float64[:]
@njit
def add_values(d_param1, d_param2):
# Make a result dictionary to store results
# Dict with keys as string and values of type float array
result_dict = Dict.empty(
key_type=types.unicode_type,
value_type=float_array,
)
for key in d_param1.keys():
result_dict[key] = d_param1[key] + d_param2[key]
return result_dict
dict_param1["hello"] = np.asarray([1.5, 2.5, 3.5], dtype='f8')
dict_param1["world"] = np.asarray([10.5, 20.5, 30.5], dtype='f8')
dict_param2["hello"] = 1.5
dict_param2["world"] = 10
final_dict = add_values(dict_param1, dict_param2)
print(final_dict)
# Output : {hello: [3. 4. 5.], world: [20.5 30.5 40.5]}
参考文献:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict
References:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict
这篇关于将字典与numba njit函数一起使用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!