From 6dd302b54de56ae05d75a066e5e93573ffefccb8 Mon Sep 17 00:00:00 2001
From: Piotr Maslanka <piotr.maslanka@henrietta.com.pl>
Date: Sat, 9 Dec 2017 09:39:48 +0100
Subject: [PATCH] minor fixes

---
 firanka/ranges.py      |  7 ++--
 firanka/series/base.py | 79 ++++++++++++++++++++++--------------------
 tests/test_series.py   |  9 +++++
 3 files changed, 56 insertions(+), 39 deletions(-)

diff --git a/firanka/ranges.py b/firanka/ranges.py
index 3e06c1e..f0bbd7d 100644
--- a/firanka/ranges.py
+++ b/firanka/ranges.py
@@ -29,8 +29,11 @@ class Range(object):
     """
 
     def translate(self, x):
-        return Range(self.start + x, self.stop + x, self.left_inc,
-                     self.right_inc)
+        if x == 0:
+            return self
+        else:
+            return Range(self.start + x, self.stop + x, self.left_inc,
+                         self.right_inc)
 
     def __init__(self, *args):
         """
diff --git a/firanka/series/base.py b/firanka/series/base.py
index 7c0df92..a67bef9 100644
--- a/firanka/series/base.py
+++ b/firanka/series/base.py
@@ -35,8 +35,7 @@ class Series(object):
             if item not in self.domain:
                 raise NotInDomainError('slicing beyond series domain')
 
-            newdom = self.domain.intersection(item)
-            return SlicedSeries(self, newdom)
+            return AlteredSeries(self, domain=self.domain.intersection(item))
         else:
             if item not in self.domain:
                 raise NotInDomainError('item not in domain')
@@ -60,7 +59,7 @@ class Series(object):
         :param fun: callable/1 => 1
         :return: Series instance
         """
-        return AppliedAndTranslatedSeries(self, applyfun=lambda k, v: fun(v))
+        return AlteredSeries(self, applyfun=lambda k, v: fun(v))
 
     def apply_with_indices(self, fun):
         """
@@ -68,7 +67,7 @@ class Series(object):
         :param fun: callable(index: float, value: any) => 1
         :return: Series instance
         """
-        return AppliedAndTranslatedSeries(self, applyfun=fun)
+        return AlteredSeries(self, applyfun=fun)
 
     def discretize(self, points, domain=None):
         """
@@ -95,6 +94,16 @@ class Series(object):
         :param fun: callable/2 => value
         :return: new Series instance
         """
+        return JoinedSeries(self, series, lambda t, v1, v2: fun(v1, v2))
+
+    def join_with_indices(self, series, fun):
+        """
+        Return a new series with values of fun(index, v1, v2)
+
+        :param series: series to join against
+        :param fun: callable/3 => value
+        :return: new Series instance
+        """
         return JoinedSeries(self, series, fun)
 
     def translate(self, x):
@@ -103,7 +112,7 @@ class Series(object):
         :param x: a float
         :return: new Series instance
         """
-        return AppliedAndTranslatedSeries(self, x=x)
+        return AlteredSeries(self, x=x)
 
 
 class DiscreteSeries(Series):
@@ -148,7 +157,7 @@ class DiscreteSeries(Series):
         b = series.data[::-1]
 
         ptr = self.domain.start
-        c = [(ptr, fun(self._get_for(ptr), series._get_for(ptr)))]
+        c = [(ptr, fun(ptr, self._get_for(ptr), series._get_for(ptr)))]
 
         while len(a) > 0 and len(b) > 0:
             if a[-1] < b[-1]:
@@ -161,58 +170,64 @@ class DiscreteSeries(Series):
                 ptr, v1 = a.pop()
                 _, v2 = b.pop()
 
-            _appendif(c, ptr, fun(v1, v2))
+            _appendif(c, ptr, fun(ptr, v1, v2))
 
         if len(a) > 0 or len(b) > 0:
             if len(a) > 0:
                 rest = a
                 static_v = series._get_for(ptr)
-                op = lambda me, const: fun(me, const)
+                op = lambda ptr, me, const: fun(ptr, me, const)
             else:
                 rest = b
                 static_v = self._get_for(ptr)
-                op = lambda me, const: fun(const, me)
+                op = lambda ptr, me, const: fun(ptr, const, me)
 
             for ptr, v in rest:
-                _appendif(c, ptr, op(v, static_v))
+                _appendif(c, ptr, op(ptr, v, static_v))
 
         return DiscreteSeries(c, new_domain)
 
     def join_discrete(self, series, fun):
+        return self.join_discrete_with_indices(series, lambda i, v1, v2: fun(v1, v2))
+
+    def join_discrete_with_indices(self, series, fun):
         new_domain = self.domain.intersection(series.domain)
 
         if isinstance(series, DiscreteSeries):
             return self._join_discrete_other_discrete(series, fun)
 
         if new_domain.start > self.data[0][0]:
-            c = [(new_domain.start, fun(self._get_for(new_domain.start),
+            c = [(new_domain.start, fun(new_domain.start,
+                                        self._get_for(new_domain.start),
                                         series._get_for(new_domain.start)))]
         else:
             c = []
 
         for k, v in ((k, v) for k, v in self.data if
                      new_domain.start <= k <= new_domain.stop):
-            _appendif(c, k, fun(v, series._get_for(k)))
+            _appendif(c, k, fun(k, v, series._get_for(k)))
 
         if c[-1][0] != new_domain.stop:
-            c.append((new_domain.stop, fun(self._get_for(new_domain.stop),
+            c.append((new_domain.stop, fun(new_domain.stop,
+                                           self._get_for(new_domain.stop),
                                            series._get_for(new_domain.stop))))
 
         return DiscreteSeries(c, new_domain)
 
-    def compute(self):
-        """Simplify self"""
-        nd = [self.data[0]]
-        for i in six.moves.range(1, len(self.data)):
-            if self.data[i][1] != nd[-1][1]:
-                nd.append(self.data[i])
-        return DiscreteSeries(nd, self.domain)
-
 
-class AppliedAndTranslatedSeries(Series):
-    def __init__(self, series, applyfun=lambda k,v: v, x=0, *args, **kwargs):
-        """:type applyfun: callable(float, v) -> any"""
-        super(AppliedAndTranslatedSeries, self).__init__(series.domain.translate(x), *args, **kwargs)
+class AlteredSeries(Series):
+    """
+    Internal use - for applyings, translations and slicing
+    """
+    def __init__(self, series, domain=None, applyfun=lambda k,v: v, x=0, *args, **kwargs):
+        """
+        :param series: original series
+        :param domain: new domain to use [if sliced]
+        :param applyfun: (index, v) -> newV [if applied]
+        :param x: translation vector [if translated]
+        """
+        domain = domain or series.domain
+        super(AlteredSeries, self).__init__(domain.translate(x), *args, **kwargs)
         self.fun = applyfun
         self.series = series
         self.x = x
@@ -221,15 +236,6 @@ class AppliedAndTranslatedSeries(Series):
         return self.fun(item, self.series._get_for(item + self.x))
 
 
-class SlicedSeries(Series):
-    def __init__(self, parent, domain, *args, **kwargs):
-        super(SlicedSeries, self).__init__(domain, *args, **kwargs)
-        self.parent = parent
-
-    def _get_for(self, item):
-        return self.parent._get_for(item)
-
-
 def _appendif(lst, ptr, v):
     if len(lst) > 0:
         assert lst[-1][0] <= ptr
@@ -246,11 +252,10 @@ class JoinedSeries(Series):
     """
 
     def __init__(self, ser1, ser2, op, *args, **kwargs):
-        domain = ser1.domain.intersection(ser2.domain)
-        super(JoinedSeries, self).__init__(domain, *args, **kwargs)
+        super(JoinedSeries, self).__init__(ser1.domain.intersection(ser2.domain), *args, **kwargs)
         self.ser1 = ser1
         self.ser2 = ser2
         self.op = op
 
     def _get_for(self, item):
-        return self.op(self.ser1._get_for(item), self.ser2._get_for(item))
+        return self.op(item, self.ser1._get_for(item), self.ser2._get_for(item))
diff --git a/tests/test_series.py b/tests/test_series.py
index c8608aa..49027f9 100644
--- a/tests/test_series.py
+++ b/tests/test_series.py
@@ -11,6 +11,7 @@ from firanka.series import DiscreteSeries, FunctionSeries, ModuloSeries, \
 
 NOOP = lambda x: x
 
+HUGE_IDENTITY = FunctionSeries(NOOP, '(-inf;inf)')
 
 class TestBase(unittest.TestCase):
     def test_abstract(self):
@@ -97,6 +98,13 @@ class TestDiscreteSeries(unittest.TestCase):
         self.assertIsInstance(sc, DiscreteSeries)
         self.assertEqual(sc.data, [(0, 0), (1, 2), (2, 4)])
 
+    def test_eval2i(self):
+        sa = DiscreteSeries([[0, 0], [1, 1], [2, 2]])
+        sc = sa.join_discrete_with_indices(HUGE_IDENTITY, lambda i, a, b: i)
+        self.assertEqual(sc.eval_points([0, 1, 2]), [0, 1, 2])
+        self.assertIsInstance(sc, DiscreteSeries)
+        self.assertEqual(sc.data, [(0, 0), (1, 1), (2, 2)])
+
     def test_apply(self):
         sa = DiscreteSeries([[0, 0], [1, 1], [2, 2]]).apply(lambda x: x + 1)
         self.assertEquals(sa.data, [(0, 1), (1, 2), (2, 3)])
@@ -175,6 +183,7 @@ class TestFunctionSeries(unittest.TestCase):
                           lambda: dirs.join_discrete(logs, lambda x, y: x + y))
 
 
+
 class TestModuloSeries(unittest.TestCase):
     def test_exceptions(self):
         self.assertRaises(ValueError, lambda: ModuloSeries(
-- 
GitLab