From ddbaa811c0c29835d45b6aa39d708d5b830ac14f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl> Date: Sat, 28 May 2022 20:50:55 +0200 Subject: [PATCH] added transaction idiom --- CHANGELOG.md | 11 +++-------- docs/db.rst | 9 +++++++++ docs/index.rst | 1 + satella/__init__.py | 2 +- satella/db.py | 48 +++++++++++++++++++++++++++++++++++++++++++++ tests/test_db.py | 40 +++++++++++++++++++++++++++++++++++++ 6 files changed, 102 insertions(+), 9 deletions(-) create mode 100644 docs/db.rst create mode 100644 satella/db.py create mode 100644 tests/test_db.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c87dc97a..bc4504eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,4 @@ -# v2.19.0 +# v2.19.2 + +* added `db_call` -* unit tests migrated to CircleCI -* added __len__ to FutureCollection -* fixed a bug in DictionaryEQAble -* fixed a bug in ListDeleter -* minor breaking change: changed semantics of ListDeleter -* added `CPManager` -* added `SetZip` \ No newline at end of file diff --git a/docs/db.rst b/docs/db.rst new file mode 100644 index 00000000..2df2fcb7 --- /dev/null +++ b/docs/db.rst @@ -0,0 +1,9 @@ +Python DB API 2 +=============== + +However imperfect may it be, it's here to stay. + +So enjoy! + +.. autoclass:: satella.db.transaction + :members: diff --git a/docs/index.rst b/docs/index.rst index 601fa559..d8d58cad 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -28,6 +28,7 @@ Visit the project's page at GitHub_! instrumentation/metrics exception_handling dao + db json posix import diff --git a/satella/__init__.py b/satella/__init__.py index 6d2db50c..a910817d 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.19.0' +__version__ = '2.20.0' diff --git a/satella/db.py b/satella/db.py new file mode 100644 index 00000000..1d3ad8ef --- /dev/null +++ b/satella/db.py @@ -0,0 +1,48 @@ +import logging + +logger = logging.getLogger(__name__) + + +class transaction: + """ + A context manager for wrapping a transaction and getting a cursor from the Python DB API 2. + + Use it as a context manager. commit and rollback will be automatically called for you. + + Use like: + + >>> with transaction(conn) as cur: + >>> cur.execute('DROP DATABASE') + + Leaving the context manager will automatically close the cursor for you. + + :param connection: the connection object to use + :param close_the_connection_after: whether the connection should be closed after use, False by default + :param log_exception: whether to log an exception if it happens + """ + def __init__(self, connection, close_the_connection_after: bool = False, + log_exception: bool = True): + self.connection = connection + self.close_the_connection_after = close_the_connection_after + self.cursor = None + self.log_exception = log_exception + + def __enter__(self): + self.cursor = self.connection.cursor() + return self.cursor() + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_val is None: + self.connection.commit() + else: + self.connection.rollback() + + if self.log_exception: + logger.error('Exception occurred of type %s', exc_type, exc_info=exc_val) + + self.cursor.close() + + if self.close_the_connection_after: + self.connection.close() + + return False diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 00000000..c5d86b02 --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,40 @@ +import unittest +from unittest.mock import Mock + +from satella.db import transaction + + +class TestDB(unittest.TestCase): + def test_something(self): + class RealConnection: + def __init__(self): + self.cursor_called = 0 + self.commit_called = 0 + self.rollback_called = 0 + self.close_called = 0 + + def cursor(self): + self.cursor_called += 1 + return Mock() + + def commit(self): + self.commit_called += 1 + + def rollback(self): + self.rollback_called += 1 + + def close(self): + self.close_called += 1 + + conn = RealConnection() + a = transaction(conn) + with a as cur: + self.assertEqual(conn.cursor_called, 1) + cur.execute('TEST') + self.assertEqual(conn.commit_called, 1) + try: + with a as cur: + raise ValueError() + except ValueError: + self.assertEqual(conn.commit_called, 1) + self.assertEqual(conn.rollback_called, 1) -- GitLab