Friday, July 13, 2007

How to override comparison operators in Python

Python, like many languages, allows the behavior of operators to be
customized using a scheme based on the types of objects they are applied to.
The precise rules and intricacies of this customization are fairly involved,
though, and most people are unaware of their full scope. While it is sometimes
valuable to be able to control the behavior of an operator to the full extent
supported by Python, quite often the complexity which this results in spills
over into simpler applications. This is visible as a general tendency on the
part of Python programmers to implement customizations which are correct for
the narrow case which they have in mind at the moment, but are incorrect when
considered in a broader context. Since many parts of the runtime and standard
library rely on the behavior of these operators, this is a somewhat more
egregious than the case of a similar offense made in an application-specific
method, where the author can simply claim that behavior beyond what was
intended is unsupported and behaves in an undefined manner.



So, with my long-winded introduction out of the way, here are the basic
rules for the customization of ==, !=, <, >, <=, and >=:





  • For all six of the above operators, if __cmp__ is defined on
    the left-hand argument, it is called with the right-hand argument. A result
    of -1 indicates the LHS is less than the RHS. A result of 0 indicates they
    are equal. A result of 1 indicates the LHS is greater than the RHS.


  • For ==, if __eq__ is defined on the left-hand argument, it
    is called with the right hand argument. A result of True indicates the
    objects are equal. A result of False indicates they are not equal. A result
    of NotImplemented indicates that the left-hand argument doesn't
    know how to test for equality with the given right-hand argument.
    __eq__ is not used for !=.


  • For !=, the special method __ne__ is used. The rules for
    its behavior are similar to those of __eq__, with the obvious
    adjustments.


  • For <, __lt__ is used. For >, __gt__.
    For <= and >=, __le__ and __ge__
    respectively.




So how should these be applied? This is best explained with an example.
While __cmp__ is often useful, I am going to ignore it for the
rest of this post, since it is easier to get right, particularly once
NotImplemented (which I will talk about) is understood.




class A(object):
def __init__(self, foo):
self.foo = foo
def __eq__(self, other):
if isinstance(other, A):
return self.foo == other.foo
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result


That's it (because I'm not going to define the other four methods to make
<, >, <=, and >= work. They follow basically the same rules as
__eq__ and __ne__, though). Pretty straightforward,
but there are some points which are not always obvious:





  • __eq__ does an isinstance test on its argument. This lets
    it know if it is dealing with another object which is like itself. In the
    case of this example, I have implemented A to only know how to compare itself
    with other instances of A. If it is called with something which is not an A,
    it returns NotImplemented. I'll explain what the consequences
    of this are below.


  • __ne__ is also implemented, but only in terms of
    __eq__. If you implement __eq__ but not
    __ne__, then == and != will behave somewhat strangely, since the
    default implementation of __ne__ is based on identity, not the
    negation of equality. Quite often a class with only __eq__ will
    appear to work properly with !=, but it fails for various corner-cases (for
    example, an object which does not compare equal to itself, such as NaN).




The major remaining point is NotImplemented: what is that
thing? NotImplemented signals to the runtime that it should ask
someone else to satisfy the operation. In the expression a == b,
if a.__eq__(b) returns NotImplemented, then Python
tries b.__eq__(a). If b knows enough to return True or False,
then the expression can succeed. If it doesn't, then the runtime will fall
back to the built-in behavior (which is based on identity for == and !=).



Here's another class which customizes equality:




class B(object):
def __init__(self, bar):
self.bar = bar
def __eq__(self, other):
if isinstance(other, B):
return self.bar == other.bar
elif isinstance(other, A):
return self.bar + 3 == other.foo
else:
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result


Here we have a class which can compare instances of itself to both instances
itself and to instances of A. Now, what would happen if we weren't careful
about returning NotImplemented at the right times?



One way it might go is...




>>> class A(object):
... def __init__(self, foo):
... self.foo = foo
... def __eq__(self, other):
... return self.foo == other.foo
...
>>> class B(object):
... def __init__(self, bar):
... self.bar = bar
...
>>> A(5) == B(6)
Traceback (most recent call last):
File "<stdin>", line 1, in ?
File "<stdin>", line 5, in __eq__
AttributeError: 'B' object has no attribute 'foo'
>>>


Another way it could go is...




>>> class A(object):
... def __init__(self, foo):
... self.foo = foo
... def __eq__(self, other):
... if isinstance(other, A):
... return self.foo == other.foo
...
>>> class B(object):
... def __init__(self, bar):
... self.bar = bar
... def __eq__(self, other):
... if isinstance(other, A):
... return self.bar + 3 == other.foo
... else:
... return self.bar == other.bar
...
>>> print A(3) == B(0)
None
>>> print B(0) == A(3)
True
>>>


That one's particularly nasty. ;) But here's what we get with correct
NotImplemented use:




>>> class A(object):
... def __init__(self, foo):
... self.foo = foo
... def __eq__(self, other):
... if isinstance(other, A):
... return self.foo == other.foo
... return NotImplemented
...
>>> class B(object):
... def __init__(self, bar):
... self.bar = bar
... def __eq__(self, other):
... if isinstance(other, A):
... return self.bar + 3 == other.foo
... elif isinstance(other, B):
... return self.bar == other.bar
... else:
... return NotImplemented
...
>>> print A(3) == B(0)
True
>>> print B(0) == A(3)
True
>>>


Ahh, excellent. NotImplemented has uses for other operators in
Python as well. For example, if the + override, __add__, returns
it, then __radd__ is tried on the right-hand argument. These can
be useful as well, though equality and inequality are by far more common use
cases.



If you follow these examples, then in the general case you'll find yourself
with more consistently behaving objects. You may even want to implement a
mixin which provides the __ne__ implementation (and one of
__lt__ or __gt__), since it gets pretty boring typing
that out after a few times. ;)



Of course, there are plenty of special cases where it makes sense to deviate
from this pattern. However, they are special. For most objects, this
is the behavior you want.



You can read about all the gory details of Python's operator overloading
system on the Python website:
http://docs.python.org/ref/specialnames.html

5 comments:

  1. Thanks. I've corrected these two errors in the post.

    ReplyDelete
  2. Thanks for pointing this out. I almost made B a subclass of A in the original post, but I decided that would have made it too hard to follow. :) It's a good rule to know, though, since it means derived classes get the first chance to define how the operation is performed.

    ReplyDelete
  3. NotImplemented is returned, not raised? Weird.

    Since the __ne__ implementation is always so stupid, why not define __cmp__ instead? I suppose "Called by comparison operations if rich comparison is not defined." means that it could fail if you unwittingly have the other operations defined somewhere else in your inheritance tree?

    ReplyDelete
  4. __cmp__ is definitely much simpler. Pushing people to use it more frequently and avoid __xy__ might make another good post :) The one disadvantage of __cmp__ is that if you only care about == and !=, it's still very difficult to avoid implementing < and >. In fact, off the top of my head, I can't think of a __cmp__ implementation which would work for _only_ implementing == and !=. This might not be of much practical consequence (after all, object implements < and >, and it just does so in a pseudo-random way), but it's a bit annoying.

    ReplyDelete
  5. NotImplementedError and NotImplemented are different things; the former is an exception, but the latter is not, so raising NotImplemented wouldn't make much sense ;)

    ReplyDelete