np.where还在多维数组中检查子元素 [英] np.where checking also for subelements in multidimensional arrays

查看:58
本文介绍了np.where还在多维数组中检查子元素的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有两个具有相同第二维的多维数组.我想确保第一个数组的任何元素(即,没有行)也是第二个数组的一行.

I have two multidimensional arrays with the same second dimension. I want to be sure no element (i.e., no row) of the first array is also a row of the second array.

为此,我正在使用 numpy.where ,但是它的行为也是检查同一位置的子元素.例如,考虑以下代码:

To do this I am using numpy.where, but its behaviour is also checking for sub-elements in the same position. For example consider this code:

x = np.array([[0,1,2,3], [4,0,6,9]])
z= np.array([[0,1,2,3], [5, 11, 6,98]])
for el in x:
    print(np.where(z==el))

它打印:

(array([0, 0, 0, 0]), array([0, 1, 2, 3]))
(array([1]), array([2]))

第一个结果是由于第一个数组相等,第二个是因为 z [1] x [1] 都具有 6作为第三个元素.有没有办法告诉 np.where 仅返回严格相等的元素的索引,即上例中的 0 ?

where the first result is due to the first arrays being equal, the second is because the both z[1] and x[1] have 6 as third element. Is there a way to tell np.where to return only indexes of strictly equal elements, i.e. 0 in the example above?

推荐答案

我还没有机会链接到回答,因为 np.unique 添加了 axis 参数.归功于@Jaime

Man I haven't had a chance to link to this answer since np.unique added an axis parameter. Credit to @Jaime

vview = lambda a: np.ascontiguousarray(a).view(np.dtype((np.void, a.dtype.itemsize * a.shape[1])))

基本上,这花费了行"数.矩阵,然后将它们转换为关于行的原始数据流的一维视图数组.这样一来,您就可以比较行,就好像它们是单个值一样.

Basically, that takes the "rows" of your matrix and turns them into a 1-d array of views on the raw datastream of the rows. This lets you compare rows as if they were single values.

那很简单:

print(np.where(vview(x) == vview(z).T))
(array([0], dtype=int64), array([0], dtype=int64))

表示 x 的第一行与 z

如果您只想知道 x 的行是否在 z 的行中:

If you only want to know if rows of x are in rows of z:

print(np.where(np.isin(vview(x), vview(z)).squeeze()))
(array([0], dtype=int64),)

在大型数组上检查与@mujjiga相比的时间:

Checking times compared to @mujjiga on big arrays:

x = np.random.randint(10, size = (1000, 4))

z = np.random.randint(10, size = (1000, 4))

%timeit np.where(np.isin(vview(x), vview(z)).squeeze())
365 µs ± 13.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit [i for i, e in enumerate(x) if (e == z).all(1).any()]  # @mujjiga
21.3 ms ± 1.28 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit np.where((x[:, None] == z).all(-1).any(-1))  # @orgoro
20 ms ± 767 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

因此,循环切片的速度提高了60倍,这可能是由于快速短路并且仅比较值的1/4

So about a 60x speedup over looping and slicing, probably due to quick short-circuiting and only comparing 1/4 the values

这篇关于np.where还在多维数组中检查子元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆