numpy ndarray的子​​类无法按预期工作 [英] Subclass of numpy ndarray doesn't work as expected

查看:104
本文介绍了numpy ndarray的子​​类无法按预期工作的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

大家好.

我发现对ndarray进行子类化时有一种奇怪的行为.

import numpy as np

class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        return obj

    def __init__(self, *args, **kwargs):
        return

    def __array_finalize__(self, obj):
        return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2

如您所见,keepdims参数不适用于我的子类fooarray.它失去了一根轴.我怎样才能避免这个问题?或更笼统地说,我如何正确地将numpy ndarray继承为子类?

解决方案

np.sum可以接受各种对象作为输入:例如,不仅是ndarrays,而且包括列表,生成器,np.matrix. keepdims参数对于列表或生成器显然没有意义.这也不适用于np.matrix实例,因为np.matrix始终具有2维.如果查看np.matrix.sum的呼叫签名,则会看到其sum方法没有keepdims参数:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)

因此,某些ndarray的子类可能具有sum方法,而这些方法没有keepdims参数.这不幸地违反了 Liskov替代原则和您遇到的陷阱的起源.

现在,如果您查看 np.sum ,您会看到它是一个委托函数,它试图根据第一个参数的类型确定要执行的操作.

如果第一个参数的类型不是ndarray,它将删除keepdims参数.这样做是因为将keepdims参数传递给np.matrix.sum会引发异常.

因此,由于np.sum试图以最通用的方式进行委派,而不是对ndarray的子​​类可能采用的参数做任何假设,因此在传递fooarray时会丢弃keepdims参数.

解决方法是不使用np.sum,而是调用a.sum.无论如何,这是更直接的,因为np.sum仅仅是一个委托函数.

import numpy as np


class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array, *args, **kwargs).view(cls)
        return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim)  # 2
print(b_sum.ndim)  # 2

`Hello, everyone.

I found there is a strange behavior when subclassing a ndarray.

import numpy as np

class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        return obj

    def __init__(self, *args, **kwargs):
        return

    def __array_finalize__(self, obj):
        return

a=fooarray(np.random.randn(3,5))
b=np.random.randn(3,5)

a_sum=np.sum(a,axis=0,keepdims=True)
b_sum=np.sum(b,axis=0, keepdims=True)

print a_sum.ndim #1
print b_sum.ndim #2

As you have seen, the keepdims argument doesn't work for my subclass fooarray. It lost one of its axis. How can't I avoid this problem? Or more generally, how can I subclass numpy ndarray correctly?

解决方案

np.sum can accept a variety of objects as input: not only ndarrays, but also lists, generators, np.matrixs, for instance. The keepdims parameter obviously does not make sense for lists or generators. It is also not appropriate for np.matrix instances either, since np.matrixs always have 2 dimensions. If you look at the call signature for np.matrix.sum you see that its sum method has no keepdims parameter:

Definition: np.matrix.sum(self, axis=None, dtype=None, out=None)

So some subclasses of ndarray may have sum methods which do not have a keepdims parameter. This is an unfortunate violation of the Liskov substitution principle and the origin of the pitfall you encountered.

Now if you look at the source code for np.sum, you see that it is a delegating function which tries to determine what to do based on the type of the first argument.

If the type of the first argument is not ndarray, it drops the keepdims parameter. It does this because passing the keepdims parameter to np.matrix.sum would raise an exception.

So because np.sum is trying to do the delegation in the most general way, not making any assumption about what arguments a subclass of ndarray may take, it drops the keepdims parameter when passed a fooarray.

The workaround is to not use np.sum, but call a.sum instead. This is more direct anyway, since np.sum is merely a delegating function.

import numpy as np


class fooarray(np.ndarray):
    def __new__(cls, input_array, *args, **kwargs):
        obj = np.asarray(input_array, *args, **kwargs).view(cls)
        return obj

a = fooarray(np.random.randn(3, 5))
b = np.random.randn(3, 5)

a_sum = a.sum(axis=0, keepdims=True)
b_sum = np.sum(b, axis=0, keepdims=True)

print(a_sum.ndim)  # 2
print(b_sum.ndim)  # 2

这篇关于numpy ndarray的子​​类无法按预期工作的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆