From b0681fef32cb2d42b14a90da9490718c3b95261d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Fri, 25 Sep 2020 21:55:36 +0200
Subject: [PATCH] v2.11.17_a2

* added `wrap_future`
* refactored `satella.cassandra`
* deprecated tracing Cassandra's ResponseFutures directly
---
 CHANGELOG.md                  |  5 ++++
 docs/cassandra.rst            |  5 ++++
 satella/__init__.py           |  2 +-
 satella/cassandra/__init__.py | 52 ++---------------------------------
 satella/cassandra/common.py   |  5 ++++
 satella/cassandra/future.py   | 19 +++++++++++++
 satella/cassandra/parallel.py | 50 +++++++++++++++++++++++++++++++++
 satella/opentracing/trace.py  | 37 ++++++++++++-------------
 tests/test_cassandra.py       | 45 +++++++++++++++++++++++++++++-
 9 files changed, 149 insertions(+), 71 deletions(-)
 create mode 100644 satella/cassandra/common.py
 create mode 100644 satella/cassandra/future.py
 create mode 100644 satella/cassandra/parallel.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9d6aab57..527cc3a5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1 +1,6 @@
 # v2.11.17
+
+* added `wrap_future`
+* refactored `satella.cassandra`
+* deprecated tracing Cassandra's ResponseFutures directly
+
diff --git a/docs/cassandra.rst b/docs/cassandra.rst
index 4f8552e8..051c6ead 100644
--- a/docs/cassandra.rst
+++ b/docs/cassandra.rst
@@ -3,6 +3,11 @@ Cassandra
 
 **This module is available only if you have cassandra-driver installed**
 
+wrap_future
+-----------
+
+.. autofunction:: satella.cassandra.wrap_future
+
 parallel_for
 ------------
 
diff --git a/satella/__init__.py b/satella/__init__.py
index 2e935608..7a3edefe 100644
--- a/satella/__init__.py
+++ b/satella/__init__.py
@@ -1 +1 @@
-__version__ = '2.11.17_a1'
+__version__ = '2.11.17_a2'
diff --git a/satella/cassandra/__init__.py b/satella/cassandra/__init__.py
index 97912507..f5afbe58 100644
--- a/satella/cassandra/__init__.py
+++ b/satella/cassandra/__init__.py
@@ -1,50 +1,4 @@
-import itertools
-import typing as tp
-from collections import namedtuple
+from .parallel import parallel_for
+from .future import wrap_future
 
-
-def parallel_for(cursor, query: tp.Union[tp.List[str], str, 'Statement', tp.List['Statement']],
-                 arguments: tp.Iterable[tuple]) -> tp.Iterator[namedtuple]:
-    """
-    Syntactic sugar for
-
-    >>> futures = []
-    >>> for args in arguments:
-    >>>     futures.append(cursor.execute_async(query, args))
-    >>> for future in futures:
-    >>>     yield future.result()
-
-    If query is a string or a Cassandra Statement, or else
-
-    >>> futures = []
-    >>> for query, args in zip(query, arguments):
-    >>>     futures.append(cursor.execute_async(query, args))
-    >>> for future in futures:
-    >>>     yield future.result()
-
-    Note that if None is encountered in the argument iterable, session.execute() will
-    be called with a single argument. You better have it as a BoundStatement then!
-
-    :param cursor: the Cassandra cursor to use (obtained using connection.session())
-    :param query: base query or a list of queries, if a different one is to be used
-    :param arguments: iterable yielding arguments to use in execute_async
-    """
-    try:
-        from cassandra.query import Statement
-        query_classes = (str, Statement)
-    except ImportError:
-        query_classes = str
-
-    if isinstance(query, query_classes):
-        query = itertools.repeat(query)
-
-    futures = []
-    for query, args in zip(query, arguments):
-        if args is None:
-            future = cursor.execute_async(query)
-        else:
-            future = cursor.execute_async(query, args)
-        futures.append(future)
-
-    for future in futures:
-        yield future.result()
+__all__ = ['wrap_future', 'parallel_for']
diff --git a/satella/cassandra/common.py b/satella/cassandra/common.py
new file mode 100644
index 00000000..448f63ed
--- /dev/null
+++ b/satella/cassandra/common.py
@@ -0,0 +1,5 @@
+try:
+    from cassandra.cluster import ResponseFuture
+except ImportError:
+    class ResponseFuture:
+        pass
diff --git a/satella/cassandra/future.py b/satella/cassandra/future.py
new file mode 100644
index 00000000..1a3f9212
--- /dev/null
+++ b/satella/cassandra/future.py
@@ -0,0 +1,19 @@
+from concurrent.futures import Future
+
+from .common import ResponseFuture
+
+
+def wrap_future(future: ResponseFuture) -> Future:
+    """
+    Convert a Cassandra's future to a normal Python future.
+
+    :param future: cassandra future to wrap
+    :return: a standard Python future
+    """
+
+    fut = Future()
+    fut.set_running_or_notify_cancel()
+    future.add_callback(lambda result: fut.set_result(result))
+    future.add_errback(lambda exception: fut.set_exception(exception))
+    return fut
+
diff --git a/satella/cassandra/parallel.py b/satella/cassandra/parallel.py
new file mode 100644
index 00000000..97912507
--- /dev/null
+++ b/satella/cassandra/parallel.py
@@ -0,0 +1,50 @@
+import itertools
+import typing as tp
+from collections import namedtuple
+
+
+def parallel_for(cursor, query: tp.Union[tp.List[str], str, 'Statement', tp.List['Statement']],
+                 arguments: tp.Iterable[tuple]) -> tp.Iterator[namedtuple]:
+    """
+    Syntactic sugar for
+
+    >>> futures = []
+    >>> for args in arguments:
+    >>>     futures.append(cursor.execute_async(query, args))
+    >>> for future in futures:
+    >>>     yield future.result()
+
+    If query is a string or a Cassandra Statement, or else
+
+    >>> futures = []
+    >>> for query, args in zip(query, arguments):
+    >>>     futures.append(cursor.execute_async(query, args))
+    >>> for future in futures:
+    >>>     yield future.result()
+
+    Note that if None is encountered in the argument iterable, session.execute() will
+    be called with a single argument. You better have it as a BoundStatement then!
+
+    :param cursor: the Cassandra cursor to use (obtained using connection.session())
+    :param query: base query or a list of queries, if a different one is to be used
+    :param arguments: iterable yielding arguments to use in execute_async
+    """
+    try:
+        from cassandra.query import Statement
+        query_classes = (str, Statement)
+    except ImportError:
+        query_classes = str
+
+    if isinstance(query, query_classes):
+        query = itertools.repeat(query)
+
+    futures = []
+    for query, args in zip(query, arguments):
+        if args is None:
+            future = cursor.execute_async(query)
+        else:
+            future = cursor.execute_async(query, args)
+        futures.append(future)
+
+    for future in futures:
+        yield future.result()
diff --git a/satella/opentracing/trace.py b/satella/opentracing/trace.py
index fd11edff..52ae0d6c 100644
--- a/satella/opentracing/trace.py
+++ b/satella/opentracing/trace.py
@@ -1,6 +1,10 @@
 import typing as tp
+import sys
+import warnings
 from concurrent.futures import Future
 
+from ..cassandra.future import wrap_future
+from ..cassandra.common import ResponseFuture
 from satella.coding.decorators import wraps
 
 try:
@@ -9,12 +13,6 @@ except ImportError:
     class Span:
         pass
 
-try:
-    from cassandra.cluster import ResponseFuture
-except ImportError:
-    class ResponseFuture:
-        pass
-
 
 def trace_function(tracer, name: str, tags: tp.Optional[dict] = None):
     """
@@ -43,18 +41,17 @@ def trace_future(future: tp.Union[ResponseFuture, Future], span: Span):
     :param span: span to close on future's completion
     """
     if isinstance(future, ResponseFuture):
-        def close_exception(exc):
+        warnings.warn('Tracing Cassandra futures is deprecated. Use wrap_future() to '
+                      'convert it to a standard Python future. This feature will be '
+                      'deprecated in Satella 3.x', DeprecationWarning)
+        future = wrap_future(future)
+
+    def close_future(fut):
+        exc = fut.exception()
+        if exc is not None:
             # noinspection PyProtectedMember
-            Span._on_error(span, type(exc), exc, '<unavailable>')
-            span.finish()
-
-        future.add_callback(span.finish)
-        future.add_errback(close_exception)
-    else:
-        def close_future(fut):
-            exc = fut.exception()
-            if exc is not None:
-                # noinspection PyProtectedMember
-                Span._on_error(span, type(exc), exc, '<unavailable>')
-            span.finish()
-        future.add_done_callback(close_future)
+            exc_type, value, traceback = sys.exc_info()
+            Span._on_error(span, exc_type, value, traceback)
+        span.finish()
+
+    future.add_done_callback(close_future)
diff --git a/tests/test_cassandra.py b/tests/test_cassandra.py
index 1e59041a..fac49118 100644
--- a/tests/test_cassandra.py
+++ b/tests/test_cassandra.py
@@ -1,8 +1,51 @@
-from satella.cassandra import parallel_for
+from satella.coding.concurrent import CallableGroup
+
+from satella.cassandra import parallel_for, wrap_future
 import unittest
 
 
 class TestCassandra(unittest.TestCase):
+    def test_wrap_future(self):
+        class MockCassandraFuture:
+            def __init__(self):
+                self.value = None
+                self.callbacks = CallableGroup()
+                self.errbacks = CallableGroup()
+
+            def add_callback(self, callback):
+                self.callbacks.add(callback)
+
+            def add_errback(self, errback):
+                self.errbacks.add(errback)
+
+            def set_result(self, x):
+                self.value = x
+                if isinstance(x, Exception):
+                    self.errbacks(x)
+                else:
+                    self.callbacks(x)
+
+        mcf = MockCassandraFuture()
+        wrapped = wrap_future(mcf)
+        a = {}
+
+        def on_done(fut):
+            if fut.exception() is None:
+                a['success'] = True
+            else:
+                a['failure'] = True
+
+        wrapped.add_done_callback(on_done)
+
+        mcf.set_result(None)
+        self.assertTrue(a['success'])
+
+        mcf = MockCassandraFuture()
+        wrapped = wrap_future(mcf)
+        wrapped.add_done_callback(on_done)
+        mcf.set_result(Exception())
+        self.assertTrue(a['failure'])
+
     def test_parallel_for(self):
         class Cursor:
             def __init__(self):
-- 
GitLab