From 3631331e29190b14437fb61d322d8a1d3a2207fe Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Wed, 16 Jun 2021 15:49:11 +0200
Subject: [PATCH] build

---
 CHANGELOG.md                                  |  5 ++
 README.md                                     |  2 +-
 satella/__init__.py                           |  2 +-
 satella/coding/concurrent/thread.py           | 63 ++++++++++++-------
 .../instrumentation/cpu_time/concurrency.py   | 40 +++++++++---
 tests/test_coding/test_concurrent.py          | 22 +++++++
 tests/test_instrumentation/test_cpu_time.py   | 11 ++++
 7 files changed, 110 insertions(+), 35 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 0942fa7d..19397ac0 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1 +1,6 @@
 # v2.17.11
+
+* **bugfix** exceptions in `TerminableThread` that have definde 
+`terminate_on` won't be swallowed anymore.
+* added support for `terminate_on` to `IntervalTerminableThread`
+    and `CPUTimeAwareIntervalTerminableThread`
diff --git a/README.md b/README.md
index 68b70b93..df1a14eb 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
 satella
 ========
-[![Build Status](https://travis-ci.org/piotrmaslanka/satella.svg)](https://travis-ci.org/piotrmaslanka/satella)
+[![Build Status](https://travis-ci.com/piotrmaslanka/satella.svg)](https://travis-ci.com/piotrmaslanka/satella)
 [![Test Coverage](https://api.codeclimate.com/v1/badges/34b392b61482d98ad3f0/test_coverage)](https://codeclimate.com/github/piotrmaslanka/satella/test_coverage)
 [![Code Climate](https://codeclimate.com/github/piotrmaslanka/satella/badges/gpa.svg)](https://codeclimate.com/github/piotrmaslanka/satella)
 [![Issue Count](https://codeclimate.com/github/piotrmaslanka/satella/badges/issue_count.svg)](https://codeclimate.com/github/piotrmaslanka/satella)
diff --git a/satella/__init__.py b/satella/__init__.py
index 07c3fc31..1848a06a 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.17.11a1'
+__version__ = '2.17.11a2'
diff --git a/satella/coding/concurrent/thread.py b/satella/coding/concurrent/thread.py
index 9ca21855..fb7c7fc6 100644
--- a/satella/coding/concurrent/thread.py
+++ b/satella/coding/concurrent/thread.py
@@ -255,6 +255,11 @@ class TerminableThread(threading.Thread):
 
     If prepare() throws one of the terminate_on exceptions,
     :meth:`~satella.coding.concurrent.TerminableThread.loop` even won't be called.
+    However, :meth:`~satella.coding.concurrent.TerminableThread.terminate` will be automatically
+    called then.
+
+    Same applies for :class:`~satella.coding.concurrent.IntervalTerminableThread` and
+    :class:`~satella.instrumentation.cpu_time.CPUTimeAwareIntervalTerminableThread`.
     """
 
     def __init__(self, *args, terminate_on: tp.Optional[ExceptionList] = None,
@@ -272,12 +277,6 @@ class TerminableThread(threading.Thread):
         """
         super().__init__(*args, **kwargs)
         self._terminating = False  # type: bool
-        if terminate_on is None:
-            terminate_on = (SystemExit,)
-        elif isinstance(terminate_on, tuple):
-            terminate_on = (SystemExit, *terminate_on)
-        else:
-            terminate_on = (SystemExit, terminate_on)
         self._terminate_on = terminate_on
 
     @property
@@ -324,8 +323,8 @@ class TerminableThread(threading.Thread):
                 if self._terminate_on is not None:
                     if isinstance(e, self._terminate_on):
                         self.terminate()
-                else:
-                    raise
+                        return
+                raise
 
             while not self._terminating:
                 try:
@@ -334,8 +333,8 @@ class TerminableThread(threading.Thread):
                     if self._terminate_on is not None:
                         if isinstance(e, self._terminate_on):
                             self.terminate()
-                    else:
-                        raise
+                            return
+                    raise
         except SystemExit:
             pass
         finally:
@@ -486,16 +485,34 @@ class IntervalTerminableThread(TerminableThread, metaclass=ABCMeta):
     def run(self):
         from satella.time.measure import measure
 
-        self.prepare()
-        while not self._terminating:
-            with measure() as measurement:
-                self.loop()
-            if self._terminating:
-                break
-            time_taken = measurement()
-            time_to_sleep = self.seconds - time_taken
-            if time_to_sleep < 0:
-                self.on_overrun(time_taken)
-            else:
-                self.safe_sleep(time_to_sleep)
-        self.cleanup()
+        try:
+            try:
+                self.prepare()
+            except Exception as e:
+                if self._terminate_on is not None:
+                    if isinstance(e, self._terminate_on):
+                        self.terminate()
+                        return
+                raise
+            while not self._terminating:
+                with measure() as measurement:
+                    try:
+                        self.loop()
+                    except Exception as e:
+                        if self._terminate_on is not None:
+                            if isinstance(e, self._terminate_on):
+                                self.terminate()
+                                return
+                        raise
+                if self._terminating:
+                    break
+                time_taken = measurement()
+                time_to_sleep = self.seconds - time_taken
+                if time_to_sleep < 0:
+                    self.on_overrun(time_taken)
+                else:
+                    self.safe_sleep(time_to_sleep)
+        except SystemExit:
+            pass
+        finally:
+            self.cleanup()
diff --git a/satella/instrumentation/cpu_time/concurrency.py b/satella/instrumentation/cpu_time/concurrency.py
index 26f7438c..af45c57a 100644
--- a/satella/instrumentation/cpu_time/concurrency.py
+++ b/satella/instrumentation/cpu_time/concurrency.py
@@ -20,6 +20,9 @@ class CPUTimeAwareIntervalTerminableThread(IntervalTerminableThread, metaclass=A
     :param percentile: percentile that CPU usage has to fall below to call it earlier.
     :param wakeup_interval: amount of seconds to wake up between to check for _terminating status.
         Can be also a time string
+
+    Same concerns for :code:`terminate_on` as in
+    :class:`~satella.coding.concurrent.TerminableThread` apply.
     """
 
     def __init__(self, seconds: tp.Union[str, float], max_sooner: tp.Optional[float] = None,
@@ -63,13 +66,30 @@ class CPUTimeAwareIntervalTerminableThread(IntervalTerminableThread, metaclass=A
         self.__sleep_waiting_for_cpu(how_long)
 
     def run(self):
-        self.prepare()
-        while not self.terminating:
-            measured = self._execute_measured()
-            seconds_to_wait = self.seconds - measured
-            if seconds_to_wait > 0:
-                self.__sleep(seconds_to_wait)
-            elif seconds_to_wait < 0:
-                self.on_overrun(measured)
-
-        self.cleanup()
+        try:
+            try:
+                self.prepare()
+            except Exception as e:
+                if self._terminate_on is not None:
+                    if isinstance(e, self._terminate_on):
+                        self.terminate()
+                        return
+                raise
+            while not self.terminating:
+                try:
+                    measured = self._execute_measured()
+                except Exception as e:
+                    if self._terminate_on is not None:
+                        if isinstance(e, self._terminate_on):
+                            self.terminate()
+                            return
+                    raise
+                seconds_to_wait = self.seconds - measured
+                if seconds_to_wait > 0:
+                    self.__sleep(seconds_to_wait)
+                elif seconds_to_wait < 0:
+                    self.on_overrun(measured)
+        except SystemExit:
+            pass
+        finally:
+            self.cleanup()
diff --git a/tests/test_coding/test_concurrent.py b/tests/test_coding/test_concurrent.py
index d13d8d87..3112a9e8 100644
--- a/tests/test_coding/test_concurrent.py
+++ b/tests/test_coding/test_concurrent.py
@@ -701,6 +701,28 @@ class TestConcurrent(unittest.TestCase):
         self.assertTrue(mtt.overrun)
         mtt.terminate().join()
 
+    def test_interval_terminable_thread_terminates(self):
+        class MyTerminableThread(IntervalTerminableThread):
+            def __init__(self, a):
+                super().__init__(1, terminate_on=ValueError)
+                self.a = a
+                self.overrun = False
+
+            def prepare(self) -> None:
+                if self.a:
+                    raise ValueError()
+
+            def loop(self) -> None:
+                if not self.a:
+                    raise ValueError()
+
+        mtt_a = MyTerminableThread(True)
+        mtt_b = MyTerminableThread(False)
+        mtt_a.start()
+        mtt_b.start()
+        mtt_a.terminate().join()
+        mtt_b.terminate().join()
+
     @unittest.skipIf(platform.python_implementation() != 'PyPy', 'this requires PyPy')
     def test_terminable_thread_force_notimplementederror(self):
         class MyTerminableThread(TerminableThread):
diff --git a/tests/test_instrumentation/test_cpu_time.py b/tests/test_instrumentation/test_cpu_time.py
index d1afb658..b3e6b77a 100644
--- a/tests/test_instrumentation/test_cpu_time.py
+++ b/tests/test_instrumentation/test_cpu_time.py
@@ -6,6 +6,17 @@ from satella.instrumentation.cpu_time import calculate_occupancy_factor, sleep_c
 
 
 class TestCPUTime(unittest.TestCase):
+    def test_cpu_time_aware_terminable_thread_terminates(self):
+        class TestingThread(CPUTimeAwareIntervalTerminableThread):
+            def __init__(self):
+                super().__init__('5s', 3, 0.5, terminate_on=ValueError)
+                self.a = 0
+
+            def loop(self) -> None:
+                raise ValueError()
+
+        TestingThread().start().terminate().join()
+
     def test_cpu_time_aware_terminable_thread(self):
         class TestingThread(CPUTimeAwareIntervalTerminableThread):
             def __init__(self):
-- 
GitLab