Skip to content
Snippets Groups Projects
Commit 95053017 authored by Piotr Maślanka's avatar Piotr Maślanka
Browse files

added `timeout` to `sync_threadpool`

parent ed9412ba
No related branches found
Tags v2.12.4
No related merge requests found
# v2.12.4
* `tags_factory` in `trace_function` can be also a dict
* added `timeout` to `sync_threadpool`
__version__ = '2.12.4_a2'
__version__ = '2.12.4'
import typing as tp
import time
from concurrent.futures import wait, ThreadPoolExecutor
from .atomic import AtomicNumber
from .thread import Condition
from ...exceptions import WouldWaitMore
from ...time import measure
def sync_threadpool(tpe: ThreadPoolExecutor) -> None:
def sync_threadpool(tpe: ThreadPoolExecutor,
max_wait: tp.Optional[float] = None) -> None:
"""
Make sure that every thread of given thread pool executor is done processing
jobs scheduled until this moment.
......@@ -13,25 +17,42 @@ def sync_threadpool(tpe: ThreadPoolExecutor) -> None:
Make sure that other tasks do not submit anything to this thread pool executor.
:param tpe: thread pool executor to sync
:raises WouldWaitMore: timeout exceeded
"""
assert isinstance(tpe, ThreadPoolExecutor), 'Must be a ThreadPoolExecutor!'
workers = tpe._max_workers
atm_n = AtomicNumber(workers)
cond = Condition()
def decrease_atm():
nonlocal atm_n
atm_n -= 1
cond.wait()
futures = [tpe.submit(decrease_atm) for _ in range(workers)]
# wait for all currently scheduled jobs to be picked up
while tpe._work_queue.qsize() > 0:
time.sleep(0.5)
atm_n.wait_until_equal(0)
cond.notify_all()
wait(futures)
with measure() as measurement:
workers = tpe._max_workers
atm_n = AtomicNumber(workers)
cond = Condition()
def decrease_atm():
nonlocal atm_n
atm_n -= 1
cond.wait()
futures = [tpe.submit(decrease_atm) for _ in range(workers)]
# wait for all currently scheduled jobs to be picked up
while tpe._work_queue.qsize() > 0:
if max_wait is not None:
if measurement() > max_wait:
for future in futures:
future.cancel()
raise WouldWaitMore('timeout exceeded')
time.sleep(0.5)
if max_wait is None:
atm_n.wait_until_equal(0)
else:
while measurement() < max_wait:
try:
atm_n.wait_until_equal(0, 1)
break
except WouldWaitMore:
continue
else:
raise WouldWaitMore('timeout exceeded')
cond.notify_all()
wait(futures)
......@@ -74,6 +74,11 @@ class TestConcurrent(unittest.TestCase):
sync_threadpool(tp)
self.assertEqual(a['cond'], 0)
def test_sync_threadpool_wait_max(self):
tp = ThreadPoolExecutor(max_workers=1)
tp.submit(lambda: time.sleep(3))
self.assertRaises(WouldWaitMore, lambda: sync_threadpool(tp, 2))
def test_run_as_future(self):
a = {}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment