diff --git a/satella/__init__.py b/satella/__init__.py index f2acad5b2f0f280230e59b8d2146e7db129a325c..35e2f7b1de279e60a882bf5037c28908e500f7cb 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.14.41a1' +__version__ = '2.14.41' diff --git a/satella/coding/concurrent/thread.py b/satella/coding/concurrent/thread.py index d67ef2c5b2a4b53326c425772b07958770667e9e..0158013d2f8b934d0795531229789a7b84e81f27 100644 --- a/satella/coding/concurrent/thread.py +++ b/satella/coding/concurrent/thread.py @@ -10,6 +10,7 @@ from threading import Condition as PythonCondition from satella.coding.decorators import wraps from satella.time import measure +from ..typing import ExceptionList from ...exceptions import ResourceLocked, WouldWaitMore @@ -201,13 +202,21 @@ class TerminableThread(threading.Thread): >>> self.assertFalse(a.is_alive()) """ - def __init__(self, *args, **kwargs): + def __init__(self, *args, terminate_on: tp.Optional[ExceptionList] = None, + **kwargs): """ Note that this is called in the constructor's thread. Use .prepare() to run statements that should be ran in new thread. + + :param terminate_on: if provided, and + :meth:`~satella.coding.concurrent.TerminableThread.loop` throws one of it, + swallow it and terminate the thread by calling + :meth:`~satella.coding.concurrent.TerminableThread.terminate`. Note that the + subclass check will be done via `isinstance` so you can use the metaclass magic :) """ super().__init__(*args, **kwargs) self._terminating = False # type: bool + self._terminate_on = terminate_on @property def terminating(self) -> bool: @@ -228,6 +237,10 @@ class TerminableThread(threading.Thread): This should block for as long as a single check will take, as termination checks take place between calls. + + Note that if it throws one of the exceptions given in `terminate_on` this thread will + terminate cleanly, whereas if it throws something else, the thread will be terminated with + a traceback. """ def start(self) -> 'TerminableThread': @@ -245,7 +258,14 @@ class TerminableThread(threading.Thread): try: self.prepare() while not self._terminating: - self.loop() + try: + self.loop() + except Exception as e: + if self._terminate_on is not None: + if isinstance(e, self._terminate_on): + self.terminate() + else: + raise except SystemExit: pass finally: diff --git a/tests/test_coding/test_concurrent.py b/tests/test_coding/test_concurrent.py index 6ecb665f99b6e5ccd6c776e72e46b6a28f9e56ea..c3bc04f8f3896e7b6a3f04e90083be7ac246cda8 100644 --- a/tests/test_coding/test_concurrent.py +++ b/tests/test_coding/test_concurrent.py @@ -433,6 +433,26 @@ class TestConcurrent(unittest.TestCase): time.sleep(0.1) self.assertTrue(dct['a']) + def test_terminate_on(self): + dct = {'a': False} + + class MyThread(TerminableThread): + def __init__(self): + super().__init__(terminate_on=ValueError) + + def loop(self) -> None: + nonlocal dct + if dct['a']: + raise ValueError() + + t = MyThread().start() + self.assertTrue(t.is_alive()) + time.sleep(1) + self.assertTrue(t.is_alive()) + dct['a'] = True + time.sleep(1) + self.assertFalse(t.is_alive()) + def test_cg_proforma(self): cg = CallableGroup() a = {}