diff --git a/CHANGELOG.md b/CHANGELOG.md index 096de83c672a56e6bdd48725cb5e936b61f28dd3..cc3036c1611e2694eeb81217843f9226bc7ac978 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,3 @@ # v2.14.37 + +* added `call_on_failure` and `call_on_success` to `retry` diff --git a/satella/__init__.py b/satella/__init__.py index 65418c7bdaca5fbc790dd9033b3b89ac49097530..bf4b705224e019fe756b979ea7eb42217e61a31b 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.14.37a1' +__version__ = '2.14.37' diff --git a/satella/coding/decorators/retry_dec.py b/satella/coding/decorators/retry_dec.py index 40a169b80c351f5576c008c0a6678c1cd5909263..7cccdd7485dfdca5794c391f6437f69c8da1e185 100644 --- a/satella/coding/decorators/retry_dec.py +++ b/satella/coding/decorators/retry_dec.py @@ -8,7 +8,9 @@ from satella.coding.typing import ExceptionClassType def retry(times: tp.Optional[int] = None, exc_classes: tp.Union[ExceptionClassType, tp.Tuple[ExceptionClassType, ...]] = Exception, on_failure: tp.Callable[[Exception], None] = lambda e: None, - swallow_exception: bool = True): + swallow_exception: bool = True, + call_on_failure: tp.Optional[tp.Callable[[Exception], None]] = None, + call_on_success: tp.Optional[tp.Callable[[int], None]] = None): """ A decorator retrying given operation, failing it when an exception shows up. @@ -34,6 +36,10 @@ def retry(times: tp.Optional[int] = None, with a single argument, exception instance that was raised last. That exception will be swallowed, unless swallow_exception is set to False :param swallow_exception: the last exception will be swallowed, unless this is set to False + :param call_on_failure: a callable that will be called upon failing to do this, with an + exception as it's sole argument. It's result will be discarded. + :param call_on_success: a callable that will be called with a single argument: the number + of retries that it took to finish the job. It's result will be discarded. :return: function result """ def outer(fun): @@ -43,14 +49,19 @@ def retry(times: tp.Optional[int] = None, iterator = itertools.count() else: iterator = range(times) - for _ in iterator: + for i in iterator: try: - return fun(*args, **kwargs) + y = fun(*args, **kwargs) + if call_on_success is not None: + call_on_success(i) + return y except exc_classes as e: f = e continue else: on_failure(f) + if call_on_failure is not None: + call_on_failure(f) if not swallow_exception: raise f return inner diff --git a/tests/test_coding/test_decorators.py b/tests/test_coding/test_decorators.py index 03a09337a3e4ab38ad0cfd1186e2c6fa14c6d6a8..e6200f63bbf436dbde63cf48c9faace37038db04 100644 --- a/tests/test_coding/test_decorators.py +++ b/tests/test_coding/test_decorators.py @@ -77,17 +77,30 @@ class TestDecorators(unittest.TestCase): self.assertEqual(test(), [2, 3, None, 4]) def test_retry(self): - a = {'test': 0, 'limit': 2} + a = {'test': 0, 'limit': 2, 'true': False, 'false': False} - @retry(3, ValueError, swallow_exception=False) + def on_failure(e): + nonlocal a + a['true'] = True + + def on_success(retries): + nonlocal a + a['false'] = True + + @retry(3, ValueError, swallow_exception=False, call_on_failure=on_failure, + call_on_success=on_success) def do_op(): a['test'] += 1 if a['test'] < a['limit']: raise ValueError() do_op() + self.assertTrue(a['false']) a['limit'] = 10 + a['false'] = False self.assertRaises(ValueError, do_op) + self.assertTrue(a['true']) + self.assertFalse(a['false']) def test_replace_argument_if(self): @replace_argument_if('y', x.int(), str)