diff --git a/CHANGELOG.md b/CHANGELOG.md index d61037f2018044dc8006f69da64b83c1d987361c..e8e94410dfda9584f2fa9fb5f89d17eb91716f07 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # v2.26.1 -* _TBA_ +* added run_when_generator_completes and RunActionAfterGeneratorCompletes # v2.26.0 diff --git a/docs/coding/sequences.rst b/docs/coding/sequences.rst index d0e9f8becc405dd0f36bb13c62f1d9812c9989ff..ab964020a92836a57dee24f52560b16eedaeacc4 100644 --- a/docs/coding/sequences.rst +++ b/docs/coding/sequences.rst @@ -1,6 +1,14 @@ Sequences and iterators ####################### +Generators +========== + +.. autoclass:: satella.coding.RunActionAfterGeneratorCompletes + :members: + +.. autoclass:: satella.coding. + Rolling averages ================ diff --git a/docs/conf.py b/docs/conf.py index b4e638a50cb59a0783d39c59ec9ffe157fc49c3e..6726251cef2bd15440ca9cb687675b441b4b9716 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -151,3 +151,11 @@ texinfo_documents = [ author, 'satella', 'One line description of project.', 'Miscellaneous'), ] + +autoclass_content = 'both' + +autodoc_default_options = { + 'members': True, +} +autodoc_typehints = "description" +autoclass_content = 'both' diff --git a/satella/__init__.py b/satella/__init__.py index 468444d3b671c0409e2431f5d56e7aa7e8216fdf..acb6e2fcce6d933808f9f703c3213a6127defaf6 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.26.1a1' +__version__ = '2.26.1' diff --git a/satella/coding/__init__.py b/satella/coding/__init__.py index 6226b1c031661e6e117212f4cfea214291cc5cec..3c8acc60e4dd1ff9a71e083d043f5ca20a6bb2d3 100644 --- a/satella/coding/__init__.py +++ b/satella/coding/__init__.py @@ -23,8 +23,10 @@ from .misc import update_if_not_none, update_key_if_none, update_attr_if_none, q from .overloading import overload, class_or_instancemethod, TypeSignature from .recast_exceptions import rethrow_as, silence_excs, catch_exception, log_exceptions, \ raises_exception, reraise_as +from .generators import RunActionAfterGeneratorCompletes, run_when_generator_completes __all__ = [ + 'RunActionAfterGeneratorCompletes', 'run_when_generator_completes', 'EmptyContextManager', 'Context', 'length', 'assert_equal', 'InequalityReason', 'Inequal', 'wrap_callable_in_context_manager', 'Closeable', 'contains', 'enum_value', diff --git a/satella/coding/generators.py b/satella/coding/generators.py new file mode 100644 index 0000000000000000000000000000000000000000..256206969f1c31412f4acc2f2ff13c1dc608b801 --- /dev/null +++ b/satella/coding/generators.py @@ -0,0 +1,57 @@ +import typing as tp +from abc import ABCMeta, abstractmethod + + +class RunActionAfterGeneratorCompletes(metaclass=ABCMeta): + """ + Run an action after a generator completes. + An abstract class. + """ + + __slots__ = 'generator', 'args', 'kwargs' + + def __init__(self, generator: tp.Generator, *args, **kwargs): + """ + :param generator: generator to watch for + :param args: arguments to invoke action_to_run with + :param kwargs: keyword arguments to invoke action_to_run with + """ + self.generator = generator + self.args = args + self.kwargs = kwargs + + def send(self, value): + """Send a value to the generator""" + self.generator.send(value) + + def __iter__(self): + return self + + def __next__(self): + try: + return next(self.generator) + except StopIteration: + self.action_to_run(*self.args, **self.kwargs) + raise + + @abstractmethod + def action_to_run(self): + """This will run when this generator completes. Override it.""" + + +def run_when_generator_completes(gen: tp.Generator, call_on_done: tp.Callable[[], None], + *args, **kwargs) -> tp.Generator: + """ + Return the generator with call_on_done to be called on when it finishes + + :param gen: generator + :param call_on_done: callable/0 to call on generator's completion + :param args: args to pass to the callable + :param kwargs: kwargs to pass to the callable + :returns: generator + """ + class Inner(RunActionAfterGeneratorCompletes): + def action_to_run(self, *args, **kwargs): + call_on_done(*args, **kwargs) + + return Inner(gen, *args, **kwargs) diff --git a/tests/test_coding/test_iterators.py b/tests/test_coding/test_iterators.py index 28b2d74e5bf927df2507d27641c6729f3b71f174..c472327a7d7db57acb042889b36d86f99330f5d2 100644 --- a/tests/test_coding/test_iterators.py +++ b/tests/test_coding/test_iterators.py @@ -1,7 +1,7 @@ import sys import unittest -from satella.coding import SelfClosingGenerator, hint_with_length, chain +from satella.coding import SelfClosingGenerator, hint_with_length, chain, run_when_generator_completes from satella.coding.sequences import smart_enumerate, ConstruableIterator, walk, \ IteratorListAdapter, is_empty, ListWrapperIterator @@ -16,6 +16,26 @@ def iterate(): class TestIterators(unittest.TestCase): + def test_run_when_generator_completes(self): + called = False + + def generator(): + yield 1 + yield 2 + yield 3 + + def mark_done(f): + assert f == 2 + nonlocal called + called = True + + gen = run_when_generator_completes(generator(), mark_done, 2) + a = next(gen) + self.assertFalse(called) + for i in gen: + pass + self.assertTrue(called) + def test_list_wrapper_iterator_contains(self): lwe = ListWrapperIterator(iterate())