有效删除数组的每一行(如果它以纯numpy出现在另一个数组中) [英] Efficiently delete each row of an array if it occurs in another array in pure numpy
问题描述
我有一个numpy数组,其中的索引以(n,2)
的形式存储.例如:
I have one numpy array, where indices are stored in the shape of (n, 2)
. E.g.:
[[0, 1],
[2, 3],
[1, 2],
[4, 2]]
然后我进行一些处理,并创建一个形状为(m,2)
的数组,其中 n>m
.例如:
Then I do some processing and create an array in the shape of (m, 2)
, where n > m
. E.g.:
[[2, 3]
[4, 2]]
现在我想删除第二个数组中也可以找到的第一个数组中的每一行.所以我想要的结果是:
Now I want to delete every row in the first array that can be found in the second array as well. So my wanted result is:
[[0, 1],
[1, 2]]
我当前的解决方法如下:
My current solution is as follows:
for row in second_array:
result = np.delete(first_array, np.where(np.all(first_array == second_array, axis=1)), axis=0)
但是,如果秒数很大,则这是安静的时间.有人知道仅numpy的解决方案,不需要循环吗?
However, this is quiet time consuming if the second is large. Does someone know a numpy only solution, which does not require a loop?
推荐答案
这里有一个事实,即它们使用正整数使用 matrix-multiplication
进行降维-
Here's one leveraging the fact that they are positive numbers using matrix-multiplication
for dimensionality-reduction -
def setdiff_nd_positivenums(a,b):
s = np.maximum(a.max(0)+1,b.max(0)+1)
return a[~np.isin(a.dot(s),b.dot(s))]
样品运行-
In [82]: a
Out[82]:
array([[0, 1],
[2, 3],
[1, 2],
[4, 2]])
In [83]: b
Out[83]:
array([[2, 3],
[4, 2]])
In [85]: setdiff_nd_positivenums(a,b)
Out[85]:
array([[0, 1],
[1, 2]])
此外,第二个数组 b
似乎是 a
的子集.因此,我们可以使用 np.searchsorted
来利用这种情况进一步提高性能,例如-
Also, it seems the second-array b
is a subset of a
. So, we can leverage that scenario to boost the performance even further using np.searchsorted
, like so -
def setdiff_nd_positivenums_searchsorted(a,b):
s = np.maximum(a.max(0)+1,b.max(0)+1)
a1D,b1D = a.dot(s),b.dot(s)
b1Ds = np.sort(b1D)
return a[b1Ds[np.searchsorted(b1Ds,a1D)] != a1D]
时间-
In [146]: np.random.seed(0)
...: a = np.random.randint(0,9,(1000000,2))
...: b = a[np.random.choice(len(a), 10000, replace=0)]
In [147]: %timeit setdiff_nd_positivenums(a,b)
...: %timeit setdiff_nd_positivenums_searchsorted(a,b)
10 loops, best of 3: 101 ms per loop
10 loops, best of 3: 70.9 ms per loop
对于通用数字,这是另一个使用 views
-
For generic numbers, here's another using views
-
# https://stackoverflow.com/a/45313353/ @Divakar
def view1D(a, b): # a, b are arrays
a = np.ascontiguousarray(a)
b = np.ascontiguousarray(b)
void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
return a.view(void_dt).ravel(), b.view(void_dt).ravel()
def setdiff_nd(a,b):
# a,b are the nD input arrays
A,B = view1D(a,b)
return a[~np.isin(A,B)]
样品运行-
In [94]: a
Out[94]:
array([[ 0, 1],
[-2, -3],
[ 1, 2],
[-4, -2]])
In [95]: b
Out[95]:
array([[-2, -3],
[ 4, 2]])
In [96]: setdiff_nd(a,b)
Out[96]:
array([[ 0, 1],
[ 1, 2],
[-4, -2]])
时间-
In [158]: np.random.seed(0)
...: a = np.random.randint(0,9,(1000000,2))
...: b = a[np.random.choice(len(a), 10000, replace=0)]
In [159]: %timeit setdiff_nd(a,b)
1 loop, best of 3: 352 ms per loop
这篇关于有效删除数组的每一行(如果它以纯numpy出现在另一个数组中)的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!