剪切一个numpy数组 [英] Shear a numpy array

查看:124
本文介绍了剪切一个numpy数组的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想剪切"一个numpy数组.我不确定我是否正确使用了剪切"一词;通过剪切,我的意思是:

I'd like to 'shear' a numpy array. I'm not sure I'm using the term 'shear' correctly; by shear, I mean something like:

将第一列移动0位
将第二列移动1个地方
将第三列移动2个地方
等等...

Shift the first column by 0 places
Shift the second column by 1 place
Shift the third colum by 2 places
etc...

所以这个数组:

array([[11, 12, 13],
       [17, 18, 19],
       [35, 36, 37]])

将变成这个数组:

array([[11, 36, 19],
       [17, 12, 37],
       [35, 18, 13]])

或类似此数组的内容:

array([[11,  0,  0],
       [17, 12,  0],
       [35, 18, 13]])

取决于我们处理边缘的方式.我对边缘行为不太在意.

depending on how we handle the edges. I'm not too particular about edge behavior.

这是我尝试执行此操作的功能:

Here's my attempt at a function that does this:

import numpy

def shear(a, strength=1, shift_axis=0, increase_axis=1, edges='clip'):
    strength = int(strength)
    shift_axis = int(shift_axis)
    increase_axis = int(increase_axis)
    if shift_axis == increase_axis:
        raise UserWarning("Shear can't shift in the direction it increases")
    temp = numpy.zeros(a.shape, dtype=int)
    indices = []
    for d, num in enumerate(a.shape):
        coords = numpy.arange(num)
        shape = [1] * len(a.shape)
        shape[d] = num
        coords = coords.reshape(shape) + temp
        indices.append(coords)
    indices[shift_axis] -= strength * indices[increase_axis]
    if edges == 'clip':
        indices[shift_axis][indices[shift_axis] < 0] = -1
        indices[shift_axis][indices[shift_axis] >= a.shape[shift_axis]] = -1
        res = a[indices]
        res[indices[shift_axis] == -1] = 0
    elif edges == 'roll':
        indices[shift_axis] %= a.shape[shift_axis]
        res = a[indices]
    return res

if __name__ == '__main__':
    a = numpy.random.random((3,4))
    print a
    print shear(a)

似乎可行.请告诉我是否可以!

It seems to work. Please tell me if it doesn't!

它似乎也笨拙且不雅.我是否忽略了执行此操作的内置numpy/scipy函数?有没有更干净/更好/更有效的方式来做到这一点?我是在重新发明轮子吗?

It also seems clunky and inelegant. Am I overlooking a builtin numpy/scipy function that does this? Is there a cleaner/better/more efficient way to do this in numpy? Am I reinventing the wheel?


如果这适用于N维数组,而不仅仅是2D情况,则有加分.


Bonus points if this works on an N-dimensional array, instead of just the 2D case.

此功能将位于循环的最中心,我将在数据处理中重复多次,因此我怀疑它实际上值得优化.

This function will be at the very center of a loop I'll repeat many times in our data processing, so I suspect it's actually worth optimizing.

第二 我终于做了一些基准测试.尽管有循环,但看起来numpy.roll是要走的路.谢谢,tom10和Sven Marnach!

SECOND I finally did some benchmarking. It looks like numpy.roll is the way to go, despite the loop. Thanks, tom10 and Sven Marnach!

基准测试代码:(在Windows上运行,我认为在Linux上不要使用time.clock)

Benchmarking code: (run on Windows, don't use time.clock on Linux I think)

import time, numpy

def shear_1(a, strength=1, shift_axis=0, increase_axis=1, edges='roll'):
    strength = int(strength)
    shift_axis = int(shift_axis)
    increase_axis = int(increase_axis)
    if shift_axis == increase_axis:
        raise UserWarning("Shear can't shift in the direction it increases")
    temp = numpy.zeros(a.shape, dtype=int)
    indices = []
    for d, num in enumerate(a.shape):
        coords = numpy.arange(num)
        shape = [1] * len(a.shape)
        shape[d] = num
        coords = coords.reshape(shape) + temp
        indices.append(coords)
    indices[shift_axis] -= strength * indices[increase_axis]
    if edges == 'clip':
        indices[shift_axis][indices[shift_axis] < 0] = -1
        indices[shift_axis][indices[shift_axis] >= a.shape[shift_axis]] = -1
        res = a[indices]
        res[indices[shift_axis] == -1] = 0
    elif edges == 'roll':
        indices[shift_axis] %= a.shape[shift_axis]
        res = a[indices]
    return res

def shear_2(a, strength=1, shift_axis=0, increase_axis=1, edges='roll'):
    indices = numpy.indices(a.shape)
    indices[shift_axis] -= strength * indices[increase_axis]
    indices[shift_axis] %= a.shape[shift_axis]
    res = a[tuple(indices)]
    if edges == 'clip':
        res[indices[shift_axis] < 0] = 0
        res[indices[shift_axis] >= a.shape[shift_axis]] = 0
    return res

def shear_3(a, strength=1, shift_axis=0, increase_axis=1):
    if shift_axis > increase_axis:
        shift_axis -= 1
    res = numpy.empty_like(a)
    index = numpy.index_exp[:] * increase_axis
    roll = numpy.roll
    for i in range(0, a.shape[increase_axis]):
        index_i = index + (i,)
        res[index_i] = roll(a[index_i], i * strength, shift_axis)
    return res

numpy.random.seed(0)
for a in (
    numpy.random.random((3, 3, 3, 3)),
    numpy.random.random((50, 50, 50, 50)),
    numpy.random.random((300, 300, 10, 10)),
    ):
    print 'Array dimensions:', a.shape
    for sa, ia in ((0, 1), (1, 0), (2, 3), (0, 3)):
        print 'Shift axis:', sa
        print 'Increase axis:', ia
        ref = shear_1(a, shift_axis=sa, increase_axis=ia)
        for shear, label in ((shear_1, '1'), (shear_2, '2'), (shear_3, '3')):
            start = time.clock()
            b = shear(a, shift_axis=sa, increase_axis=ia)
            end = time.clock()
            print label + ': %0.6f seconds'%(end-start)
            if (b - ref).max() > 1e-9:
                print "Something's wrong."
        print

推荐答案

tom10的答案可以扩展为任意尺寸:

The approach in tom10's answer can be extended to arbitrary dimensions:

def shear3(a, strength=1, shift_axis=0, increase_axis=1):
    if shift_axis > increase_axis:
        shift_axis -= 1
    res = numpy.empty_like(a)
    index = numpy.index_exp[:] * increase_axis
    roll = numpy.roll
    for i in range(0, a.shape[increase_axis]):
        index_i = index + (i,)
        res[index_i] = roll(a[index_i], -i * strength, shift_axis)
    return res

这篇关于剪切一个numpy数组的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆