diff --git a/satella/coding/generators.py b/satella/coding/generators.py index 3b834c824b35c97cc2f6f45910abe3d80adae8ce..2e96f53c98c97f6ca50ac30c2a7abdf4c5ff6b22 100644 --- a/satella/coding/generators.py +++ b/satella/coding/generators.py @@ -76,7 +76,7 @@ class RunActionAfterGeneratorCompletes(tp.Generator, metaclass=ABCMeta): """This will run when this generator throws any exception. Override it.""" -def run_when_generator_completes(gen: tp.Generator, call_on_done: tp.Callable +def run_when_generator_completes(gen: tp.Generator, call_on_done: tp.Callable, *args, **kwargs) -> RunActionAfterGeneratorCompletes: """ Return the generator with call_on_done to be called on when it finishes diff --git a/tests/test_coding/test_iterators.py b/tests/test_coding/test_iterators.py index 82459e3c2108403a9e8581265be0a22c639e77cf..73852f58adebc4b748cce40f215718532ae908f0 100644 --- a/tests/test_coding/test_iterators.py +++ b/tests/test_coding/test_iterators.py @@ -140,7 +140,7 @@ class TestIterators(unittest.TestCase): self.assertFalse(called) - def test_run_when_generator_closed_failure(self): + def test_run_when_generator_failure(self): called = False def generator(): @@ -161,10 +161,9 @@ class TestIterators(unittest.TestCase): called = True gen = Inner(generator()) - a = next(gen) - gen.close() - self.assertRaises(StopIteration, next, gen) - self.assertFalse(called) + for i in gen: + pass + self.assertTrue(called) def test_list_wrapper_iterator_contains(self):