class Foo(object): def __radd__(self,lhs): return 0 __array_priority__ = 100a = np.random.random((100,100))b = Foo()a + b # calls b.__radd__(a) -> 0
然而,同样的事情似乎不适用于比较运算符.例如,如果我将以下行添加到Foo,那么它永远不会从表达式a< b:
def __rlt__(self,lhs): return 0
我意识到__rlt__并不是真正的Python特殊名称,但我认为它可能有用.我尝试了所有的__lt __,__ le __,__ eq __,__ ne__,__ ge __,__ gt__,有或没有前面的r,加上__cmp__,但我永远无法让NumPy调用它们中的任何一个.
这些比较可以被覆盖吗?
UPDATE
为了避免混淆,这里有一个更长的描述NumPy的行为.对于初学者来说,这是NumPy指南中所说的内容:
If the ufunc has 2 inputs and 1 output and the second input is an Object arraythen a special-case check is performed so that NotImplemented is returned if thesecond input is not an ndarray,has the array priority attribute,and has anr<op> special method.
我认为这是制定工作的规则.这是一个例子:
import numpy as npa = np.random.random((2,2))class bar0(object): def __add__(self,rhs): return 0 def __radd__(self,rhs): return 1b = bar0()print a + b # Calls __radd__ four times,returns an array# [[1 1]# [1 1]]class bar1(object): def __add__(self,rhs): return 1 __array_priority__ = 100b = bar1()print a + b # Calls __radd__ once,returns 1# 1
如您所见,在没有__array_priority__的情况下,NumPy将用户定义的对象解释为标量类型,并在数组中的每个位置应用该 *** 作.那不是我想要的.我的类型是数组(但不应该从ndarray派生).
这是一个较长的示例,显示了在定义所有比较方法时如何失败:
class Foo(object): def __cmp__(self,rhs): return 0 def __lt__(self,rhs): return 1 def __le__(self,rhs): return 2 def __eq__(self,rhs): return 3 def __ne__(self,rhs): return 4 def __gt__(self,rhs): return 5 def __ge__(self,rhs): return 6 __array_priority__ = 100b = Foo()print a < b # Calls __cmp__ four times,returns an array# [[False False]# [False False]]解决方法 看起来我自己可以回答这个问题. np.set_numeric_ops可以使用如下:
class Foo(object): def __lt__(self,rhs): return 0 def __le__(self,rhs): return 1 def __eq__(self,rhs): return 2 def __ne__(self,rhs): return 3 def __gt__(self,rhs): return 4 def __ge__(self,rhs): return 5 __array_priority__ = 100def overrIDe(name): def ufunc(x,y): if isinstance(y,Foo): return NotImplemented return np.getattr(name)(x,y) return ufuncnp.set_numeric_ops( ** { ufunc : overrIDe(ufunc) for ufunc in ( "less","less_equal","equal","not_equal","greater_equal","greater" ) } )a = np.random.random((2,2))b = Foo()print a < b# 4总结
以上是内存溢出为你收集整理的python – 如何覆盖NumPy的ndarray和我的类型之间的比较?全部内容,希望文章能够帮你解决python – 如何覆盖NumPy的ndarray和我的类型之间的比较?所遇到的程序开发问题。
如果觉得内存溢出网站内容还不错,欢迎将内存溢出网站推荐给程序员好友。
欢迎分享,转载请注明来源:内存溢出
评论列表(0条)