From 5efaaf30b481defbf3b9aec7fb35eab1c2d88aa4 Mon Sep 17 00:00:00 2001 From: Piotr Maslanka <piotr.maslanka@henrietta.com.pl> Date: Sat, 9 Dec 2017 23:06:05 +0100 Subject: [PATCH] reformat --- firanka/builders.py | 16 +++++++++------- firanka/ranges.py | 2 +- firanka/series/base.py | 18 ++++++++++++------ firanka/series/function.py | 5 ----- tests/test_series.py | 29 ++++++++++++++++++----------- 5 files changed, 40 insertions(+), 30 deletions(-) diff --git a/firanka/builders.py b/firanka/builders.py index ba49333..d01d1d4 100644 --- a/firanka/builders.py +++ b/firanka/builders.py @@ -1,12 +1,13 @@ # coding=UTF-8 from __future__ import print_function, absolute_import, division -import six -import logging + import copy -from .series import Series, DiscreteSeries -from .ranges import Range + from sortedcontainers import SortedList +from .ranges import Range +from .series import DiscreteSeries + """ Update knowledge of current discrete series """ @@ -15,6 +16,7 @@ __all__ = [ 'DiscreteSeriesBuilder', ] + class DiscreteSeriesBuilder(object): def __init__(self, series=None): @@ -51,10 +53,10 @@ class DiscreteSeriesBuilder(object): for k, v in self.series.data: if k in cp_new_data: v = cp_new_data.pop(k) - new_data.append((k,v)) + new_data.append((k, v)) # Add those that remained - for k,v in cp_new_data.items(): - new_data.add((k,v)) + for k, v in cp_new_data.items(): + new_data.add((k, v)) return DiscreteSeries(new_data, self.domain) diff --git a/firanka/ranges.py b/firanka/ranges.py index 491ed7b..0d1256b 100644 --- a/firanka/ranges.py +++ b/firanka/ranges.py @@ -96,7 +96,7 @@ class Range(object): if isinstance(x, Range): if ((x.start == self.start) and (x.left_inc ^ self.left_inc)) \ - or ((x.stop == self.stop) and (x.right_inc ^ self.right_inc)): + or ((x.stop == self.stop) and (x.right_inc ^ self.right_inc)): return False return (x.start >= self.start) and (x.stop <= self.stop) diff --git a/firanka/series/base.py b/firanka/series/base.py index 2001146..efaa6cd 100644 --- a/firanka/series/base.py +++ b/firanka/series/base.py @@ -1,17 +1,17 @@ # coding=UTF-8 from __future__ import print_function, absolute_import, division -import six import inspect +from sortedcontainers import SortedList + from firanka.exceptions import NotInDomainError from firanka.ranges import Range, EMPTY_SET -from sortedcontainers import SortedList -def _has_arguments(fun, n): +def _has_arguments(fun, n): # used only in assert clauses assert hasattr(fun, '__call__'), 'function is not callable!' - return len(inspect.getargspec(fun).args) == n + return len(inspect.getargspec(fun).args) >= n class Series(object): @@ -131,6 +131,8 @@ class DiscreteSeries(Series): 'some domain space is not covered by definition!') def apply(self, fun): + assert _has_arguments(fun, 2), 'fun must have at least 2 arguments' + return DiscreteSeries([(k, fun(k, v)) for k, v in self.data], self.domain) @@ -174,10 +176,12 @@ class DiscreteSeries(Series): if len(a) > 0 or len(b) > 0: if len(a) > 0: + assert len(b) == 0 rest = a static_v = series._get_for(ptr) op = lambda ptr, me, const: fun(ptr, me, const) else: + assert len(a) == 0 rest = b static_v = self._get_for(ptr) op = lambda ptr, me, const: fun(ptr, const, me) @@ -187,8 +191,9 @@ class DiscreteSeries(Series): return DiscreteSeries(c, new_domain) - def join_discrete(self, series, fun): + assert _has_arguments(fun, 3), 'fun must have at least 3 arguments!' + new_domain = self.domain.intersection(series.domain) if isinstance(series, DiscreteSeries): @@ -217,7 +222,8 @@ class AlteredSeries(Series): """ Internal use - for applyings, translations and slicing """ - def __init__(self, series, domain=None, applyfun=lambda k,v: v, x=0, *args, **kwargs): + + def __init__(self, series, domain=None, applyfun=lambda k, v: v, x=0, *args, **kwargs): """ :param series: original series :param domain: new domain to use [if sliced] diff --git a/firanka/series/function.py b/firanka/series/function.py index c80e6b3..7778afd 100644 --- a/firanka/series/function.py +++ b/firanka/series/function.py @@ -1,7 +1,5 @@ # coding=UTF-8 from __future__ import print_function, absolute_import, division -import six -import logging from .base import Series @@ -17,6 +15,3 @@ class FunctionSeries(Series): def _get_for(self, item): return self.fun(item) - - - diff --git a/tests/test_series.py b/tests/test_series.py index 8c4301b..d35862a 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -13,18 +13,18 @@ NOOP = lambda x: x HUGE_IDENTITY = FunctionSeries(NOOP, '(-inf;inf)') + class TestBase(unittest.TestCase): def test_abstract(self): self.assertRaises(NotImplementedError, lambda: Series('<-1;1>')[0]) class TestDiscreteSeries(unittest.TestCase): - def test_redundancy_skip(self): - a = DiscreteSeries([(0,0), (1,0), (2,0)], '<0;5>') - b = DiscreteSeries([(0,0), (1,0)], '<0;5>') + a = DiscreteSeries([(0, 0), (1, 0), (2, 0)], '<0;5>') + b = DiscreteSeries([(0, 0), (1, 0)], '<0;5>') - a.join(b, lambda i,x,y: x+y) + a.join(b, lambda i, x, y: x + y) def test_uncov(self): self.assertRaises(ValueError, @@ -129,10 +129,9 @@ class TestDiscreteSeries(unittest.TestCase): [0, 1, 2, 3, 4, 5], '(-1;6)')) self.assertRaises(NotInDomainError, lambda: FunctionSeries(lambda x: x ** 2, - '<-10;10)').discretize( + '<-10;10)').discretize( [-100, 0, 1, 2, 3, 4, 5], '(-1;6)')) - PTS = [-1, 0, 1, 2, 3, 4, 5] sa = FunctionSeries(lambda x: x ** 2, '<-10;10)').discretize(PTS, '(-1;6)') @@ -166,13 +165,21 @@ class TestFunctionSeries(unittest.TestCase): self.assertEqual(series.eval_points(PTS), [x for x in PTS]) + def test_apply_wild(self): + def dzika(k, x, a=5, *args, **kwargs): + return k + + PTS = [-1, -2, -3, 1, 2, 3] + series = FunctionSeries(NOOP, '<-5;5>').apply(dzika) + + self.assertEqual(series.eval_points(PTS), [x for x in PTS]) + def test_domain_sensitivity(self): logs = FunctionSeries(math.log, '(0;5>') dirs = DiscreteSeries([(0, 1), (1, 2), (3, 4)], '<0;5>') self.assertRaises(ValueError, - lambda: dirs.join_discrete(logs, lambda x, y: x + y)) - + lambda: dirs.join_discrete(logs, lambda i, x, y: x + y)) class TestModuloSeries(unittest.TestCase): @@ -198,10 +205,10 @@ class TestModuloSeries(unittest.TestCase): self.assertEqual(series.period, 3.0) - self.assertEqual(series.eval_points([-1,0,1]), [1,2,3]) + self.assertEqual(series.eval_points([-1, 0, 1]), [1, 2, 3]) - self.assertEqual(series.eval_points([-5,-4,-3,-2,-1,0,1,2,3,4,5]), - [ 3, 1, 2, 3, 1,2,3,1,2,3,1]) + self.assertEqual(series.eval_points([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]), + [3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1]) def test_comp_discrete(self): ser1 = ModuloSeries(FunctionSeries(lambda x: x ** 2, '<0;3)')) -- GitLab