diff --git a/CHANGELOG.md b/CHANGELOG.md index 5420140222ee13fccd258d3c6fa00c2b159dcbe9..90b7a918bfe7fab79e0e0eaa7288c4944560b428 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,4 +3,4 @@ * more time-related calls will accept time strings * added optional `delay` argument to `call_in_separate_thread` * beefed up BogusTerminableThread docs -* added `ThreadCollection.add` +* added functionality to `ThreadCollection` diff --git a/satella/__init__.py b/satella/__init__.py index e6e7554c3527f2a9de0d73a564c911aeeb11b131..d85f422e6312a93dd185083c4dfe505885c32a19 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.14.47a3' +__version__ = '2.14.47a4' diff --git a/satella/coding/concurrent/thread_collection.py b/satella/coding/concurrent/thread_collection.py index 0867811ccb7380bd03f2b88f77f650f8605fedf7..e138022b8fa434374b1630329f0b305f3be22237 100644 --- a/satella/coding/concurrent/thread_collection.py +++ b/satella/coding/concurrent/thread_collection.py @@ -1,3 +1,4 @@ +import threading import typing as tp from threading import Thread @@ -15,10 +16,36 @@ class ThreadCollection: >>> tc.start() >>> tc.terminate() >>> tc.join() + + This also implements iteration (it will return all the threads in the collection) and + length check. """ __slots__ = ('threads', ) + def __len__(self): + return len(self.threads) + + def __iter__(self): + return iter(self.threads) + + @classmethod + def get_currently_running(cls, include_main_thread: bool = True) -> 'ThreadCollection': + """ + Get all currently running threads as thread collection + + :param include_main_thread: whether to include the main thread + + :return: a thread collection representing all currently running threads + """ + result = [] + for thread in threading.enumerate(): + # noinspection PyProtectedMember + if not include_main_thread and isinstance(thread, threading._MainThread): + continue + result.append(thread) + return ThreadCollection(result) + @classmethod def from_class(cls, cls_to_use, iteratable, **kwargs) -> 'ThreadCollection': """ @@ -46,6 +73,14 @@ class ThreadCollection: def __init__(self, threads: tp.Sequence[Thread]): self.threads = list(threads) + def append(self, thread: Thread) -> None: + """ + Alias for :meth:`~satella.coding.concurrent.ThreadCollection.add` + + :param thread: thread to add + """ + self.add(thread) + def add(self, thread: Thread) -> None: """ Add a thread to the collection