Skip to content
Snippets Groups Projects
Commit 24e61b1c authored by Piotr Maślanka's avatar Piotr Maślanka
Browse files

fix CPManager, v2.21.2

parent da5fbb68
No related branches found
Tags v2.21.2
No related merge requests found
# v2.21.2 # v2.21.2
* fixed that CPManager, this time with unit tests to back them up
* added option to invalidate all connections on this CPManager
\ No newline at end of file
__version__ = '2.21.2a1' __version__ = '2.21.2'
...@@ -54,9 +54,16 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta): ...@@ -54,9 +54,16 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta):
def close(self) -> None: def close(self) -> None:
if super().close(): if super().close():
self.terminating = True self.terminating = True
while self.spawned_connections > 0: self.invalidate()
self.teardown_connection(self.connections.get())
self.spawned_connections -= 1 @Monitor.synchronized
def invalidate(self) -> None:
"""
Close all connections. Connections have to be released first. Object is ready for use after this
"""
while self.spawned_connections > 0:
self.teardown_connection(self.connections.get())
self.spawned_connections -= 1
def acquire_connection(self) -> T: def acquire_connection(self) -> T:
""" """
...@@ -72,20 +79,19 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta): ...@@ -72,20 +79,19 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta):
except queue.Empty: except queue.Empty:
while True: while True:
with silence_excs(queue.Empty), Monitor.acquire(self): with silence_excs(queue.Empty), Monitor.acquire(self):
if self.spawned_connections >= self.max_number: if self.connections.qsize() >= self.max_number:
conn = self.connections.get(False, 5) conn = self.connections.get(False, 5)
break break
elif self.spawned_connections < self.max_number: elif self.connections.qsize() < self.max_number:
conn = self.create_connection() conn = self.create_connection()
self.spawned_connections += 1
self.connections.put(conn)
break break
obj_id = id(conn) with Monitor.acquire(self):
try: obj_id = id(conn)
self.id_to_times[obj_id] += 1 try:
except KeyError: self.id_to_times[obj_id] += 1
self.id_to_times[obj_id] = 1 except KeyError:
return conn self.id_to_times[obj_id] = 1
return conn
def release_connection(self, connection: T) -> None: def release_connection(self, connection: T) -> None:
""" """
...@@ -95,20 +101,21 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta): ...@@ -95,20 +101,21 @@ class CPManager(Monitor, Closeable, tp.Generic[T], metaclass=abc.ABCMeta):
""" """
obj_id = id(connection) obj_id = id(connection)
if self.id_to_times[obj_id] == self.max_cycle_no: if self.id_to_times[obj_id] == self.max_cycle_no:
with Monitor.acquire(self), silence_excs(KeyError): self._kill_connection(connection)
self.spawned_connections -= 1
del self.id_to_times[obj_id]
self.teardown_connection(connection)
else: else:
try: try:
self.connections.put(connection, False) self.connections.put(connection, False)
except queue.Full: except queue.Full:
with Monitor.acquire(self), silence_excs(KeyError): self._kill_connection(connection)
self.spawned_connections -= 1
del self.id_to_times[obj_id] def _kill_connection(self, connection):
self.teardown_connection(connection) obj_id = id(connection)
with Monitor.acquire(self):
del self.id_to_times[obj_id]
self.teardown_connection(connection)
@Monitor.synchronized
def fail_connection(self, connection: T) -> None: def fail_connection(self, connection: T) -> None:
""" """
Signal that a given connection has been failed Signal that a given connection has been failed
......
import time
import unittest import unittest
from concurrent.futures import Future from concurrent.futures import Future
...@@ -8,32 +7,68 @@ from satella.coding.resources import CPManager ...@@ -8,32 +7,68 @@ from satella.coding.resources import CPManager
class TestResources(unittest.TestCase): class TestResources(unittest.TestCase):
def test_something(self): def test_cp_manager(self):
class Connection:
total_connections = 0
def __init__(self):
self.i = 0
Connection.total_connections += 1
self.id = Connection.total_connections
self.value_error_emitted = False
def do(self):
if self.value_error_emitted:
raise RuntimeError('do called despite raising ValueError earlier')
self.i += 1
if self.i == 3:
self.value_error_emitted = True
raise ValueError()
class InheritCPManager(CPManager): class InheritCPManager(CPManager):
def __init__(self, *args): def __init__(self, *args):
super().__init__(*args) super().__init__(*args)
self.resources = 0 self.already_acquired = set()
def create_connection(self): def create_connection(self) -> Connection:
time.sleep(3) return Connection()
self.resources += 1
return lambda: self.resources + 1 def acquire_connection(self):
v = super().acquire_connection()
if v.id in self.already_acquired:
raise RuntimeError('Reacquiring an acquired connection')
self.already_acquired.add(v.id)
return v
def fail_connection(self, connection) -> None:
super().fail_connection(connection)
def release_connection(self, connection) -> None:
self.already_acquired.remove(connection.id)
super().release_connection(connection)
def teardown_connection(self, connection) -> None: def teardown_connection(self, connection) -> None:
... ...
cp = InheritCPManager(5, 2) cp = InheritCPManager(5, 3)
conns = [cp.acquire_connection() for _ in range(5)] conns = [cp.acquire_connection() for _ in range(5)]
@call_in_separate_thread() @call_in_separate_thread(daemon=True)
def do_call(): def do_call():
conn = cp.acquire_connection() for _ in range(10):
cp.release_connection(conn) conn = cp.acquire_connection()
try:
conn.do()
except ValueError:
cp.fail_connection(conn)
cp.release_connection(conn)
ret = do_call() # type: Future ret = do_call() # type: Future
cp.release_connection(conns.pop()) cp.release_connection(conns.pop())
ret.result(timeout=5) ret.result(timeout=15)
while conns: while conns:
cp.release_connection(conns.pop()) cp.release_connection(conns.pop())
......
...@@ -13,9 +13,8 @@ class RealConnection: ...@@ -13,9 +13,8 @@ class RealConnection:
def cursor(self): def cursor(self):
self.cursor_called += 1 self.cursor_called += 1
return Mock()
rpjg hyosp yh
self.commit_called += 1 self.commit_called += 1
return Mock()
def rollback(self): def rollback(self):
self.rollback_called += 1 self.rollback_called += 1
...@@ -23,6 +22,9 @@ rpjg hyosp yh ...@@ -23,6 +22,9 @@ rpjg hyosp yh
def close(self): def close(self):
self.close_called += 1 self.close_called += 1
def commit(self):
self.commit_called += 1
class TestDB(unittest.TestCase): class TestDB(unittest.TestCase):
def test_db(self): def test_db(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment