diff --git a/docs/db.rst b/docs/db.rst index 2df2fcb7d14341cc277f7551a1ae6ce15f5de95a..3aa17d8810bd458dbb56fbcd020832a21e14b7c2 100644 --- a/docs/db.rst +++ b/docs/db.rst @@ -7,3 +7,8 @@ So enjoy! .. autoclass:: satella.db.transaction :members: + +You might use is also a context decorator, eg. + +>>> @transaction(conn) +>>> def do_transaction(): \ No newline at end of file diff --git a/satella/__init__.py b/satella/__init__.py index dc0f3a2ebad8cf826d652b66e1eba059f03dcea5..076535da917296f9349c97bba67d5256be99a633 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.20.3a1' +__version__ = '2.20.3' diff --git a/satella/db.py b/satella/db.py index 13b5d08f190fd62180f6a2b14ea263ea4eec3f5d..7f61f39d892a94f29117397ab6dc2d6d848d6cf7 100644 --- a/satella/db.py +++ b/satella/db.py @@ -1,5 +1,8 @@ +import inspect import logging +from satella.coding import wraps + logger = logging.getLogger(__name__) @@ -16,21 +19,43 @@ class transaction: Leaving the context manager will automatically close the cursor for you. - :param connection: the connection object to use + >>> def conn_getter_function() -> connection: + >>> .... + >>> @transaction(conn_getter_function) + >>> .... + + The same syntax can be used, if you session depends eg. on a thread. + + :param connection_or_getter: the connection object to use, or a callable/0, that called with + this thread will provide us with a connection :param close_the_connection_after: whether the connection should be closed after use, False by default :param log_exception: whether to log an exception if it happens """ - def __init__(self, connection, close_the_connection_after: bool = False, + def __init__(self, connection_or_getter, close_the_connection_after: bool = False, log_exception: bool = True): - self.connection = connection + self._connection = connection_or_getter self.close_the_connection_after = close_the_connection_after self.cursor = None self.log_exception = log_exception + def __call__(self, fun): + @wraps(fun) + def inner(*args, **kwargs): + with self: + return fun(*args, **kwargs) + return inner + def __enter__(self): self.cursor = self.connection.cursor() return self.cursor + @property + def connection(self): + if inspect.isfunction(self.connection): + return self._connection() + else: + return self._connection + def __exit__(self, exc_type, exc_val, exc_tb): if exc_val is None: self.connection.commit()