diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e94410dfda9584f2fa9fb5f89d17eb91716f07..b4a0d71604c1badcbbdfa0fa5ff1225d7e74f972 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# v2.26.2 + +* RunActionAfterGeneratorCompletes won't call it's on_done action if closed prematurely +* more complete support for generators in RunActionAfterGeneratorCompletes + # v2.26.1 * added run_when_generator_completes and RunActionAfterGeneratorCompletes diff --git a/satella/__init__.py b/satella/__init__.py index acb6e2fcce6d933808f9f703c3213a6127defaf6..b9d4f6d2329e918767d04d086b16b04850acbb2a 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.26.1' +__version__ = '2.26.2' diff --git a/satella/coding/generators.py b/satella/coding/generators.py index 256206969f1c31412f4acc2f2ff13c1dc608b801..fd29de6aef85ce0d154aba7ae9e9de88808e5012 100644 --- a/satella/coding/generators.py +++ b/satella/coding/generators.py @@ -2,13 +2,17 @@ import typing as tp from abc import ABCMeta, abstractmethod -class RunActionAfterGeneratorCompletes(metaclass=ABCMeta): + +class RunActionAfterGeneratorCompletes(tp.Generator, metaclass=ABCMeta): """ Run an action after a generator completes. An abstract class. + + Please note that this routine will be called only when the generator completes. If you abort it prematurely, + via close() """ - __slots__ = 'generator', 'args', 'kwargs' + __slots__ = 'generator', 'args', 'kwargs', 'closed' def __init__(self, generator: tp.Generator, *args, **kwargs): """ @@ -16,22 +20,34 @@ class RunActionAfterGeneratorCompletes(metaclass=ABCMeta): :param args: arguments to invoke action_to_run with :param kwargs: keyword arguments to invoke action_to_run with """ + self.closed = False self.generator = generator self.args = args self.kwargs = kwargs + def close(self): + self.closed = True + self.generator.close() + def send(self, value): """Send a value to the generator""" - self.generator.send(value) + return self.generator.send(value) + + def next(self): + return self.generator.__next__() def __iter__(self): return self + def throw(self, __typ, __val=None, __tb=None): + return self.generator.throw(__typ, __val, __tb) + def __next__(self): try: - return next(self.generator) + return self.generator.__next__() except StopIteration: - self.action_to_run(*self.args, **self.kwargs) + if not self.closed: + self.action_to_run(*self.args, **self.kwargs) raise @abstractmethod @@ -40,7 +56,7 @@ class RunActionAfterGeneratorCompletes(metaclass=ABCMeta): def run_when_generator_completes(gen: tp.Generator, call_on_done: tp.Callable[[], None], - *args, **kwargs) -> tp.Generator: + *args, **kwargs) -> RunActionAfterGeneratorCompletes: """ Return the generator with call_on_done to be called on when it finishes diff --git a/tests/test_coding/test_iterators.py b/tests/test_coding/test_iterators.py index c472327a7d7db57acb042889b36d86f99330f5d2..b1afe729373aa8a6e2f221d9db42dd3d9febe781 100644 --- a/tests/test_coding/test_iterators.py +++ b/tests/test_coding/test_iterators.py @@ -1,11 +1,16 @@ +import typing as tp import sys +import logging import unittest -from satella.coding import SelfClosingGenerator, hint_with_length, chain, run_when_generator_completes +from satella.coding import SelfClosingGenerator, hint_with_length, chain, run_when_generator_completes, typing from satella.coding.sequences import smart_enumerate, ConstruableIterator, walk, \ IteratorListAdapter, is_empty, ListWrapperIterator +logger = logging.getLogger(__name__) + + def iterate(): yield 1 yield 2 @@ -16,6 +21,35 @@ def iterate(): class TestIterators(unittest.TestCase): + def test_run_when_generator_completes_2(self): + called = False + + def generator(): + print('Starting generator') + c = yield 1 + assert c == 2 + print('Starting generator') + yield 2 + yield 2 + yield 3 + + def mark_done(f): + assert f == 2 + nonlocal called + called = True + + gen = run_when_generator_completes(generator(), mark_done, 2) + a = next(gen) + self.assertFalse(called) + self.assertEqual(a, 1) + b = gen.send(2) + self.assertEqual(b, 2) + self.assertIsInstance(gen, tp.Generator) + self.assertFalse(called) + for i in gen: + pass + self.assertTrue(called) + def test_run_when_generator_completes(self): called = False @@ -31,11 +65,33 @@ class TestIterators(unittest.TestCase): gen = run_when_generator_completes(generator(), mark_done, 2) a = next(gen) + self.assertIsInstance(gen, tp.Generator) self.assertFalse(called) - for i in gen: + def inner(): + yield from gen + for i in inner(): pass self.assertTrue(called) + def test_run_when_generator_closed(self): + called = False + + def generator(): + yield 1 + yield 2 + yield 3 + + def mark_done(f): + assert f == 2 + nonlocal called + called = True + + gen = run_when_generator_completes(generator(), mark_done, 2) + a = next(gen) + gen.close() + self.assertRaises(StopIteration, next, gen) + self.assertFalse(called) + def test_list_wrapper_iterator_contains(self): lwe = ListWrapperIterator(iterate())