如何检查numpy数组的所有元素是否在另一个numpy数组中 [英] How to check if all elements of a numpy array are in another numpy array
问题描述
我有两个2D numpy数组,例如:
I have two 2D numpy arrays, for example:
A = numpy.array([[1, 2, 4, 8], [16, 32, 32, 8], [64, 32, 16, 8]])
和
B = numpy.array([[1, 2], [32, 32]])
我想拥有A
中的所有行,在这里我可以找到B
中任何行的所有元素.在B
行中有2个相同元素的地方,A
中的行也必须至少包含2个.以我的示例为例,我想实现以下目标:
I want to have all lines from A
where I can find all elements from any of the lines of B
. Where there are 2 of the same element in a row of B
, lines from A
must contain at least 2 as well. In case of my example, I want to achieve this:
A_filtered = [[1, 2, 4, 8], [16, 32, 32, 8]]
我可以控制值的表示形式,所以我选择了数字表示形式,其中二进制表示形式只用1
一位(例如:0b00000001
和0b00000010
等).这样,我可以轻松地检查是否全部通过使用np.logical_or.reduce()
函数,值的类型在行中,但我无法检查A
行中相同元素的数量是否大于或等于.我真的希望我可以避免简单的for
循环和数组的深拷贝,因为性能对我来说是非常重要的一个方面.
I have control over the values representation so I chose numbers where the binary representation has only one place with 1
(example: 0b00000001
and 0b00000010
, etc...) This way I can easily check if all type of values are in the row by using np.logical_or.reduce()
function, but I cannot check that the number of the same elements are bigger or equal in a row of A
. I was really hoping that I could avoid simple for
loop and deep copies of the arrays as the performance is a very important aspect for me.
如何以有效的方式在numpy中做到这一点?
How can I do that in numpy in an efficient way?
更新:
此处的解决方案可能有效,但我认为性能对我来说是一个很大的问题,A
可能真的很大(> 300000行),而B
可能中等(> 30):
A solution from here may work, but I think the performance is a big concern for me, the A
can be really big (>300000 rows) and B
can be moderate (>30):
[set(row).issuperset(hand) for row in A.tolist() for hand in B.tolist()]
更新2:
set()
解决方案无法正常工作,因为set()
会丢弃所有重复的值.
The set()
solution is not working since the set()
drops all duplicated values.
推荐答案
我希望我正确回答了您的问题.至少它可以解决您在问题中描述的问题.如果输出的顺序应与输入的顺序相同,请更改就地排序.
I hope I got your question right. At least it works with the problem you described in your question. If the order of the output should stay the same as the input, change the inplace-sort.
代码看起来很丑陋,但是应该表现良好,并且不难理解.
The code looks quite ugly, but should perform well and shouldn't be to hard to understand.
代码
import time
import numba as nb
import numpy as np
@nb.njit(fastmath=True,parallel=True)
def filter(A,B):
iFilter=np.zeros(A.shape[0],dtype=nb.bool_)
for i in nb.prange(A.shape[0]):
break_loop=False
for j in range(B.shape[0]):
ind_to_B=0
for k in range(A.shape[1]):
if A[i,k]==B[j,ind_to_B]:
ind_to_B+=1
if ind_to_B==B.shape[1]:
iFilter[i]=True
break_loop=True
break
if break_loop==True:
break
return A[iFilter,:]
衡量效果
####First call has some compilation overhead####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)
t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)
####Let's measure the second call too####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)
t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)
结果
46ms after the first run on a dual-core Notebook (sorting included)
32ms (sorting excluded)
这篇关于如何检查numpy数组的所有元素是否在另一个numpy数组中的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!