如何获取numpy数组中重复元素的所有索引的列表 [英] How to get a list of all indices of repeated elements in a numpy array

查看:38
本文介绍了如何获取numpy数组中重复元素的所有索引的列表的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试获取 numpy 数组中所有重复元素的索引,但我目前找到的解决方案对于大型(> 20000 个元素)输入数组来说确实效率低下(大约需要 9 秒).这个想法很简单:

I'm trying to get the index of all repeated elements in a numpy array, but the solution I found for the moment is REALLY inefficient for a large (>20000 elements) input array (it takes more or less 9 seconds). The idea is simple:

  1. records_array 是一个 numpy 时间戳数组 (datetime),我们要从中提取重复时间戳的索引

  1. records_arrayis a numpy array of timestamps (datetime) from which we want to extract the indexes of repeated timestamps

time_array 是一个 numpy 数组,包含在 records_array

time_array is a numpy array containing all the timestamps that are repeated in records_array

records 是一个包含一些 Record 对象的 django QuerySet(可以轻松转换为列表).我们想要创建一个由记录的 ​​tagId 属性的所有可能组合组成的对列表,这些组合对应于从 records_array 中找到的重复时间戳.

records is a django QuerySet (which can easily converted to a list) containing some Record objects. We want to create a list of couples formed by all possible combinations of tagId attributes of Record corresponding to the repeated timestamps found from records_array.

这是我目前的工作(但效率低下)代码:

Here is the working (but inefficient) code I have for the moment:

tag_couples = [];
for t in time_array:
    users_inter = np.nonzero(records_array == t)[0] # Get all repeated timestamps in records_array for time t
    l = [str(records[i].tagId) for i in users_inter] # Create a temporary list containing all tagIds recorded at time t
    if l.count(l[0]) != len(l): #remove tuples formed by the first tag repeated
        tag_couples +=[x for x in itertools.combinations(list(set(l)),2)] # Remove duplicates with list(set(l)) and append all possible couple combinations to tag_couples

我很确定这可以通过使用 Numpy 来优化,但我找不到一种方法来比较 records_arraytime_array 的每个元素而不使用 for循环(不能仅使用 == 进行比较,因为它们都是数组).

I'm quite sure this can be optimized by using Numpy, but I can't find a way to compare records_array with each element of time_array without using a for loop (this can't be compared by just using ==, since they are both arrays).

推荐答案

基于 unique().

import numpy as np

# create a test array
records_array = np.array([1, 2, 3, 1, 1, 3, 4, 3, 2])

# creates an array of indices, sorted by unique element
idx_sort = np.argsort(records_array)

# sorts records array so all unique elements are together 
sorted_records_array = records_array[idx_sort]

# returns the unique values, the index of the first occurrence of a value, and the count for each element
vals, idx_start, count = np.unique(sorted_records_array, return_counts=True, return_index=True)

# splits the indices into separate arrays
res = np.split(idx_sort, idx_start[1:])

#filter them with respect to their size, keeping only items occurring more than once
vals = vals[count > 1]
res = filter(lambda x: x.size > 1, res)


以下代码是原答案,需要多一点内存,使用numpy广播和两次调用unique:

records_array = array([1, 2, 3, 1, 1, 3, 4, 3, 2])
vals, inverse, count = unique(records_array, return_inverse=True,
                              return_counts=True)

idx_vals_repeated = where(count > 1)[0]
vals_repeated = vals[idx_vals_repeated]

rows, cols = where(inverse == idx_vals_repeated[:, newaxis])
_, inverse_rows = unique(rows, return_index=True)
res = split(cols, inverse_rows[1:])

和预期的一样 res = [array([0, 3, 4]), array([1, 8]), array([2, 5, 7])]

这篇关于如何获取numpy数组中重复元素的所有索引的列表的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆