Skip to content
Snippets Groups Projects
Commit 87687b31 authored by Piotr Maślanka's avatar Piotr Maślanka
Browse files

improve OmniHashableMixin

parent ddd42940
No related branches found
No related tags found
No related merge requests found
import operator
import typing as tp
from abc import ABCMeta, abstractmethod
......@@ -69,7 +70,7 @@ class ComparableAndHashableBy(metaclass=ABCMeta):
class ComparableAndHashableByInt(metaclass=ABCMeta):
"""
A mix-in. Provides comparision (lt, gt, ge, le, eq) and hashing by __int__ of this class.
A mix-in. Provides comparison (lt, gt, ge, le, eq) and hashing by __int__ of this class.
"""
__slots__ = ()
......@@ -180,40 +181,27 @@ class OmniHashableMixin(metaclass=ABCMeta):
"""
Note that this will only compare _HASH_FIELDS_TO_USE
"""
if not isinstance(other, type(self)):
return False
if not isinstance(other, OmniHashableMixin):
return super().__eq__(other)
cmpr_by = self._HASH_FIELDS_TO_USE
try:
if isinstance(cmpr_by, str):
return getattr(self, cmpr_by) == getattr(other, cmpr_by)
for field_name in self._HASH_FIELDS_TO_USE:
if getattr(self, field_name) != getattr(other, field_name):
return False
return True
except AttributeError:
return False
return _generic_eq(self, other, False, operator.eq, 'eq', )
def __ne__(self, other) -> bool:
if not isinstance(other, type(self)):
return True
return _generic_eq(self, other, True, operator.ne, 'ne')
if not isinstance(other, OmniHashableMixin):
return super().__ne__(other)
cmpr_by = self._HASH_FIELDS_TO_USE
def _generic_eq(self, other, truth, comparator, name):
if not isinstance(other, type(self)):
return truth
if not isinstance(other, OmniHashableMixin):
return comparator(self, other)
try:
if isinstance(cmpr_by, str):
return getattr(self, cmpr_by) != getattr(other, cmpr_by)
cmpr_by = self._HASH_FIELDS_TO_USE
try:
if isinstance(cmpr_by, str):
return comparator(getattr(self, cmpr_by), getattr(other, cmpr_by))
for field_name in cmpr_by:
if getattr(self, field_name) != getattr(other, field_name):
return True
return False
except AttributeError:
return True
for field_name in cmpr_by:
if getattr(self, field_name) != getattr(other, field_name):
return truth
return not truth
except AttributeError:
return truth
......@@ -273,6 +273,9 @@ class TestStructures(unittest.TestCase):
b = MyClass(1)
c = MyClass(2)
self.assertEqual(a, b)
self.assertEqual(a, 1)
self.assertGreater(a, 0)
self.assertLess(c, 1)
self.assertNotEqual(a, c)
self.assertLess(a, c)
self.assertGreaterEqual(a, b)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment