From 4e0cb4368006f80678a21e44bb00dbe7ff8fee58 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Tue, 23 Feb 2021 14:31:13 +0100
Subject: [PATCH] add safe_wait_condition, v2.14.43

---
 CHANGELOG.md                         |  1 +
 satella/__init__.py                  |  2 +-
 satella/coding/concurrent/thread.py  | 35 ++++++++++++++++++++++++++++
 tests/test_coding/test_concurrent.py | 23 ++++++++++++++++++
 4 files changed, 60 insertions(+), 1 deletion(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index a377372c..27aca0df 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,4 @@
 # v2.14.43
 
 * bug fixed for `SequentialIssuer`
+* added `TerminableThread.safe_wait_condition`
diff --git a/satella/__init__.py b/satella/__init__.py
index 6d124c8e..c8b678b7 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.14.43a2'
+__version__ = '2.14.43'
diff --git a/satella/coding/concurrent/thread.py b/satella/coding/concurrent/thread.py
index 94521e15..3c35ce2c 100644
--- a/satella/coding/concurrent/thread.py
+++ b/satella/coding/concurrent/thread.py
@@ -213,9 +213,16 @@ class TerminableThread(threading.Thread):
             swallow it and terminate the thread by calling
             :meth:`~satella.coding.concurrent.TerminableThread.terminate`. Note that the
             subclass check will be done via `isinstance` so you can use the metaclass magic :)
+            Note that SystemExit will be automatically added to list of terminable exceptions.
         """
         super().__init__(*args, **kwargs)
         self._terminating = False  # type: bool
+        if terminate_on is None:
+            terminate_on = (SystemExit, )
+        elif isinstance(terminate_on, tuple):
+            terminate_on = (SystemExit, *terminate_on)
+        else:
+            terminate_on = (SystemExit, terminate_on)
         self._terminate_on = terminate_on
 
     @property
@@ -287,6 +294,32 @@ class TerminableThread(threading.Thread):
         self.terminate().join()
         return False
 
+    def safe_wait_condition(self, condition: Condition, timeout: float,
+                            wake_up_each: float = 2) -> None:
+        """
+        Wait for a condition, checking periodically if the thread is being terminated.
+
+        To be invoked only by the thread that's represented by the object!
+
+        :param condition: condition to wait on
+        :param timeout: maximum time to wait
+        :param wake_up_each: amount of seconds to wake up each to check for termination
+        :raises WouldWaitMore: timeout has passed and Condition has not happened
+        :raises SystemExit: thread is terminating
+        """
+        t = 0
+        while t < timeout:
+            if self._terminating:
+                raise SystemExit()
+            ttw = min(timeout-t, wake_up_each)
+            t += ttw
+            try:
+                condition.wait(ttw)
+                return
+            except WouldWaitMore:
+                pass
+        raise WouldWaitMore()
+
     def safe_sleep(self, interval: float, wake_up_each: float = 2) -> None:
         """
         Sleep for interval, waking up each wake_up_each seconds to check if terminating,
@@ -294,6 +327,8 @@ class TerminableThread(threading.Thread):
 
         This will do *the right thing* when passed a negative interval.
 
+        To be invoked only by the thread that's represented by the object!
+
         :param interval: Time to sleep in total
         :param wake_up_each: Amount of seconds to wake up each
         :raises SystemExit: thread is terminating
diff --git a/tests/test_coding/test_concurrent.py b/tests/test_coding/test_concurrent.py
index a12f0906..f0b26caf 100644
--- a/tests/test_coding/test_concurrent.py
+++ b/tests/test_coding/test_concurrent.py
@@ -433,6 +433,29 @@ class TestConcurrent(unittest.TestCase):
         time.sleep(0.1)
         self.assertTrue(dct['a'])
 
+    def test_terminablethread_condition(self):
+        a = {'dct': False}
+        condition = Condition()
+        slf = self
+
+        class MyThread(TerminableThread):
+            def __init__(self):
+                super().__init__()
+
+            def run(self) -> None:
+                nonlocal a, slf, condition
+                self.safe_wait_condition(condition, 5)
+                a['dct'] = True
+                slf.assertRaises(WouldWaitMore, lambda: self.safe_wait_condition(condition, 3))
+
+        t = MyThread().start()
+        time.sleep(0.2)
+        self.assertTrue(t.is_alive())
+        self.assertFalse(a['dct'])
+        condition.notify()
+        time.sleep(0.1)
+        self.assertTrue(a['dct'])
+
     def test_terminate_on(self):
         dct = {'a': False}
 
-- 
GitLab