将字典与numba njit函数一起使用 [英] Using Dictionaries with numba njit function

查看:881
本文介绍了将字典与numba njit函数一起使用的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

当输入和返回为字典时,如何加快numba的功能?

How to speed up a funtion with numba when input and return are dictionaries?

我熟悉将numba用于接受数字并返回数组的函数,例如:

I'm familiar with using numba for functions that accept numbers and return arrays, like this:

@numba.jit('float64[:](int32,int32)',nopython=True)
def f(a, b):
    # returns array 1d array

现在我有一个接受和返回字典的函数.如何在这里申请numba?

Now I have a function that accepts and returns dictionaries. How can I apply numba here?

    def collocation(aeolus_data,val_data):

      ...

      return sample_aeolus, sample_valdata

推荐答案

Numba版本43.0中已添加了对Dictionary的支持.尽管它非常有限(不支持列表并设置为键/值).不过,您可以在此处此处阅读更新的文档 . 这是一个例子

The support for Dictionary has now been added in Numba version 43.0. Although it quite limited (does not support list and set as key/values). You can however read the updated documentation here for more info. Here is an example

import numpy as np
from numba import njit
from numba import types
from numba.typed import Dict

# First create a dictionary using Dict.empty()
# Specify the data types for both key and value pairs

# Dict with key as strings and values of type float array
dict_param1 = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64[:],
)

# Dict with keys as string and values of type float
dict_param2 = Dict.empty(
    key_type=types.unicode_type,
    value_type=types.float64,
)

# Type-expressions are currently not supported inside jit functions.
float_array = types.float64[:]

@njit
def add_values(d_param1, d_param2):
    # Make a result dictionary to store results
    # Dict with keys as string and values of type float array
    result_dict = Dict.empty(
        key_type=types.unicode_type,
        value_type=float_array,
    )

    for key in d_param1.keys():
      result_dict[key] = d_param1[key] + d_param2[key]

    return result_dict

dict_param1["hello"]  = np.asarray([1.5, 2.5, 3.5], dtype='f8')
dict_param1["world"]  = np.asarray([10.5, 20.5, 30.5], dtype='f8')

dict_param2["hello"]  = 1.5
dict_param2["world"]  = 10

final_dict = add_values(dict_param1, dict_param2)

print(final_dict)
# Output : {hello: [3. 4. 5.], world: [20.5 30.5 40.5]}

链接到Google colab笔记本.

参考文献:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict

References:
- https://github.com/numba/numba/issues/3644
- https://numba.pydata.org/numba-doc/dev/reference/pysupported.html#dict

这篇关于将字典与numba njit函数一起使用的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆