diff --git a/firanka/series/__init__.py b/firanka/series/__init__.py index 3b3bbcdf59397899bf54eae56cc58f275b2abf5b..ce1102ae4ee89f9bfb4f99537480d04660c06230 100644 --- a/firanka/series/__init__.py +++ b/firanka/series/__init__.py @@ -5,9 +5,13 @@ import logging from .exceptions import OutOfRangeError, EmptyDomainError from .range import Range +from .series import DiscreteSeries, FunctionBasedSeries + __all__ = [ 'OutOfRangeError', 'EmptyDomainError', - 'Range' + 'Range', + 'FunctionBasedSeries', + 'DiscreteSeries' ] diff --git a/firanka/series/range.py b/firanka/series/range.py index c3cd8158635daed17edf1e7d460e45ac4952791b..2fdc599bccf80b3d779960dd3ea9306ef72749ee 100644 --- a/firanka/series/range.py +++ b/firanka/series/range.py @@ -28,6 +28,8 @@ class Range(object): rs, = args if isinstance(rs, type(self)): args = rs.start, rs.stop, rs.left_inc, rs.right_inc + elif isinstance(rs, slice): + args = rs.start, rs.stop, True, True else: if rs[0] not in '<(': raise ValueError('Must start with ( or <') if rs[-1] not in '>)': raise ValueError('Must end with ) or >') @@ -41,20 +43,37 @@ class Range(object): if q(2, 0, args) or q(3, 1, args): raise ValueError('Set with sharp closing but infinity set') - print(args) self.start, self.stop, self.left_inc, self.right_inc = args def __contains__(self, x): - if x == self.start: - return self.left_inc + """ + :type x: index or a Range + """ - if x == self.stop: - return self.right_inc + if isinstance(x, (Range, six.text_type)): + if isinstance(x, six.text_type): + x = Range(x) + print('does ', self, 'contain', x) - return self.start < x < self.stop + if x.start == self.start: + if x.left_inc ^ self.left_inc: + return False + + if x.stop == self.stop: + if x.right_inc ^ self.right_inc: + return False + + return (x.start >= self.start) and (x.stop <= self.stop) + else: + if x == self.start: + return self.left_inc + + if x == self.stop: + return self.right_inc + + return self.start < x < self.stop def is_empty(self): - print(self.start, self.stop, self.left_inc, self.right_inc) return (self.start == self.stop) and not (self.left_inc or self.right_inc) def __len__(self): diff --git a/firanka/series/series.py b/firanka/series/series.py new file mode 100644 index 0000000000000000000000000000000000000000..142dae2ef0aef55e896ff6fb72b75aaf48c3e680 --- /dev/null +++ b/firanka/series/series.py @@ -0,0 +1,147 @@ + +# coding=UTF-8 +from __future__ import print_function, absolute_import, division +import six +import functools +import itertools + +from .range import Range + + +class Series(object): + + def __init__(self, domain): + if not isinstance(domain, Range): + domain = Range(domain) + self.domain = domain + + def __getitem__(self, item): + if isinstance(item, slice): + item = Range(item) + if item not in self.domain: + raise ValueError('slicing beyond series domain') + + newdom = self.domain.intersection(item) + return SlicedSeries(self, newdom) + else: + if item not in self.domain: + raise ValueError('item not in domain') + + return self._get_for(item) + + def _get_for(self, item): + raise NotImplementedError + + def eval_points(self, points): + return [self[p] for p in points] + + def apply(self, series, fun): + return AppliedSeries(self, series, fun) + + def compute(self): + """Simplify self""" + return self + + +class SlicedSeries(Series): + def __init__(self, parent, domain): + super(SlicedSeries, self).__init__(domain) + self.parent = parent + + def _get_for(self, item): + return self.parent._get_for(item) + +class DiscreteSeries(Series): + + def __init__(self, data, domain=None): + if domain is None: + domain = Range(data[0][0], data[-1][0], True, True) + + self.data = data + super(DiscreteSeries, self).__init__(domain) + + def _get_for(self, item): + for k, v in reversed(self.data): + if k <= item: + return v + + raise RuntimeError('should never happen') + + def compute(self): + """Simplify self""" + nd = [self.data[0]] + for i in six.moves.range(1, len(self.data)): + if self.data[i][1] != nd[-1][1]: + nd.append(self.data[i]) + return DiscreteSeries(nd, self.domain) + + +class FunctionBasedSeries(Series): + def __init__(self, fun, domain): + super(FunctionBasedSeries, self).__init__(domain) + self.fun = fun + self._get_for = fun + + +class AppliedSeries(Series): + def __init__(self, ser1, ser2, op): + super(AppliedSeries, self).__init__( + ser1.domain.intersection(ser2.domain)) + self.ser1 = ser1 + self.ser2 = ser2 + self.op = op + + def _get_for(self, item): + return self.op(self.ser1._get_for(item), self.ser2._get_for(item)) + + def compute(self): + """ + Attempt to simplify the call tree + """ + if isinstance(self.ser1, DiscreteSeries) and isinstance(self.ser2, + DiscreteSeries): + a = [p for p, q in reversed(self.ser1.data)] + b = [p for p, q in reversed(self.ser2.data)] + + ptr = self.domain.start + c = [(ptr, self._get_for(ptr))] + + while len(a) > 0 or len(b) > 0: + if len(a) > 0 and len(b) > 0: + if a[-1] < b[-1]: + ptr = a.pop() + elif a[-1] > b[-1]: + ptr = b.pop() + else: + a.pop() + ptr = b.pop() + + assert ptr >= c[-1][0] + + if ptr > c[-1][0]: + c.append((ptr, self._get_for(ptr))) + + else: + rest = a if len(a) > 0 else b + c.extend((ptr, self._get_for(ptr)) for ptr, v in rest) + break + + return DiscreteSeries(c, self.domain) + elif isinstance(self.ser1, DiscreteSeries) or isinstance(self.ser2, + DiscreteSeries): + dis, nds = (self.ser1, self.ser2) if isinstance(self.ser1, DiscreteSeries) else (self.ser2, self.ser1) + + if dis.data[0][0] != self.domain.start: + p = [(self.domain.start, self._get_for(self.domain.start))] + else: + p = [] + + for ptr, v in dis.data: + p.append((ptr, self._get_for(ptr))) + + if dis.data[-1][0] != self.domain.stop: + dis.data.append((self.domain.stop, self._get_for(ptr))) + + return DiscreteSeries(p, self.domain) + else: + return self diff --git a/tests/test_series/test_range.py b/tests/test_series/test_range.py index 852239f9f59454e510e3b3fd5951a57ea07145c2..2bcf86757599255908c535bb87286047e63ff506 100644 --- a/tests/test_series/test_range.py +++ b/tests/test_series/test_range.py @@ -51,3 +51,5 @@ class TestRange(unittest.TestCase): self.assertTrue(-5 in Range('(-10;-1>')) self.assertFalse(-20 in Range('(-10;-1>')) self.assertFalse(1 in Range('(-10;-1>')) + + self.assertTrue(Range('<-5;5>') in Range('<-10;10>')) \ No newline at end of file diff --git a/tests/test_series/test_series.py b/tests/test_series/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..3ba6c1ecc2546783b7d8c2064e95301d101d8f4f --- /dev/null +++ b/tests/test_series/test_series.py @@ -0,0 +1,78 @@ +# coding=UTF-8 +from __future__ import print_function, absolute_import, division +import six +import unittest +from firanka.series import DiscreteSeries, FunctionBasedSeries, Range + + +class TestDiscreteSeries(unittest.TestCase): + + + def test_base(self): + + s = DiscreteSeries([[0,0], [1,1], [2,2]]) + + self.assertEqual(s[0], 0) + self.assertEqual(s[0.5], 0) + self.assertEqual(s[1], 1) + + self.assertRaises(ValueError, lambda: s[-1]) + self.assertRaises(ValueError, lambda: s[2.5]) + + + s = DiscreteSeries([[0,0], [1,1], [2,2]], domain=Range(0,3,True,True)) + self.assertEqual(s[0], 0) + self.assertEqual(s[0.5], 0) + self.assertEqual(s[1], 1) + + self.assertRaises(ValueError, lambda: s[-1]) + self.assertEqual(s[2.5], 2) + + + def test_slice(self): + series = DiscreteSeries([[0, 0], [1, 1], [2, 2]]) + + sp = series[0.5:1.5] + + self.assertEqual(sp[0.5], 0) + self.assertEqual(sp[1.5], 1) + self.assertRaises(ValueError, lambda: sp[0]) + self.assertRaises(ValueError, lambda: sp[2]) + self.assertEqual(sp.domain.start, 0.5) + self.assertEqual(sp.domain.stop, 1.5) + + def test_eval(self): + sa = DiscreteSeries([[0, 0], [1, 1], [2, 2]]) + sb = DiscreteSeries([[0, 1], [1, 2], [2, 3]]) + + sc = sa.apply(sb, lambda a, b: a+b) + sd = sc.compute() + self.assertEqual(sc.eval_points([0,1,2]), [1,3,5]) + self.assertEqual(sd.eval_points([0,1,2]), sd.eval_points([0,1,2])) + + self.assertEqual(sd.data, [(0,1),(1,3),(2,5)]) + + def test_eval2(self): + sa = DiscreteSeries([[0, 0], [1, 1], [2, 2]]) + sb = FunctionBasedSeries(lambda x: x, '<0;2)') + + sc = sa.apply(sb, lambda a, b: a+b) + sd = sc.compute() + self.assertEqual(sc.eval_points([0,1,2]), [0,2,4]) + self.assertEqual(sd.eval_points([0,1,2]), sd.eval_points([0,1,2])) + + self.assertEqual(sd.data, [(0,0),(1,2),(2,4)]) + + +class TestFunctionBasedSeries(unittest.TestCase): + def test_slice(self): + series = FunctionBasedSeries(lambda x: x, '<0;2>') + + sp = series[0.5:1.5] + + self.assertEqual(sp[0.5], 0.5) + self.assertEqual(sp[1.5], 1.5) + self.assertRaises(ValueError, lambda: sp[0]) + self.assertRaises(ValueError, lambda: sp[2]) + self.assertEqual(sp.domain.start, 0.5) + self.assertEqual(sp.domain.stop, 1.5)