From 7fef38cb941dcdaf1fa4450ddcec26792058971a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@ericsson.com>
Date: Thu, 4 Apr 2024 13:11:28 +0200
Subject: [PATCH] added wrap_callable_in_context_manager

---
 CHANGELOG.md                         |  1 +
 docs/coding/ctxt_managers.rst        |  2 ++
 satella/__init__.py                  |  2 +-
 satella/coding/__init__.py           |  4 ++--
 satella/coding/ctxt_managers.py      | 25 +++++++++++++++++++++++++
 tests/test_coding/test_decorators.py | 19 ++++++++++++++++++-
 6 files changed, 49 insertions(+), 4 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 26358e1f..398a1c51 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,4 @@
 # v2.25.2
 
+* added wrap_callable_in_context_manager
 
diff --git a/docs/coding/ctxt_managers.rst b/docs/coding/ctxt_managers.rst
index 2bcd52fa..ab613eed 100644
--- a/docs/coding/ctxt_managers.rst
+++ b/docs/coding/ctxt_managers.rst
@@ -3,3 +3,5 @@ Context managers
 
 .. autoclass:: satella.coding.EmptyContextManager
     :members:
+
+.. autofunction:: satella.coding.wrap_callable_in_context_manager
diff --git a/satella/__init__.py b/satella/__init__.py
index dd476395..0f50d43c 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.25.2a1'
+__version__ = '2.25.2a2'
diff --git a/satella/coding/__init__.py b/satella/coding/__init__.py
index d2965721..6226b1c0 100644
--- a/satella/coding/__init__.py
+++ b/satella/coding/__init__.py
@@ -4,7 +4,7 @@ Just useful objects to make your coding nicer every day
 
 from .algos import merge_dicts
 from .concurrent import Monitor, RMonitor
-from .ctxt_managers import EmptyContextManager
+from .ctxt_managers import EmptyContextManager, wrap_callable_in_context_manager
 from .decorators import precondition, short_none, has_keys, \
     wraps, chain_functions, postcondition, queue_get, auto_adapt_to_methods, \
     attach_arguments, for_argument
@@ -26,7 +26,7 @@ from .recast_exceptions import rethrow_as, silence_excs, catch_exception, log_ex
 
 __all__ = [
     'EmptyContextManager', 'Context', 'length',
-    'assert_equal', 'InequalityReason', 'Inequal',
+    'assert_equal', 'InequalityReason', 'Inequal', 'wrap_callable_in_context_manager',
     'Closeable', 'contains', 'enum_value',
     'expect_exception',
     'overload', 'class_or_instancemethod', 'TypeSignature',
diff --git a/satella/coding/ctxt_managers.py b/satella/coding/ctxt_managers.py
index 7dd1123b..d9928ba6 100644
--- a/satella/coding/ctxt_managers.py
+++ b/satella/coding/ctxt_managers.py
@@ -1,3 +1,6 @@
+from satella.coding.decorators import wraps
+
+
 class EmptyContextManager:
     """
     A context manager that does nothing. Only to support conditional change of context managers,
@@ -21,3 +24,25 @@ class EmptyContextManager:
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         return False
+
+
+def wrap_callable_in_context_manager(clbl, ctxt_mgr, *my_args, **my_kwargs):
+    """
+    Wrap a callable in context manager.
+
+    Roughly equivalent to:
+
+    >>> def inner(*args, **kwargs):
+    >>>     with ctxt_mgr(*my_args, **my_kwargs):
+    >>>         return clbl(*args, **kwargs)
+    >>> return inner
+
+    To be used as:
+
+    >>> clbl = wrap_callable_in_context_manager(lambda y: 5, tracing.start_new_span, 'New span')
+    """
+    @wraps(clbl)
+    def inner(*args, **kwargs):
+        with ctxt_mgr(*my_args, **my_kwargs):
+            return clbl(*args, **kwargs)
+    return inner
diff --git a/tests/test_coding/test_decorators.py b/tests/test_coding/test_decorators.py
index cb69d32b..38ff234a 100644
--- a/tests/test_coding/test_decorators.py
+++ b/tests/test_coding/test_decorators.py
@@ -5,7 +5,7 @@ from socket import socket
 
 import time
 from satella.coding import wraps, chain_functions, postcondition, \
-    log_exceptions, queue_get, precondition, short_none
+    log_exceptions, queue_get, precondition, short_none, wrap_callable_in_context_manager
 from satella.coding.decorators import auto_adapt_to_methods, attach_arguments, \
     execute_before, loop_while, memoize, copy_arguments, replace_argument_if, \
     retry, return_as_list, default_return, transform_result, transform_arguments, \
@@ -18,6 +18,23 @@ logger = logging.getLogger(__name__)
 
 
 class TestDecorators(unittest.TestCase):
+    def test_wrap_ctxt_mgr(self):
+        a = None
+        class CtxtMgr:
+            def __init__(self, value):
+                nonlocal a
+                a = value
+
+            def __enter__(self):
+                return self
+
+            def __exit__(self, exc_type, exc_val, exc_tb):
+                return False
+
+        clbl = lambda y: y
+        clbl = wrap_callable_in_context_manager(clbl, CtxtMgr, 5)
+        clbl(3)
+        self.assertEqual(a, 5)
 
     def test_cached_property(self):
         class Example:
-- 
GitLab