From b0112abaf4eaf17e0bf645a194c31c7fa40fb226 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Mon, 15 Feb 2021 16:57:57 +0100
Subject: [PATCH] added `call_on_failure` and `call_on_success` to `retry`

---
 CHANGELOG.md                           |  2 ++
 satella/__init__.py                    |  2 +-
 satella/coding/decorators/retry_dec.py | 17 ++++++++++++++---
 tests/test_coding/test_decorators.py   | 17 +++++++++++++++--
 4 files changed, 32 insertions(+), 6 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 096de83c..cc3036c1 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1 +1,3 @@
 # v2.14.37
+
+* added `call_on_failure` and `call_on_success` to `retry`
diff --git a/satella/__init__.py b/satella/__init__.py
index 65418c7b..bf4b7052 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.14.37a1'
+__version__ = '2.14.37'
diff --git a/satella/coding/decorators/retry_dec.py b/satella/coding/decorators/retry_dec.py
index 40a169b8..7cccdd74 100644
--- a/satella/coding/decorators/retry_dec.py
+++ b/satella/coding/decorators/retry_dec.py
@@ -8,7 +8,9 @@ from satella.coding.typing import ExceptionClassType
 def retry(times: tp.Optional[int] = None,
           exc_classes: tp.Union[ExceptionClassType, tp.Tuple[ExceptionClassType, ...]] = Exception,
           on_failure: tp.Callable[[Exception], None] = lambda e: None,
-          swallow_exception: bool = True):
+          swallow_exception: bool = True,
+          call_on_failure: tp.Optional[tp.Callable[[Exception], None]] = None,
+          call_on_success: tp.Optional[tp.Callable[[int], None]] = None):
     """
     A decorator retrying given operation, failing it when an exception shows up.
 
@@ -34,6 +36,10 @@ def retry(times: tp.Optional[int] = None,
         with a single argument, exception instance that was raised last. That exception will
         be swallowed, unless swallow_exception is set to False
     :param swallow_exception: the last exception will be swallowed, unless this is set to False
+    :param call_on_failure: a callable that will be called upon failing to do this, with an
+        exception as it's sole argument. It's result will be discarded.
+    :param call_on_success: a callable that will be called with a single argument: the number
+        of retries that it took to finish the job. It's result will be discarded.
     :return: function result
     """
     def outer(fun):
@@ -43,14 +49,19 @@ def retry(times: tp.Optional[int] = None,
                 iterator = itertools.count()
             else:
                 iterator = range(times)
-            for _ in iterator:
+            for i in iterator:
                 try:
-                    return fun(*args, **kwargs)
+                    y = fun(*args, **kwargs)
+                    if call_on_success is not None:
+                        call_on_success(i)
+                    return y
                 except exc_classes as e:
                     f = e
                     continue
             else:
                 on_failure(f)
+                if call_on_failure is not None:
+                    call_on_failure(f)
                 if not swallow_exception:
                     raise f
         return inner
diff --git a/tests/test_coding/test_decorators.py b/tests/test_coding/test_decorators.py
index 03a09337..e6200f63 100644
--- a/tests/test_coding/test_decorators.py
+++ b/tests/test_coding/test_decorators.py
@@ -77,17 +77,30 @@ class TestDecorators(unittest.TestCase):
         self.assertEqual(test(), [2, 3, None, 4])
 
     def test_retry(self):
-        a = {'test': 0, 'limit': 2}
+        a = {'test': 0, 'limit': 2, 'true': False, 'false': False}
 
-        @retry(3, ValueError, swallow_exception=False)
+        def on_failure(e):
+            nonlocal a
+            a['true'] = True
+
+        def on_success(retries):
+            nonlocal a
+            a['false'] = True
+
+        @retry(3, ValueError, swallow_exception=False, call_on_failure=on_failure,
+               call_on_success=on_success)
         def do_op():
             a['test'] += 1
             if a['test'] < a['limit']:
                 raise ValueError()
 
         do_op()
+        self.assertTrue(a['false'])
         a['limit'] = 10
+        a['false'] = False
         self.assertRaises(ValueError, do_op)
+        self.assertTrue(a['true'])
+        self.assertFalse(a['false'])
 
     def test_replace_argument_if(self):
         @replace_argument_if('y', x.int(), str)
-- 
GitLab