通过在numpy中设置一些条件来检索元素的位置 [英] Retrieve position of elements with setting some criteria in numpy

查看:157
本文介绍了通过在numpy中设置一些条件来检索元素的位置的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

对于给定的2d数据数组,如何以粗体显示7和11的位置(索引). 因为只有它们是邻居中被相同值包围的元素

For the given 2d array of data, how to retrieve the position (index) of 7 and 11 in the bold. Because only they are the elements surrounded by same value in the neighbours

  import numpy as np
    data  = np.array([
        [0,1,2,3,4,7,6,7,8,9,10], 
        [3,3,3,4,7,7,7,8,11,12,11],  
        [3,3,3,5,7,**7**,7,9,11,11,11],
        [3,4,3,6,7,7,7,10,11,**11**,11],
        [4,5,6,7,7,9,10,11,11,11,11]
        ])

    print data

推荐答案

使用scipy,您可以将这些点定性为既是其邻域的最大点又是最小点的点:

Using scipy, you could characterize such points as those which are both the maximum and the minimum of its neighborhood:

import numpy as np
import scipy.ndimage.filters as filters

def using_filters(data):
    return np.where(np.logical_and.reduce(
        [data == f(data, footprint=np.ones((3,3)), mode='constant', cval=np.inf)
         for f in (filters.maximum_filter, filters.minimum_filter)]))  

using_filters(data)
# (array([2, 3]), array([5, 9]))

仅使用numpy,您可以将data与自身的8个移位切片进行比较,以找到相等的点:

Using only numpy, you could compare data with 8 shifted slices of itself to find the points which are equal:

def using_eight_shifts(data):
    h, w = data.shape
    data2 = np.empty((h+2, w+2))
    data2[(0,-1),:] = np.nan
    data2[:,(0,-1)] = np.nan
    data2[1:1+h,1:1+w] = data

    result = np.where(np.logical_and.reduce([
        (data2[i:i+h,j:j+w] == data)
        for i in range(3)
        for j in range(3)
        if not (i==1 and j==1)]))
    return result

正如您在上面看到的那样,此策略创建了一个扩展数组,该数组在数据周围具有NaN边界.这允许将移位的切片表示为data2[i:i+h,j:j+w].

As you can see above, this strategy makes an expanded array which has a border of NaNs around the data. This allows the shifted slices to be expressed as data2[i:i+h,j:j+w].

如果您知道要与邻居进行比较,则可能应该从一开始就用NaN的边界定义data,因此您不必像上面那样制作第二个数组.

If you know that you are going to be comparing against neighbors, it might behoove you to define data with a border of NaNs from the very beginning so you don't have to make a second array as done above.

使用8个移位(和比较)比循环遍历data中的每个单元格并将其与相邻单元格进行比较要快得多:

Using eight shifts (and comparisons) is much faster than looping over each cell in data and comparing it against its neighbors:

def using_quadratic_loop(data):
    return np.array([[i,j]
            for i in range(1,np.shape(data)[0]-1)
            for j in range(1,np.shape(data)[1]-1)
            if np.all(data[i-1:i+2,j-1:j+2]==data[i,j])]).T

这是一个基准:

using_filters            : 0.130
using_eight_shifts       : 0.340
using_quadratic_loop     : 18.794


以下是用于生成基准的代码:


Here is the code used to produce the benchmark:

import timeit
import operator
import numpy as np
import scipy.ndimage.filters as filters
import matplotlib.pyplot as plt

data  = np.array([
    [0,1,2,3,4,7,6,7,8,9,10], 
    [3,3,3,4,7,7,7,8,11,12,11],  
    [3,3,3,5,7,7,7,9,11,11,11],
    [3,4,3,6,7,7,7,10,11,11,11],
    [4,5,6,7,7,9,10,11,11,11,11]
    ])

data = np.tile(data, (50,50))

def using_filters(data):
    return np.where(np.logical_and.reduce(
        [data == f(data, footprint=np.ones((3,3)), mode='constant', cval=np.inf)
         for f in (filters.maximum_filter, filters.minimum_filter)]))    


def using_eight_shifts(data):
    h, w = data.shape
    data2 = np.empty((h+2, w+2))
    data2[(0,-1),:] = np.nan
    data2[:,(0,-1)] = np.nan
    data2[1:1+h,1:1+w] = data

    result = np.where(np.logical_and.reduce([
        (data2[i:i+h,j:j+w] == data)
        for i in range(3)
        for j in range(3)
        if not (i==1 and j==1)]))
    return result


def using_quadratic_loop(data):
    return np.array([[i,j]
            for i in range(1,np.shape(data)[0]-1)
            for j in range(1,np.shape(data)[1]-1)
            if np.all(data[i-1:i+2,j-1:j+2]==data[i,j])]).T

np.testing.assert_equal(using_quadratic_loop(data), using_filters(data))
np.testing.assert_equal(using_eight_shifts(data), using_filters(data))

timing = dict()
for f in ('using_filters', 'using_eight_shifts', 'using_quadratic_loop'):
    timing[f] = timeit.timeit('{f}(data)'.format(f=f),
                              'from __main__ import data, {f}'.format(f=f),
                              number=10) 

for f, t in sorted(timing.items(), key=operator.itemgetter(1)):
    print('{f:25}: {t:.3f}'.format(f=f, t=t))

这篇关于通过在numpy中设置一些条件来检索元素的位置的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆