From 8e51465a42f291ddf8746d438c6b6233afc8dfd9 Mon Sep 17 00:00:00 2001 From: Piotr Maslanka <piotr.maslanka@henrietta.com.pl> Date: Sun, 10 Dec 2017 03:07:50 +0100 Subject: [PATCH] started adding bundles --- firanka/builders.py | 2 +- firanka/exceptions.py | 2 +- firanka/series/__init__.py | 4 +++- firanka/series/base.py | 14 +++++++------- firanka/series/bundle.py | 28 ++++++++++++++++++++++++++++ tests/test_builder.py | 11 +++++------ tests/test_series.py | 15 +++++++++++++-- 7 files changed, 58 insertions(+), 18 deletions(-) create mode 100644 firanka/series/bundle.py diff --git a/firanka/builders.py b/firanka/builders.py index 2c6e0d8..e484a7e 100644 --- a/firanka/builders.py +++ b/firanka/builders.py @@ -23,7 +23,7 @@ class DiscreteSeriesBuilder(object): series = DiscreteSeries([]) if not isinstance(series, DiscreteSeries): - raise TypeError('discrete knowledge builder supports only discrete series') + raise TypeError(u'discrete knowledge builder supports only discrete series') self.new_data = {} self.domain = series.domain diff --git a/firanka/exceptions.py b/firanka/exceptions.py index 7721760..bc1ef14 100644 --- a/firanka/exceptions.py +++ b/firanka/exceptions.py @@ -24,6 +24,6 @@ class NotInDomainError(DomainError): """ def __init__(self, index, domain): - super(NotInDomainError, self).__init__('NotInDomainError: %s not in %s' % (index, domain)) + super(NotInDomainError, self).__init__(u'NotInDomainError: %s not in %s' % (index, domain)) self.index = index self.domain = domain diff --git a/firanka/series/__init__.py b/firanka/series/__init__.py index be43233..5ba9b3c 100644 --- a/firanka/series/__init__.py +++ b/firanka/series/__init__.py @@ -2,6 +2,7 @@ from __future__ import absolute_import from .base import DiscreteSeries, Series +from .bundle import SeriesBundle from .function import FunctionSeries from .interpolations import LinearInterpolationSeries, \ SCALAR_LINEAR_INTERPOLATOR @@ -12,6 +13,7 @@ __all__ = [ 'DiscreteSeries', 'ModuloSeries', 'Series', - 'SCALAR_LINEAR_INTERPOLATOR', 'LinearInterpolationSeries', + 'SCALAR_LINEAR_INTERPOLATOR', + 'SeriesBundle', ] diff --git a/firanka/series/base.py b/firanka/series/base.py index 381e749..9b82918 100644 --- a/firanka/series/base.py +++ b/firanka/series/base.py @@ -46,7 +46,7 @@ class Series(object): return self._get_for(item) def _get_for(self, item): - raise NotImplementedError('This is abstract, override me!') + raise NotImplementedError(u'This is abstract, override me!') def eval_points(self, points): """ @@ -90,7 +90,7 @@ class Series(object): :param fun: callable(t: float, v1: any, v2: any) => value :return: new Series instance """ - assert _has_arguments(fun, 3), 'Callable to join needs 3 arguments' + assert _has_arguments(fun, 3), u'Callable to join needs 3 arguments' return JoinedSeries(self, series, fun) @@ -122,10 +122,10 @@ class DiscreteSeries(Series): if len(data) > 0: if self.domain.start < data[0][0]: - raise DomainError('some domain space is not covered by definition!') + raise DomainError(u'some domain space is not covered by definition!') def apply(self, fun): - assert _has_arguments(fun, 2), 'fun must have at least 2 arguments' + assert _has_arguments(fun, 2), u'fun must have at least 2 arguments' return DiscreteSeries([(k, fun(k, v)) for k, v in self.data], self.domain) @@ -138,7 +138,7 @@ class DiscreteSeries(Series): if k <= item: return v - raise RuntimeError('should never happen') + raise RuntimeError(u'should never happen') def translate(self, x): return DiscreteSeries([(k + x, v) for k, v in self.data], @@ -198,7 +198,7 @@ class DiscreteSeries(Series): :param fun: :return: """ - assert _has_arguments(fun, 3), 'fun must have at least 3 arguments!' + assert _has_arguments(fun, 3), u'fun must have at least 3 arguments!' new_domain = self.domain.intersection(series.domain) @@ -261,7 +261,7 @@ class JoinedSeries(Series): def __init__(self, ser1, ser2, op, *args, **kwargs): """:type op: callable(time: float, v1, v2: any) -> v""" - assert _has_arguments(op, 3), 'op must have 3 arguments' + assert _has_arguments(op, 3), u'op must have 3 arguments' super(JoinedSeries, self).__init__(ser1.domain.intersection(ser2.domain), *args, **kwargs) self.ser1 = ser1 diff --git a/firanka/series/bundle.py b/firanka/series/bundle.py new file mode 100644 index 0000000..8090f56 --- /dev/null +++ b/firanka/series/bundle.py @@ -0,0 +1,28 @@ +# coding=UTF-8 +from __future__ import print_function, absolute_import, division + +import functools +import logging + +logger = logging.getLogger(__name__) + +from .base import Series +from ..intervals import REAL_SET + + +class SeriesBundle(Series): + """ + Bundles a bunch of series together, returning a list from their outputs + """ + + def __init__(self, *series): + domain = functools.reduce(lambda x, y: x.intersection(y), + (p.domain for p in series), + REAL_SET) + + super(SeriesBundle, self).__init__(domain) + + self.series = series + + def _get_for(self, item): + return [s._get_for(item) for s in self.series] diff --git a/tests/test_builder.py b/tests/test_builder.py index 9188215..d84e36c 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -1,6 +1,6 @@ # coding=UTF-8 from __future__ import print_function, absolute_import, division -import six + import unittest from firanka.builders import DiscreteSeriesBuilder @@ -9,8 +9,7 @@ from firanka.series import DiscreteSeries class TestBuilder(unittest.TestCase): def test_t1(self): - - ser = DiscreteSeries([(0,1), (1,2)]) + ser = DiscreteSeries([(0, 1), (1, 2)]) kb = DiscreteSeriesBuilder(ser) @@ -22,7 +21,7 @@ class TestBuilder(unittest.TestCase): s2 = kb.as_series() self.assertTrue(s2.domain, '<-1;3>') - self.assertEqual(s2.data,[(-1,6), (0,2), (1,2), (3,4)]) + self.assertEqual(s2.data, [(-1, 6), (0, 2), (1, 2), (3, 4)]) def test_exnihilo(self): kb = DiscreteSeriesBuilder() @@ -32,6 +31,6 @@ class TestBuilder(unittest.TestCase): s = kb.as_series() - self.assertEqual(s[0],0) - self.assertEqual(s[1],1) + self.assertEqual(s[0], 0) + self.assertEqual(s[1], 1) self.assertEqual(s.domain, '<0;1>') diff --git a/tests/test_series.py b/tests/test_series.py index c947e31..a706e92 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -7,7 +7,7 @@ import unittest from firanka.exceptions import NotInDomainError, DomainError from firanka.intervals import Interval from firanka.series import DiscreteSeries, FunctionSeries, ModuloSeries, \ - LinearInterpolationSeries, Series + LinearInterpolationSeries, Series, SeriesBundle NOOP = lambda x: x @@ -125,7 +125,7 @@ class TestDiscreteSeries(unittest.TestCase): def test_discretize(self): # note the invalid data for covering this domain self.assertRaises(DomainError, lambda: FunctionSeries(lambda x: x ** 2, - '<-10;10)').discretize( + '<-10;10)').discretize( [0, 1, 2, 3, 4, 5], '(-1;6)')) self.assertRaises(NotInDomainError, lambda: FunctionSeries(lambda x: x ** 2, @@ -230,3 +230,14 @@ class TestLinearInterpolation(unittest.TestCase): def test_conf(self): self.assertRaises(TypeError, lambda: LinearInterpolationSeries( FunctionSeries(NOOP, '<0;3)'))) + + +class TestBundles(unittest.TestCase): + def test_base(self): + s = SeriesBundle( + DiscreteSeries([(0, 1), (1, 1), (2, 1)], '<0;inf)'), + DiscreteSeries([(0, 2), (1, 2), (2, 2), (3, 4)], '<0;inf)'), + ) + + self.assertEqual(s[0], [1, 2]) + self.assertEqual(s[3], [1, 4]) -- GitLab