如何覆盖NumPy的ndarray和我的类型之间的比较? [英] How can I override comparisons between NumPy's ndarray and my type?

查看:140
本文介绍了如何覆盖NumPy的ndarray和我的类型之间的比较?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

在NumPy中,可以使用__array_priority__属性来控制作用于ndarray和用户定义类型的二进制运算符.例如:

In NumPy, it is possible to use the __array_priority__ attribute to take control of binary operators acting on an ndarray and a user-defined type. For instance:

class Foo(object):
  def __radd__(self, lhs): return 0
  __array_priority__ = 100

a = np.random.random((100,100))
b = Foo()
a + b # calls b.__radd__(a) -> 0

但是,对于比较运算符来说,这似乎并不起作用.例如,如果我将以下行添加到Foo,则永远不会从表达式a < b中调用它:

The same thing, however, doesn't appear to work for comparison operators. For instance, if I add the following line to Foo, then it is never called from the expression a < b:

def __rlt__(self, lhs): return 0

我意识到__rlt__并不是一个真正的Python特殊名称,但是我认为它可能有用.我尝试了所有__lt____le____eq____ne____ge____gt__,并且也没有前面的r__cmp__,但是我无法将NumPy设置为打电话给他们中的任何一个.

I realize that __rlt__ is not really a Python special name, but I thought it might work. I tried all of __lt__, __le__, __eq__, __ne__, __ge__, __gt__ with and without a preceding r, plus __cmp__, too, but I could never get NumPy to call any of them.

这些比较可以覆盖吗?

为避免混淆,这是对NumPy行为的详细描述.首先,这是《 NumPy指南》书中所说的内容:

To avoid confusion, here is a longer description NumPy's behavior. For starters, here's what the Guide to NumPy book says:

If the ufunc has 2 inputs and 1 output and the second input is an Object array
then a special-case check is performed so that NotImplemented is returned if the
second input is not an ndarray, has the array priority attribute, and has an
r<op> special method.

我认为这是使+起作用的规则.这是一个示例:

I think this is the rule that makes + work. Here's an example:

import numpy as np
a = np.random.random((2,2))


class Bar0(object):
  def __add__(self, rhs): return 0
  def __radd__(self, rhs): return 1

b = Bar0()
print a + b # Calls __radd__ four times, returns an array
# [[1 1]
#  [1 1]]



class Bar1(object):
  def __add__(self, rhs): return 0
  def __radd__(self, rhs): return 1
  __array_priority__ = 100

b = Bar1()
print a + b # Calls __radd__ once, returns 1
# 1

如您所见,

在没有__array_priority__的情况下,NumPy会将用户定义的对象解释为标量类型,并将该操作应用于数组中的每个位置.那不是我想要的我的类型类似于数组(但不应从ndarray派生).

As you can see, without __array_priority__, NumPy interprets the user-defined object as a scalar type, and applies the operation at every position in the array. That's not what I want. My type is array-like (but should not be derived from ndarray).

下面是一个更长的示例,显示了在定义所有比较方法后如何失败:

Here's a longer example showing how this fails when all of the comparison methods are defined:

class Foo(object):
  def __cmp__(self, rhs): return 0
  def __lt__(self, rhs): return 1
  def __le__(self, rhs): return 2
  def __eq__(self, rhs): return 3
  def __ne__(self, rhs): return 4
  def __gt__(self, rhs): return 5
  def __ge__(self, rhs): return 6
  __array_priority__ = 100

b = Foo()
print a < b # Calls __cmp__ four times, returns an array
# [[False False]
#  [False False]]

推荐答案

看来我可以自己回答这个问题. np.set_numeric_ops可以按如下方式使用:

It looks like I can answer this myself. np.set_numeric_ops can be used as follows:

class Foo(object):
  def __lt__(self, rhs): return 0
  def __le__(self, rhs): return 1
  def __eq__(self, rhs): return 2
  def __ne__(self, rhs): return 3
  def __gt__(self, rhs): return 4
  def __ge__(self, rhs): return 5
  __array_priority__ = 100

def override(name):
  def ufunc(x,y):
    if isinstance(y,Foo): return NotImplemented
    return np.getattr(name)(x,y)
  return ufunc

np.set_numeric_ops(
    ** {
        ufunc : override(ufunc) for ufunc in (
            "less", "less_equal", "equal", "not_equal", "greater_equal"
          , "greater"
          )
    }
  )

a = np.random.random((2,2))
b = Foo()
print a < b
# 4

这篇关于如何覆盖NumPy的ndarray和我的类型之间的比较?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆