From 819b77173307dc2a0524e054a0112ec058c9a33c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Tue, 23 Feb 2021 20:05:14 +0100
Subject: [PATCH] add ThreadCollection

---
 CHANGELOG.md                                  |  2 +
 docs/coding/concurrent.rst                    |  6 ++
 satella/__init__.py                           |  2 +-
 satella/coding/concurrent/__init__.py         |  3 +-
 .../coding/concurrent/thread_collection.py    | 61 +++++++++++++++++++
 tests/test_coding/test_concurrent.py          | 20 +++++-
 6 files changed, 91 insertions(+), 3 deletions(-)
 create mode 100644 satella/coding/concurrent/thread_collection.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index f4c334fa..f4c73255 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1 +1,3 @@
 # v2.14.45
+
+* add `ThreadCollection`
diff --git a/docs/coding/concurrent.rst b/docs/coding/concurrent.rst
index fef43b5f..ccdf924a 100644
--- a/docs/coding/concurrent.rst
+++ b/docs/coding/concurrent.rst
@@ -32,6 +32,12 @@ PeekableQueue
 .. autoclass:: satella.coding.concurrent.PeekableQueue
     :members:
 
+ThreadCollection
+================
+
+.. autoclass:: satella.coding.concurrent.ThreadCollection
+    :members:
+
 TerminableThread
 ================
 
diff --git a/satella/__init__.py b/satella/__init__.py
index e62b32c9..c27280e5 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.14.45a1'
+__version__ = '2.14.45'
diff --git a/satella/coding/concurrent/__init__.py b/satella/coding/concurrent/__init__.py
index 5d74c02d..a786dfcc 100644
--- a/satella/coding/concurrent/__init__.py
+++ b/satella/coding/concurrent/__init__.py
@@ -10,6 +10,7 @@ from .sync import sync_threadpool
 from .thread import TerminableThread, Condition, SingleStartThread, call_in_separate_thread, \
     BogusTerminableThread, IntervalTerminableThread
 from .timer import Timer
+from .thread_collection import ThreadCollection
 from .queue import PeekableQueue
 
 __all__ = ['LockedDataset', 'Monitor', 'RMonitor', 'CallableGroup', 'TerminableThread',
@@ -18,5 +19,5 @@ __all__ = ['LockedDataset', 'Monitor', 'RMonitor', 'CallableGroup', 'TerminableT
            'BogusTerminableThread', 'Timer', 'parallel_execute', 'run_as_future',
            'sync_threadpool', 'IntervalTerminableThread', 'Future',
            'WrappingFuture', 'InvalidStateError', 'PeekableQueue',
-           'CancellableCallback',
+           'CancellableCallback', 'ThreadCollection',
            'SequentialIssuer']
diff --git a/satella/coding/concurrent/thread_collection.py b/satella/coding/concurrent/thread_collection.py
new file mode 100644
index 00000000..98ed4fe8
--- /dev/null
+++ b/satella/coding/concurrent/thread_collection.py
@@ -0,0 +1,61 @@
+import typing as tp
+from threading import Thread
+
+
+class ThreadCollection:
+    """
+    A collection of threads.
+
+    Create like:
+
+    >>> class MyThread(Thread):
+    >>>     def __init__(self, a):
+    >>>         ...
+    >>> tc = ThreadCollection.from_class(MyThread, [2, 4, 5])
+    >>> tc.start()
+    >>> tc.terminate()
+    >>> tc.join()
+    """
+
+    __slots__ = ('threads', )
+
+    @classmethod
+    def from_class(cls, cls_to_use, iteratable) -> 'ThreadCollection':
+        """
+        Build a thread collection
+
+        :param cls_to_use: class to instantiate with
+        :param iteratable: an iterable with the sole argument to this class
+        """
+        return ThreadCollection([cls_to_use(it) for it in iteratable])
+
+    def __init__(self, threads: tp.List[Thread]):
+        self.threads = threads
+
+    def start(self):
+        """
+        Start all threads
+        """
+        for thread in self.threads:
+            thread.start()
+
+    def terminate(self, *args, **kwargs):
+        """
+        Call terminate() on all threads that have this method
+        """
+        for thread in self.threads:
+            try:
+                thread.terminate(*args, **kwargs)
+            except AttributeError:
+                pass
+
+    def join(self):
+        """Join all threads"""
+        for thread in self.threads:
+            thread.join()
+
+    def is_alive(self):
+        """
+        Is at least one thread alive?
+        """
+        return any(thread.is_alive() for thread in self.threads)
diff --git a/tests/test_coding/test_concurrent.py b/tests/test_coding/test_concurrent.py
index f0b26caf..b7de4f5c 100644
--- a/tests/test_coding/test_concurrent.py
+++ b/tests/test_coding/test_concurrent.py
@@ -10,7 +10,7 @@ from concurrent.futures import ThreadPoolExecutor, Future as PythonFuture
 from satella.coding.concurrent import TerminableThread, CallableGroup, Condition, MonitorList, \
     LockedStructure, AtomicNumber, Monitor, IDAllocator, call_in_separate_thread, Timer, \
     parallel_execute, run_as_future, sync_threadpool, IntervalTerminableThread, Future, \
-    WrappingFuture, PeekableQueue, SequentialIssuer, CancellableCallback
+    WrappingFuture, PeekableQueue, SequentialIssuer, CancellableCallback, ThreadCollection
 from satella.coding.concurrent.futures import call_in_future, ExecutorWrapper
 from satella.coding.sequences import unique
 from satella.exceptions import WouldWaitMore, AlreadyAllocated, Empty
@@ -18,6 +18,24 @@ from satella.exceptions import WouldWaitMore, AlreadyAllocated, Empty
 
 class TestConcurrent(unittest.TestCase):
 
+    def test_thread_collection(self):
+        dct = {}
+
+        class Threading(threading.Thread):
+            def __init__(self, a):
+                super().__init__()
+                self.a = a
+
+            def run(self):
+                nonlocal dct
+                dct[self.a] = True
+
+        tc = ThreadCollection.from_class(Threading, [2, 3, 4])
+        tc.start()
+        tc.terminate()
+        tc.join()
+        self.assertEqual(dct, {2: True, 3: True, 4: True})
+
     def test_cancellable_callback(self):
         a = {'test': True}
 
-- 
GitLab