数组中的 Python Numba 值 [英] Python Numba Value in Array

查看:37
本文介绍了数组中的 Python Numba 值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试检查一个数字是否在 int8 的 NumPy 数组中.我试过这个,但它不起作用.

I am trying to check if a number is in NumPy array of int8s. I tried this, but it does not work.

from numba import njit
import numpy as np

@njit
def c(b):
    return 9 in b

a = np.array((9, 10, 11), 'int8')
print(c(a))

我得到的错误是

Invalid use of Function(<built-in function contains>) with argument(s) of type(s): (array(int8, 1d, C), Literal[int](9))
 * parameterized
In definition 0:
    All templates rejected with literals.
In definition 1:
    All templates rejected without literals.
In definition 2:
    All templates rejected with literals.
In definition 3:
    All templates rejected without literals.
In definition 4:
    All templates rejected with literals.
In definition 5:
    All templates rejected without literals.
This error is usually caused by passing an argument of a type that is unsupported by the named function.
[1] During: typing of intrinsic-call at .\emptyList.py (6)

如何在保持性能的同时解决此问题?数组将被检查两个值,1 和 -1,并且有 32 个项目长.它们没有排序.

How can I fix this while still maintaining performance? The arrays will be checked for two values, 1 and -1, and are 32 items long. They are not sorted.

推荐答案

检查两个值是否在一个数组中

为了仅检查数组中是否出现两个值,我会推荐一个简单的蛮力算法.

Checking if two values are in an array

For checking only if two values occur in an array I would recommend a simple brute force algorithm.

代码

import numba as nb
import numpy as np

@nb.njit(fastmath=True)
def isin(b):
  for i in range(b.shape[0]):
    res=False
    if (b[i]==-1):
      res=True
    if (b[i]==1):
      res=True
  return res

#Parallelized call to isin if the data is an array of shape (n,m)
@nb.njit(fastmath=True,parallel=True)
def isin_arr(b):
  res=np.empty(b.shape[0],dtype=nb.boolean)
  for i in nb.prange(b.shape[0]):
    res[i]=isin(b[i,:])

  return res

性能

#Create some data (320MB)
A=(np.random.randn(10000000,32)-0.5)*5
A=A.astype(np.int8)
res=isin_arr(A) 11ms per call

因此,通过这种方法,我获得了大约 29GB/s 的吞吐量,这与内存带宽相距不远.您还可以尝试减少 Testdatasize,使其适合 L3-cache 以避免内存带宽限制.使用 3.2 MB 测试数据时,我获得了 100 GB/s 的吞吐量(远远超出了我的内存带宽),这清楚地表明此实现是内存带宽有限的.

So with this method I get a throughput of about 29GB/s which isn't far away from memory bandwith. You can also try to reduce the Testdatasize so that it will fit in L3-cache to avoid the memory-bandwith limit. With 3.2 MB Testdata I get a throuput of 100 GB/s (far beyond my the memory bandwith), which is a clear indicator that this implementation is memory bandwidth limited.

这篇关于数组中的 Python Numba 值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆