diff --git a/satella/coding/structures/mixins/hashable.py b/satella/coding/structures/mixins/hashable.py index 5ae8fc2a888b724014c7c8a104bee219cdf694c9..ff1c4106604fa9659b8455b9e212e6c1471f0dd7 100644 --- a/satella/coding/structures/mixins/hashable.py +++ b/satella/coding/structures/mixins/hashable.py @@ -1,3 +1,4 @@ +import operator import typing as tp from abc import ABCMeta, abstractmethod @@ -69,7 +70,7 @@ class ComparableAndHashableBy(metaclass=ABCMeta): class ComparableAndHashableByInt(metaclass=ABCMeta): """ - A mix-in. Provides comparision (lt, gt, ge, le, eq) and hashing by __int__ of this class. + A mix-in. Provides comparison (lt, gt, ge, le, eq) and hashing by __int__ of this class. """ __slots__ = () @@ -180,40 +181,27 @@ class OmniHashableMixin(metaclass=ABCMeta): """ Note that this will only compare _HASH_FIELDS_TO_USE """ - if not isinstance(other, type(self)): - return False - - if not isinstance(other, OmniHashableMixin): - return super().__eq__(other) - - cmpr_by = self._HASH_FIELDS_TO_USE - try: - if isinstance(cmpr_by, str): - return getattr(self, cmpr_by) == getattr(other, cmpr_by) - - for field_name in self._HASH_FIELDS_TO_USE: - if getattr(self, field_name) != getattr(other, field_name): - return False - return True - except AttributeError: - return False + return _generic_eq(self, other, False, operator.eq, 'eq', ) def __ne__(self, other) -> bool: - if not isinstance(other, type(self)): - return True + return _generic_eq(self, other, True, operator.ne, 'ne') - if not isinstance(other, OmniHashableMixin): - return super().__ne__(other) - cmpr_by = self._HASH_FIELDS_TO_USE +def _generic_eq(self, other, truth, comparator, name): + if not isinstance(other, type(self)): + return truth + + if not isinstance(other, OmniHashableMixin): + return comparator(self, other) - try: - if isinstance(cmpr_by, str): - return getattr(self, cmpr_by) != getattr(other, cmpr_by) + cmpr_by = self._HASH_FIELDS_TO_USE + try: + if isinstance(cmpr_by, str): + return comparator(getattr(self, cmpr_by), getattr(other, cmpr_by)) - for field_name in cmpr_by: - if getattr(self, field_name) != getattr(other, field_name): - return True - return False - except AttributeError: - return True + for field_name in cmpr_by: + if getattr(self, field_name) != getattr(other, field_name): + return truth + return not truth + except AttributeError: + return truth diff --git a/tests/test_coding/test_structures.py b/tests/test_coding/test_structures.py index f796e3692bfbe3166dd3d33f130a4c308558b27c..0fa0be560e1018474330cd5310168949ad1f80b1 100644 --- a/tests/test_coding/test_structures.py +++ b/tests/test_coding/test_structures.py @@ -273,6 +273,9 @@ class TestStructures(unittest.TestCase): b = MyClass(1) c = MyClass(2) self.assertEqual(a, b) + self.assertEqual(a, 1) + self.assertGreater(a, 0) + self.assertLess(c, 1) self.assertNotEqual(a, c) self.assertLess(a, c) self.assertGreaterEqual(a, b)