From b4ed382a87604cf07b4b9649fc7b788109156141 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Sun, 29 May 2022 22:46:05 +0200
Subject: [PATCH] v2.30.3

---
 docs/db.rst         |  5 +++++
 satella/__init__.py |  2 +-
 satella/db.py       | 31 ++++++++++++++++++++++++++++---
 3 files changed, 34 insertions(+), 4 deletions(-)

diff --git a/docs/db.rst b/docs/db.rst
index 2df2fcb7..3aa17d88 100644
--- a/docs/db.rst
+++ b/docs/db.rst
@@ -7,3 +7,8 @@ So enjoy!
 
 .. autoclass:: satella.db.transaction
     :members:
+
+You might use is also a context decorator, eg.
+
+>>> @transaction(conn)
+>>>    def do_transaction():
\ No newline at end of file
diff --git a/satella/__init__.py b/satella/__init__.py
index dc0f3a2e..076535da 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.20.3a1'
+__version__ = '2.20.3'
diff --git a/satella/db.py b/satella/db.py
index 13b5d08f..7f61f39d 100644
--- a/satella/db.py
+++ b/satella/db.py
@@ -1,5 +1,8 @@
+import inspect
 import logging
 
+from satella.coding import wraps
+
 logger = logging.getLogger(__name__)
 
 
@@ -16,21 +19,43 @@ class transaction:
 
     Leaving the context manager will automatically close the cursor for you.
 
-    :param connection: the connection object to use
+    >>> def conn_getter_function() -> connection:
+    >>>     ....
+    >>> @transaction(conn_getter_function)
+    >>>     ....
+
+    The same syntax can be used, if you session depends eg. on a thread.
+
+    :param connection_or_getter: the connection object to use, or a callable/0, that called with
+        this thread will provide us with a connection
     :param close_the_connection_after: whether the connection should be closed after use, False by default
     :param log_exception: whether to log an exception if it happens
     """
-    def __init__(self, connection, close_the_connection_after: bool = False,
+    def __init__(self, connection_or_getter, close_the_connection_after: bool = False,
                  log_exception: bool = True):
-        self.connection = connection
+        self._connection = connection_or_getter
         self.close_the_connection_after = close_the_connection_after
         self.cursor = None
         self.log_exception = log_exception
 
+    def __call__(self, fun):
+        @wraps(fun)
+        def inner(*args, **kwargs):
+            with self:
+                return fun(*args, **kwargs)
+        return inner
+
     def __enter__(self):
         self.cursor = self.connection.cursor()
         return self.cursor
 
+    @property
+    def connection(self):
+        if inspect.isfunction(self.connection):
+            return self._connection()
+        else:
+            return self._connection
+
     def __exit__(self, exc_type, exc_val, exc_tb):
         if exc_val is None:
             self.connection.commit()
-- 
GitLab