From 8e51465a42f291ddf8746d438c6b6233afc8dfd9 Mon Sep 17 00:00:00 2001
From: Piotr Maslanka <piotr.maslanka@henrietta.com.pl>
Date: Sun, 10 Dec 2017 03:07:50 +0100
Subject: [PATCH] started adding bundles

---
 firanka/builders.py        |  2 +-
 firanka/exceptions.py      |  2 +-
 firanka/series/__init__.py |  4 +++-
 firanka/series/base.py     | 14 +++++++-------
 firanka/series/bundle.py   | 28 ++++++++++++++++++++++++++++
 tests/test_builder.py      | 11 +++++------
 tests/test_series.py       | 15 +++++++++++++--
 7 files changed, 58 insertions(+), 18 deletions(-)
 create mode 100644 firanka/series/bundle.py

diff --git a/firanka/builders.py b/firanka/builders.py
index 2c6e0d8..e484a7e 100644
--- a/firanka/builders.py
+++ b/firanka/builders.py
@@ -23,7 +23,7 @@ class DiscreteSeriesBuilder(object):
             series = DiscreteSeries([])
 
         if not isinstance(series, DiscreteSeries):
-            raise TypeError('discrete knowledge builder supports only discrete series')
+            raise TypeError(u'discrete knowledge builder supports only discrete series')
 
         self.new_data = {}
         self.domain = series.domain
diff --git a/firanka/exceptions.py b/firanka/exceptions.py
index 7721760..bc1ef14 100644
--- a/firanka/exceptions.py
+++ b/firanka/exceptions.py
@@ -24,6 +24,6 @@ class NotInDomainError(DomainError):
     """
 
     def __init__(self, index, domain):
-        super(NotInDomainError, self).__init__('NotInDomainError: %s not in %s' % (index, domain))
+        super(NotInDomainError, self).__init__(u'NotInDomainError: %s not in %s' % (index, domain))
         self.index = index
         self.domain = domain
diff --git a/firanka/series/__init__.py b/firanka/series/__init__.py
index be43233..5ba9b3c 100644
--- a/firanka/series/__init__.py
+++ b/firanka/series/__init__.py
@@ -2,6 +2,7 @@
 from __future__ import absolute_import
 
 from .base import DiscreteSeries, Series
+from .bundle import SeriesBundle
 from .function import FunctionSeries
 from .interpolations import LinearInterpolationSeries, \
     SCALAR_LINEAR_INTERPOLATOR
@@ -12,6 +13,7 @@ __all__ = [
     'DiscreteSeries',
     'ModuloSeries',
     'Series',
-    'SCALAR_LINEAR_INTERPOLATOR',
     'LinearInterpolationSeries',
+    'SCALAR_LINEAR_INTERPOLATOR',
+    'SeriesBundle',
 ]
diff --git a/firanka/series/base.py b/firanka/series/base.py
index 381e749..9b82918 100644
--- a/firanka/series/base.py
+++ b/firanka/series/base.py
@@ -46,7 +46,7 @@ class Series(object):
             return self._get_for(item)
 
     def _get_for(self, item):
-        raise NotImplementedError('This is abstract, override me!')
+        raise NotImplementedError(u'This is abstract, override me!')
 
     def eval_points(self, points):
         """
@@ -90,7 +90,7 @@ class Series(object):
         :param fun: callable(t: float, v1: any, v2: any) => value
         :return: new Series instance
         """
-        assert _has_arguments(fun, 3), 'Callable to join needs 3 arguments'
+        assert _has_arguments(fun, 3), u'Callable to join needs 3 arguments'
 
         return JoinedSeries(self, series, fun)
 
@@ -122,10 +122,10 @@ class DiscreteSeries(Series):
 
         if len(data) > 0:
             if self.domain.start < data[0][0]:
-                raise DomainError('some domain space is not covered by definition!')
+                raise DomainError(u'some domain space is not covered by definition!')
 
     def apply(self, fun):
-        assert _has_arguments(fun, 2), 'fun must have at least 2 arguments'
+        assert _has_arguments(fun, 2), u'fun must have at least 2 arguments'
 
         return DiscreteSeries([(k, fun(k, v)) for k, v in self.data],
                               self.domain)
@@ -138,7 +138,7 @@ class DiscreteSeries(Series):
             if k <= item:
                 return v
 
-        raise RuntimeError('should never happen')
+        raise RuntimeError(u'should never happen')
 
     def translate(self, x):
         return DiscreteSeries([(k + x, v) for k, v in self.data],
@@ -198,7 +198,7 @@ class DiscreteSeries(Series):
         :param fun:
         :return:
         """
-        assert _has_arguments(fun, 3), 'fun must have at least 3 arguments!'
+        assert _has_arguments(fun, 3), u'fun must have at least 3 arguments!'
 
         new_domain = self.domain.intersection(series.domain)
 
@@ -261,7 +261,7 @@ class JoinedSeries(Series):
 
     def __init__(self, ser1, ser2, op, *args, **kwargs):
         """:type op: callable(time: float, v1, v2: any) -> v"""
-        assert _has_arguments(op, 3), 'op must have 3 arguments'
+        assert _has_arguments(op, 3), u'op must have 3 arguments'
 
         super(JoinedSeries, self).__init__(ser1.domain.intersection(ser2.domain), *args, **kwargs)
         self.ser1 = ser1
diff --git a/firanka/series/bundle.py b/firanka/series/bundle.py
new file mode 100644
index 0000000..8090f56
--- /dev/null
+++ b/firanka/series/bundle.py
@@ -0,0 +1,28 @@
+# coding=UTF-8
+from __future__ import print_function, absolute_import, division
+
+import functools
+import logging
+
+logger = logging.getLogger(__name__)
+
+from .base import Series
+from ..intervals import REAL_SET
+
+
+class SeriesBundle(Series):
+    """
+    Bundles a bunch of series together, returning a list from their outputs
+    """
+
+    def __init__(self, *series):
+        domain = functools.reduce(lambda x, y: x.intersection(y),
+                                  (p.domain for p in series),
+                                  REAL_SET)
+
+        super(SeriesBundle, self).__init__(domain)
+
+        self.series = series
+
+    def _get_for(self, item):
+        return [s._get_for(item) for s in self.series]
diff --git a/tests/test_builder.py b/tests/test_builder.py
index 9188215..d84e36c 100644
--- a/tests/test_builder.py
+++ b/tests/test_builder.py
@@ -1,6 +1,6 @@
 # coding=UTF-8
 from __future__ import print_function, absolute_import, division
-import six
+
 import unittest
 
 from firanka.builders import DiscreteSeriesBuilder
@@ -9,8 +9,7 @@ from firanka.series import DiscreteSeries
 
 class TestBuilder(unittest.TestCase):
     def test_t1(self):
-
-        ser = DiscreteSeries([(0,1), (1,2)])
+        ser = DiscreteSeries([(0, 1), (1, 2)])
 
         kb = DiscreteSeriesBuilder(ser)
 
@@ -22,7 +21,7 @@ class TestBuilder(unittest.TestCase):
         s2 = kb.as_series()
 
         self.assertTrue(s2.domain, '<-1;3>')
-        self.assertEqual(s2.data,[(-1,6), (0,2), (1,2), (3,4)])
+        self.assertEqual(s2.data, [(-1, 6), (0, 2), (1, 2), (3, 4)])
 
     def test_exnihilo(self):
         kb = DiscreteSeriesBuilder()
@@ -32,6 +31,6 @@ class TestBuilder(unittest.TestCase):
 
         s = kb.as_series()
 
-        self.assertEqual(s[0],0)
-        self.assertEqual(s[1],1)
+        self.assertEqual(s[0], 0)
+        self.assertEqual(s[1], 1)
         self.assertEqual(s.domain, '<0;1>')
diff --git a/tests/test_series.py b/tests/test_series.py
index c947e31..a706e92 100644
--- a/tests/test_series.py
+++ b/tests/test_series.py
@@ -7,7 +7,7 @@ import unittest
 from firanka.exceptions import NotInDomainError, DomainError
 from firanka.intervals import Interval
 from firanka.series import DiscreteSeries, FunctionSeries, ModuloSeries, \
-    LinearInterpolationSeries, Series
+    LinearInterpolationSeries, Series, SeriesBundle
 
 NOOP = lambda x: x
 
@@ -125,7 +125,7 @@ class TestDiscreteSeries(unittest.TestCase):
     def test_discretize(self):
         # note the invalid data for covering this domain
         self.assertRaises(DomainError, lambda: FunctionSeries(lambda x: x ** 2,
-                                                             '<-10;10)').discretize(
+                                                              '<-10;10)').discretize(
             [0, 1, 2, 3, 4, 5], '(-1;6)'))
 
         self.assertRaises(NotInDomainError, lambda: FunctionSeries(lambda x: x ** 2,
@@ -230,3 +230,14 @@ class TestLinearInterpolation(unittest.TestCase):
     def test_conf(self):
         self.assertRaises(TypeError, lambda: LinearInterpolationSeries(
             FunctionSeries(NOOP, '<0;3)')))
+
+
+class TestBundles(unittest.TestCase):
+    def test_base(self):
+        s = SeriesBundle(
+            DiscreteSeries([(0, 1), (1, 1), (2, 1)], '<0;inf)'),
+            DiscreteSeries([(0, 2), (1, 2), (2, 2), (3, 4)], '<0;inf)'),
+        )
+
+        self.assertEqual(s[0], [1, 2])
+        self.assertEqual(s[3], [1, 4])
-- 
GitLab