计算numpy数组中每个其他元素的增量内有多少个元素 [英] count how many elements in a numpy array are within delta of every other element
问题描述
考虑数组x
和增量变量d
np.random.seed([3,1415])
x = np.random.randint(100, size=10)
d = 10
对于x
中的每个元素,我想计算每个元素中有多少其他元素在距离d
distance 的范围内.
For each element in x
, I want to count how many other elements in each are within delta d
distance away.
所以x看起来像
print(x)
[11 98 74 90 15 55 13 11 13 26]
结果应该是
[5 2 1 2 5 1 5 5 5 1]
我尝试过的
策略:
what I've tried
Strategy:
- 利用广播带来外在差异
- 外差的绝对值
- 总和超过阈值
(np.abs(x[:, None] - x) <= d).sum(-1)
[5 2 1 2 5 1 5 5 5 1]
这很好.但是,它无法扩展.外部差异为O(n ^ 2)时间.如何获得不随二次时间缩放的相同解决方案?
This works great. However, it doesn't scale. That outer difference is O(n^2) time. How can I get the same solution that doesn't scale with quadratic time?
推荐答案
这篇文章中列出了另外两个基于OP's answer post
.
Listed in this post are two more variants based on the searchsorted strategy
from OP's answer post
.
def pir3(a,d): # Short & less efficient
sidx = a.argsort()
p1 = a.searchsorted(a+d,'right',sorter=sidx)
p2 = a.searchsorted(a-d,sorter=sidx)
return p1 - p2
def pir4(a, d): # Long & more efficient
s = a.argsort()
y = np.empty(s.size,dtype=np.int64)
y[s] = np.arange(s.size)
a_ = a[s]
return (
a_.searchsorted(a_ + d, 'right')
- a_.searchsorted(a_ - d)
)[y]
更有效的方法是从 this post
中获取s.argsort()
的有效思路.
The more efficient approach derives the efficient idea to get s.argsort()
from this post
.
运行时测试-
In [155]: # Inputs
...: a = np.random.randint(0,1000000,(10000))
...: d = 10
In [156]: %timeit pir2(a,d) #@ piRSquared's post solution
...: %timeit pir3(a,d)
...: %timeit pir4(a,d)
...:
100 loops, best of 3: 2.43 ms per loop
100 loops, best of 3: 4.44 ms per loop
1000 loops, best of 3: 1.66 ms per loop
这篇关于计算numpy数组中每个其他元素的增量内有多少个元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!