如何检查numpy数组的所有元素是否在另一个numpy数组中 [英] How to check if all elements of a numpy array are in another numpy array

查看:1013
本文介绍了如何检查numpy数组的所有元素是否在另一个numpy数组中的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个2D numpy数组,例如:

I have two 2D numpy arrays, for example:

A = numpy.array([[1, 2, 4, 8], [16, 32, 32, 8], [64, 32, 16, 8]])

B = numpy.array([[1, 2], [32, 32]])

我想拥有A中的所有行,在这里我可以找到B中任何行的所有元素.在B行中有2个相同元素的地方,A中的行也必须至少包含2个.以我的示例为例,我想实现以下目标:

I want to have all lines from A where I can find all elements from any of the lines of B. Where there are 2 of the same element in a row of B, lines from A must contain at least 2 as well. In case of my example, I want to achieve this:

A_filtered = [[1, 2, 4, 8], [16, 32, 32, 8]]

我可以控制值的表示形式,所以我选择了数字表示形式,其中二进制表示形式只用1一位(例如:0b000000010b00000010等).这样,我可以轻松地检查是否全部通过使用np.logical_or.reduce()函数,值的类型在行中,但我无法检查A行中相同元素的数量是否大于或等于.我真的希望我可以避免简单的for循环和数组的深拷贝,因为性能对我来说是非常重要的一个方面.

I have control over the values representation so I chose numbers where the binary representation has only one place with 1 (example: 0b00000001 and 0b00000010, etc...) This way I can easily check if all type of values are in the row by using np.logical_or.reduce() function, but I cannot check that the number of the same elements are bigger or equal in a row of A. I was really hoping that I could avoid simple for loop and deep copies of the arrays as the performance is a very important aspect for me.

如何以有效的方式在numpy中做到这一点?

How can I do that in numpy in an efficient way?

更新:

此处的解决方案可能有效,但我认为性能对我来说是一个很大的问题,A可能真的很大(> 300000行),而B可能中等(> 30):

A solution from here may work, but I think the performance is a big concern for me, the A can be really big (>300000 rows) and B can be moderate (>30):

[set(row).issuperset(hand) for row in A.tolist() for hand in B.tolist()]

更新2:

set()解决方案无法正常工作,因为set()会丢弃所有重复的值.

The set() solution is not working since the set() drops all duplicated values.

推荐答案

我希望我正确回答了您的问题.至少它可以解决您在问题中描述的问题.如果输出的顺序应与输入的顺序相同,请更改就地排序.

I hope I got your question right. At least it works with the problem you described in your question. If the order of the output should stay the same as the input, change the inplace-sort.

代码看起来很丑陋,但是应该表现良好,并且不难理解.

The code looks quite ugly, but should perform well and shouldn't be to hard to understand.

代码

import time
import numba as nb
import numpy as np

@nb.njit(fastmath=True,parallel=True)
def filter(A,B):
  iFilter=np.zeros(A.shape[0],dtype=nb.bool_)

  for i in nb.prange(A.shape[0]):
    break_loop=False

    for j in range(B.shape[0]):
      ind_to_B=0
      for k in range(A.shape[1]):
        if A[i,k]==B[j,ind_to_B]:
          ind_to_B+=1

        if ind_to_B==B.shape[1]:
          iFilter[i]=True
          break_loop=True
          break

      if break_loop==True:
        break

  return A[iFilter,:]

衡量效果

####First call has some compilation overhead####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

####Let's measure the second call too####
A=np.random.randint(low=0, high=60, size=300_000*4).reshape(300_000,4)
B=np.random.randint(low=0, high=60, size=30*2).reshape(30,2)

t1=time.time()
#At first sort the arrays
A.sort()
B.sort()
A_filtered=filter(A,B)
print(time.time()-t1)

结果

46ms after the first run on a dual-core Notebook (sorting included)
32ms (sorting excluded)

这篇关于如何检查numpy数组的所有元素是否在另一个numpy数组中的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
相关文章
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆