diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ac196d37d8426df263f555f4e32133af57bcd4e..386166b293f1d760dbd1ab3c5291e2a7815f57bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,4 @@ # v2.21.2 + +* fixed that CPManager, this time with unit tests to back them up +* added option to invalidate all connections on this CPManager \ No newline at end of file diff --git a/satella/__init__.py b/satella/__init__.py index 9311d3af4feb5d10192d924ced985bafa731247c..345635996313fb9eae88a71f05fd9b996bec4ff8 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.21.2a1' +__version__ = '2.21.2' diff --git a/satella/coding/resources/cp_manager.py b/satella/coding/resources/cp_manager.py index 6d0d714a8472308c7add622c0f398611f6c6b1f4..e7c75dad8d4f6ad039ccc7ee8ddef5e4e50a2353 100644 --- a/satella/coding/resources/cp_manager.py +++ b/satella/coding/resources/cp_manager.py @@ -54,9 +54,16 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta): def close(self) -> None: if super().close(): self.terminating = True - while self.spawned_connections > 0: - self.teardown_connection(self.connections.get()) - self.spawned_connections -= 1 + self.invalidate() + + @Monitor.synchronized + def invalidate(self) -> None: + """ + Close all connections. Connections have to be released first. Object is ready for use after this + """ + while self.spawned_connections > 0: + self.teardown_connection(self.connections.get()) + self.spawned_connections -= 1 def acquire_connection(self) -> T: """ @@ -72,20 +79,19 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta): except queue.Empty: while True: with silence_excs(queue.Empty), Monitor.acquire(self): - if self.spawned_connections >= self.max_number: + if self.connections.qsize() >= self.max_number: conn = self.connections.get(False, 5) break - elif self.spawned_connections < self.max_number: + elif self.connections.qsize() < self.max_number: conn = self.create_connection() - self.spawned_connections += 1 - self.connections.put(conn) break - obj_id = id(conn) - try: - self.id_to_times[obj_id] += 1 - except KeyError: - self.id_to_times[obj_id] = 1 - return conn + with Monitor.acquire(self): + obj_id = id(conn) + try: + self.id_to_times[obj_id] += 1 + except KeyError: + self.id_to_times[obj_id] = 1 + return conn def release_connection(self, connection: T) -> None: """ @@ -95,20 +101,21 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta): """ obj_id = id(connection) if self.id_to_times[obj_id] == self.max_cycle_no: - with Monitor.acquire(self), silence_excs(KeyError): - self.spawned_connections -= 1 - del self.id_to_times[obj_id] - - self.teardown_connection(connection) + self._kill_connection(connection) else: try: self.connections.put(connection, False) except queue.Full: - with Monitor.acquire(self), silence_excs(KeyError): - self.spawned_connections -= 1 - del self.id_to_times[obj_id] - self.teardown_connection(connection) + self._kill_connection(connection) + + def _kill_connection(self, connection): + obj_id = id(connection) + with Monitor.acquire(self): + del self.id_to_times[obj_id] + + self.teardown_connection(connection) + @Monitor.synchronized def fail_connection(self, connection: T) -> None: """ Signal that a given connection has been failed diff --git a/tests/test_coding/test_resources.py b/tests/test_coding/test_resources.py index 24182f90da83c7c5f1d0a22e521a892178737089..cde69630b16db7077828bd84fc0488095d613ecd 100644 --- a/tests/test_coding/test_resources.py +++ b/tests/test_coding/test_resources.py @@ -1,4 +1,3 @@ -import time import unittest from concurrent.futures import Future @@ -8,32 +7,68 @@ from satella.coding.resources import CPManager class TestResources(unittest.TestCase): - def test_something(self): + def test_cp_manager(self): + + class Connection: + total_connections = 0 + + def __init__(self): + self.i = 0 + Connection.total_connections += 1 + self.id = Connection.total_connections + self.value_error_emitted = False + + def do(self): + if self.value_error_emitted: + raise RuntimeError('do called despite raising ValueError earlier') + self.i += 1 + if self.i == 3: + self.value_error_emitted = True + raise ValueError() + class InheritCPManager(CPManager): def __init__(self, *args): super().__init__(*args) - self.resources = 0 + self.already_acquired = set() - def create_connection(self): - time.sleep(3) - self.resources += 1 - return lambda: self.resources + 1 + def create_connection(self) -> Connection: + return Connection() + + def acquire_connection(self): + v = super().acquire_connection() + if v.id in self.already_acquired: + raise RuntimeError('Reacquiring an acquired connection') + self.already_acquired.add(v.id) + return v + + def fail_connection(self, connection) -> None: + super().fail_connection(connection) + + def release_connection(self, connection) -> None: + self.already_acquired.remove(connection.id) + super().release_connection(connection) def teardown_connection(self, connection) -> None: ... - cp = InheritCPManager(5, 2) + cp = InheritCPManager(5, 3) conns = [cp.acquire_connection() for _ in range(5)] - @call_in_separate_thread() + @call_in_separate_thread(daemon=True) def do_call(): - conn = cp.acquire_connection() - cp.release_connection(conn) + for _ in range(10): + conn = cp.acquire_connection() + try: + conn.do() + except ValueError: + cp.fail_connection(conn) + + cp.release_connection(conn) ret = do_call() # type: Future cp.release_connection(conns.pop()) - ret.result(timeout=5) + ret.result(timeout=15) while conns: cp.release_connection(conns.pop()) diff --git a/tests/test_db.py b/tests/test_db.py index 1f8735c8587a82262d45d51e21d50099294c52f2..4a7c336b808971058a3e05a1d8f077898a78dfb4 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -13,9 +13,8 @@ class RealConnection: def cursor(self): self.cursor_called += 1 - return Mock() -rpjg hyosp yh self.commit_called += 1 + return Mock() def rollback(self): self.rollback_called += 1 @@ -23,6 +22,9 @@ rpjg hyosp yh def close(self): self.close_called += 1 + def commit(self): + self.commit_called += 1 + class TestDB(unittest.TestCase): def test_db(self):