根据条件获取NumPy数组的连续元素组 [英] Get groups of consecutive elements of a NumPy array based on condition

查看:472
本文介绍了根据条件获取NumPy数组的连续元素组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个NumPy数组,如下所示:

I have a NumPy array as follows:

import numpy as np
a = np.array([1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8])

和常数b = 6

基于

Based on a previous question I can count the number c which is defined by the number of times the elements in a are less than b 2 or more times consecutively.

from itertools import groupby
b = 6
sum(len(list(g))>=2 for i, g in groupby(a < b) if i)

因此在此示例中c == 3

现在,我想在每次满足条件时输出一个数组,而不是计算满足条件的次数.

Now I would like to output an array each time the condition is met instead of counting the number of times the condition is met.

因此,在此示例中,正确的输出将是:

So with this example the right output would be:

array1 = [1, 4, 2]
array2 = [4, 4]
array3 = [3, 4, 4, 5]

因为:

1, 4, 2, 6, 4, 4, 6, 2, 7, 6, 2, 8, 9, 3, 6, 3, 4, 4, 5, 8  # numbers in a
1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0  # (a<b)
^^^^^^^-----^^^^-----------------------------^^^^^^^^^^---  # (a<b) 2+ times consecutively
   1         2                                    3

到目前为止,我已经尝试了不同的选择:

So far I have tried different options:

np.isin((len(list(g))>=2 for i, g in groupby(a < b)if i), a)

np.extract((len(list(g))>=2 for i, g in groupby(a < b)if i), a)

但是他们都没有实现我所寻找的目标.有人可以指出我正确的Python工具,以便输出满足我条件的不同数组吗?

But none of them achieved what I am searching for. Can someone point me to the right Python tools in order to output the different arrays satisfying my condition?

推荐答案

在评估我的其他答案的性能时虽然它的速度比奥斯汀的解决方案(长度小于15000的数组)要快,但它的复杂度并不是线性的.

While measuring performance of my other answer I noticed that while it was faster than Austin's solution (for arrays of length <15000), its complexity was not linear.

基于此答案,我使用

Based on this answer I came up with the following solution using np.split which is more efficent than both previously added answers here:

array = np.append(a, -np.inf)  # padding so we don't lose last element
mask = array >= 6  # values to be removed
split_indices = np.where(mask)[0]
for subarray in np.split(array, split_indices + 1):
    if len(subarray) > 2:
        print(subarray[:-1])

给予:

[1. 4. 2.]
[4. 4.]
[3. 4. 4. 5.]

性能*:

*由 perfplot

这篇关于根据条件获取NumPy数组的连续元素组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆