diff --git a/CHANGELOG.md b/CHANGELOG.md index 3de8e23c9b6719ec1c7de822fa95a7dbd16dc773..ef13c26c6f4df724c2e0e77b4f3fb60689684c79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # v2.7.16 -* _TBA_ +* extended `Proxy` # v2.7.15 diff --git a/satella/__init__.py b/satella/__init__.py index 23e24119ad11437a6b6de387f6bc584eda8e1a80..61ad89d8930c52561233eee1a38f3233119771c4 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.7.16_a2' +__version__ = '2.7.16_a3' diff --git a/satella/coding/structures/proxy.py b/satella/coding/structures/proxy.py index 9ff50c61058893350ea28d494edc4969388d6d77..6408e8bcea2de7201c6a5a6565add2e31ed69699 100644 --- a/satella/coding/structures/proxy.py +++ b/satella/coding/structures/proxy.py @@ -14,11 +14,20 @@ class Proxy(tp.Generic[T]): Note that in-place operations will return the Proxy itself, whereas simple addition will shed this proxy, returning object wrapped plus something. + + :param object_to_wrap: object to wrap + :param wrap_operations: whether results of operations returning something else should be + also proxied. This will be done by the following code: + >>> a = a.__add__(b) + >>> return self.__class__(a) + Wrapped operations include ONLY add, sub, mul, all kinds of div. + If you want logical operations wrapped, file an issue. """ - __slots__ = ('__obj',) + __slots__ = ('__obj', '__wrap_operations') - def __init__(self, object_to_wrap: T): + def __init__(self, object_to_wrap: T, wrap_operations: bool = False): self.__obj = object_to_wrap # type: T + self.__wrap_operations = wrap_operations def __call__(self, *args, **kwargs): return self.__obj(*args, **kwargs) @@ -33,7 +42,7 @@ class Proxy(tp.Generic[T]): del self.__obj[key] def __setattr__(self, key, value): - if key in ('_Proxy__obj', ): + if key in ('_Proxy__obj', '_Proxy__wrap_operations'): super().__setattr__(key, value) else: setattr(self.__obj, key, value) @@ -57,30 +66,48 @@ class Proxy(tp.Generic[T]): return str(self.__obj) def __add__(self, other): - return self.__obj + other + result = self.__obj + other + if self.__wrap_operations: + result = self.__class__(result) + return result def __iadd__(self, other): self.__obj += other return self def __sub__(self, other): - return self.__obj - other + result = self.__obj - other + if self.__wrap_operations: + result = self.__class__(result) + return result def __isub__(self, other): self.__obj -= other return self def __mul__(self, other): - return self.__obj * other + result = self.__obj * other + if self.__wrap_operations: + result = self.__class__(result) + return result def __divmod__(self, other): - return divmod(self.__obj, other) + result = divmod(self.__obj, other) + if self.__wrap_operations: + result = self.__class__(result) + return result def __floordiv__(self, other): - return self.__obj // other + result = self.__obj // other + if self.__wrap_operations: + result = self.__class__(result) + return result def __truediv__(self, other): - return self.__obj / other + result = self.__obj / other + if self.__wrap_operations: + result = self.__class__(result) + return result def __imul__(self, other): self.__obj * other diff --git a/tests/test_coding/test_structures.py b/tests/test_coding/test_structures.py index cb83fb6b9061aef50a070f508d9fac28aa684779..01c6b124cf18c9567b4b11ae94d7e2cf280703f2 100644 --- a/tests/test_coding/test_structures.py +++ b/tests/test_coding/test_structures.py @@ -9,12 +9,19 @@ import mock from satella.coding.structures import TimeBasedHeap, Heap, typednamedtuple, \ OmniHashableMixin, DictObject, apply_dict_object, Immutable, frozendict, SetHeap, \ DictionaryView, HashableWrapper, TwoWayDictionary, Ranking, SortedList, SliceableDeque, \ - DirtyDict, KeyAwareDefaultDict + DirtyDict, KeyAwareDefaultDict, Proxy logger = logging.getLogger(__name__) class TestMisc(unittest.TestCase): + def test_proxy(self): + a = Proxy(5, wrap_operations=True) + self.assertIsInstance(a+5, Proxy) + + a = Proxy(5) + self.assertNotIsInstance(a+5, Proxy) + def test_key_aware_defaultdict(self): a = KeyAwareDefaultDict(int) self.assertEqual(a['1'], 1)