快速的numpy卷 [英] Fast numpy roll

查看:109
本文介绍了快速的numpy卷的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个2d的numpy数组,我想以增量方式滚动每一行.我在for循环中使用np.roll来执行此操作.但是由于我已经调用了数千次,所以我的代码确实很慢.您能帮我加快速度吗?

I have a 2d numpy array and I want to roll each row in an incremental fashion. I am using np.roll in a for loop to do so. But since I am calling this thousands of times, my code is really slow. Can you please help me out on how to make it faster.

我的输入看起来像

array([[4,1],
       [0,2]])

我的输出看起来像

array([[4,1],
       [2,0]])

这里第0行[4,1]移位了0,而第一行[0,2]移位了1.类似地,第二行也移位了2,依此类推.

Here the zeroth row [4,1] was shifted by 0, and the first row [0,2] was shifted by 1. Similarly the second row will be shifted by 2 and so on.

编辑

temp = np.zeros([dd,dd])
for i in range(min(t + 1, dd)):
    temp[i,:] = np.roll(y[i,:], i, axis=0)

推荐答案

这里是一个矢量化解决方案-

Here's one vectorized solution -

m,n = a.shape
idx = np.mod((n-1)*np.arange(m)[:,None] + np.arange(n), n)
out = a[np.arange(m)[:,None], idx]

样本输入,输出-

In [256]: a
Out[256]: 
array([[73, 55, 79, 52, 15],
       [45, 11, 19, 93, 12],
       [78, 50, 30, 88, 53],
       [98, 13, 58, 34, 35]])

In [257]: out
Out[257]: 
array([[73, 55, 79, 52, 15],
       [12, 45, 11, 19, 93],
       [88, 53, 78, 50, 30],
       [58, 34, 35, 98, 13]])

既然如此,您已经提到要多次调用这样的滚动例程,请一次创建索引数组idx,以后再使用它.

Since, you have mentioned that you are calling such a rolling routine multiple times, create the indexing array idx once and re-use it later on.

进一步的改进

对于重复使用,最好创建完整的线性索引,然后使用np.take提取滚动元素,就像这样-

For repeated usages, you are better off creating the full linear indices and then using np.take to extract the rolled elements, like so -

full_idx = idx + n*np.arange(m)[:,None]
out = np.take(a,full_idx)

让我们看看有什么改进-

Let's see what's the improvement like -

In [330]: a = np.random.randint(11,99,(600,600))

In [331]: m,n = a.shape
     ...: idx = np.mod((n-1)*np.arange(m)[:,None] + np.arange(n), n)
     ...: 

In [332]: full_idx = idx + n*np.arange(m)[:,None]

In [333]: %timeit a[np.arange(m)[:,None], idx] # Approach #1
1000 loops, best of 3: 1.42 ms per loop

In [334]: %timeit np.take(a,full_idx)          # Improvement
1000 loops, best of 3: 486 µs per loop

围绕 3x 进行了改进!

Around 3x improvement there!

这篇关于快速的numpy卷的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆