From ddbaa811c0c29835d45b6aa39d708d5b830ac14f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Sat, 28 May 2022 20:50:55 +0200
Subject: [PATCH] added transaction idiom

---
 CHANGELOG.md        | 11 +++--------
 docs/db.rst         |  9 +++++++++
 docs/index.rst      |  1 +
 satella/__init__.py |  2 +-
 satella/db.py       | 48 +++++++++++++++++++++++++++++++++++++++++++++
 tests/test_db.py    | 40 +++++++++++++++++++++++++++++++++++++
 6 files changed, 102 insertions(+), 9 deletions(-)
 create mode 100644 docs/db.rst
 create mode 100644 satella/db.py
 create mode 100644 tests/test_db.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index c87dc97a..bc4504eb 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,9 +1,4 @@
-# v2.19.0
+# v2.19.2
+
+* added `db_call`
 
-* unit tests migrated to CircleCI
-* added __len__ to FutureCollection
-* fixed a bug in DictionaryEQAble
-* fixed a bug in ListDeleter
-* minor breaking change: changed semantics of ListDeleter
-* added `CPManager`
-* added `SetZip`
\ No newline at end of file
diff --git a/docs/db.rst b/docs/db.rst
new file mode 100644
index 00000000..2df2fcb7
--- /dev/null
+++ b/docs/db.rst
@@ -0,0 +1,9 @@
+Python DB API 2
+===============
+
+However imperfect may it be, it's here to stay.
+
+So enjoy!
+
+.. autoclass:: satella.db.transaction
+    :members:
diff --git a/docs/index.rst b/docs/index.rst
index 601fa559..d8d58cad 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -28,6 +28,7 @@ Visit the project's page at GitHub_!
            instrumentation/metrics
            exception_handling
            dao
+           db
            json
            posix
            import
diff --git a/satella/__init__.py b/satella/__init__.py
index 6d2db50c..a910817d 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.19.0'
+__version__ = '2.20.0'
diff --git a/satella/db.py b/satella/db.py
new file mode 100644
index 00000000..1d3ad8ef
--- /dev/null
+++ b/satella/db.py
@@ -0,0 +1,48 @@
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class transaction:
+    """
+    A context manager for wrapping a transaction and getting a cursor from the Python DB API 2.
+
+    Use it as a context manager. commit and rollback will be automatically called for you.
+
+    Use like:
+
+    >>> with transaction(conn) as cur:
+    >>>   cur.execute('DROP DATABASE')
+
+    Leaving the context manager will automatically close the cursor for you.
+
+    :param connection: the connection object to use
+    :param close_the_connection_after: whether the connection should be closed after use, False by default
+    :param log_exception: whether to log an exception if it happens
+    """
+    def __init__(self, connection, close_the_connection_after: bool = False,
+                 log_exception: bool = True):
+        self.connection = connection
+        self.close_the_connection_after = close_the_connection_after
+        self.cursor = None
+        self.log_exception = log_exception
+
+    def __enter__(self):
+        self.cursor = self.connection.cursor()
+        return self.cursor()
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        if exc_val is None:
+            self.connection.commit()
+        else:
+            self.connection.rollback()
+
+        if self.log_exception:
+            logger.error('Exception occurred of type %s', exc_type, exc_info=exc_val)
+
+        self.cursor.close()
+
+        if self.close_the_connection_after:
+            self.connection.close()
+
+        return False
diff --git a/tests/test_db.py b/tests/test_db.py
new file mode 100644
index 00000000..c5d86b02
--- /dev/null
+++ b/tests/test_db.py
@@ -0,0 +1,40 @@
+import unittest
+from unittest.mock import Mock
+
+from satella.db import transaction
+
+
+class TestDB(unittest.TestCase):
+    def test_something(self):
+        class RealConnection:
+            def __init__(self):
+                self.cursor_called = 0
+                self.commit_called = 0
+                self.rollback_called = 0
+                self.close_called = 0
+
+            def cursor(self):
+                self.cursor_called += 1
+                return Mock()
+
+            def commit(self):
+                self.commit_called += 1
+
+            def rollback(self):
+                self.rollback_called += 1
+
+            def close(self):
+                self.close_called += 1
+
+        conn = RealConnection()
+        a = transaction(conn)
+        with a as cur:
+            self.assertEqual(conn.cursor_called, 1)
+            cur.execute('TEST')
+        self.assertEqual(conn.commit_called, 1)
+        try:
+            with a as cur:
+                raise ValueError()
+        except ValueError:
+            self.assertEqual(conn.commit_called, 1)
+            self.assertEqual(conn.rollback_called, 1)
-- 
GitLab