From 06493193827790e35c062a3dff14c35f1d7ce4f7 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Tue, 28 Apr 2020 15:32:29 +0200
Subject: [PATCH] extended Proxy

---
 CHANGELOG.md                         |  2 +-
 satella/__init__.py                  |  2 +-
 satella/coding/structures/proxy.py   | 45 ++++++++++++++++++++++------
 tests/test_coding/test_structures.py |  9 +++++-
 4 files changed, 46 insertions(+), 12 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 3de8e23c..ef13c26c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,6 +1,6 @@
 # v2.7.16
 
-* _TBA_
+* extended `Proxy`
 
 # v2.7.15
 
diff --git a/satella/__init__.py b/satella/__init__.py
index 23e24119..61ad89d8 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.7.16_a2'
+__version__ = '2.7.16_a3'
diff --git a/satella/coding/structures/proxy.py b/satella/coding/structures/proxy.py
index 9ff50c61..6408e8bc 100644
--- a/satella/coding/structures/proxy.py
+++ b/satella/coding/structures/proxy.py
@@ -14,11 +14,20 @@ class Proxy(tp.Generic[T]):
 
     Note that in-place operations will return the Proxy itself, whereas simple addition will shed
     this proxy, returning object wrapped plus something.
+
+    :param object_to_wrap: object to wrap
+    :param wrap_operations: whether results of operations returning something else should be
+        also proxied. This will be done by the following code:
+        >>> a = a.__add__(b)
+        >>> return self.__class__(a)
+        Wrapped operations include ONLY add, sub, mul, all kinds of div.
+        If you want logical operations wrapped, file an issue.
     """
-    __slots__ = ('__obj',)
+    __slots__ = ('__obj', '__wrap_operations')
 
-    def __init__(self, object_to_wrap: T):
+    def __init__(self, object_to_wrap: T, wrap_operations: bool = False):
         self.__obj = object_to_wrap  # type: T
+        self.__wrap_operations = wrap_operations
 
     def __call__(self, *args, **kwargs):
         return self.__obj(*args, **kwargs)
@@ -33,7 +42,7 @@ class Proxy(tp.Generic[T]):
         del self.__obj[key]
 
     def __setattr__(self, key, value):
-        if key in ('_Proxy__obj', ):
+        if key in ('_Proxy__obj', '_Proxy__wrap_operations'):
             super().__setattr__(key, value)
         else:
             setattr(self.__obj, key, value)
@@ -57,30 +66,48 @@ class Proxy(tp.Generic[T]):
         return str(self.__obj)
 
     def __add__(self, other):
-        return self.__obj + other
+        result = self.__obj + other
+        if self.__wrap_operations:
+            result = self.__class__(result)
+        return result
 
     def __iadd__(self, other):
         self.__obj += other
         return self
 
     def __sub__(self, other):
-        return self.__obj - other
+        result = self.__obj - other
+        if self.__wrap_operations:
+            result = self.__class__(result)
+        return result
 
     def __isub__(self, other):
         self.__obj -= other
         return self
 
     def __mul__(self, other):
-        return self.__obj * other
+        result = self.__obj * other
+        if self.__wrap_operations:
+            result = self.__class__(result)
+        return result
 
     def __divmod__(self, other):
-        return divmod(self.__obj, other)
+        result = divmod(self.__obj, other)
+        if self.__wrap_operations:
+            result = self.__class__(result)
+        return result
 
     def __floordiv__(self, other):
-        return self.__obj // other
+        result = self.__obj // other
+        if self.__wrap_operations:
+            result = self.__class__(result)
+        return result
 
     def __truediv__(self, other):
-        return self.__obj / other
+        result = self.__obj / other
+        if self.__wrap_operations:
+            result = self.__class__(result)
+        return result
 
     def __imul__(self, other):
         self.__obj * other
diff --git a/tests/test_coding/test_structures.py b/tests/test_coding/test_structures.py
index cb83fb6b..01c6b124 100644
--- a/tests/test_coding/test_structures.py
+++ b/tests/test_coding/test_structures.py
@@ -9,12 +9,19 @@ import mock
 from satella.coding.structures import TimeBasedHeap, Heap, typednamedtuple, \
     OmniHashableMixin, DictObject, apply_dict_object, Immutable, frozendict, SetHeap, \
     DictionaryView, HashableWrapper, TwoWayDictionary, Ranking, SortedList, SliceableDeque, \
-    DirtyDict, KeyAwareDefaultDict
+    DirtyDict, KeyAwareDefaultDict, Proxy
 
 logger = logging.getLogger(__name__)
 
 
 class TestMisc(unittest.TestCase):
+    def test_proxy(self):
+        a = Proxy(5, wrap_operations=True)
+        self.assertIsInstance(a+5, Proxy)
+
+        a = Proxy(5)
+        self.assertNotIsInstance(a+5, Proxy)
+
     def test_key_aware_defaultdict(self):
         a = KeyAwareDefaultDict(int)
         self.assertEqual(a['1'], 1)
-- 
GitLab