diff --git a/firanka/series/range.py b/firanka/series/range.py index 3779a143742c996ef6d567f5e8f1702d8e33450a..994677021fb7a1513d4572ef70b65467a05e691f 100644 --- a/firanka/series/range.py +++ b/firanka/series/range.py @@ -24,12 +24,12 @@ class Range(object): """ def __init__(self, *args): if len(args) == 1: - - if isinstance(args[0], type(self)): - start = args[0].start - stop = args[0].stop - lend_inclusive = args[0].lend_inclusive - rend_inclusive = args[0].rend_inclusive + rs, = args + if isinstance(rs, type(self)): + start = rs.start + stop = rs.stop + lend_inclusive = rs.lend_inclusive + rend_inclusive = rs.rend_inclusive else: rs, = args @@ -40,8 +40,7 @@ class Range(object): lend_inclusive = rs[0] == '<' rend_inclusive = rs[-1] == '>' - rs = rs[1:-1] - start, stop = rs.split(';') + start, stop = rs[1:-1].split(';') start = float(start) stop = float(stop) else: @@ -98,7 +97,7 @@ class Range(object): if (self.stop < y.start) or (y.stop < y.start): return EMPTY_RANGE - if self.stop == y.start and not (self.rend_inclusive or y.lend_inclusive): + if self.stop == y.start and not (self.rend_inclusive and y.lend_inclusive): return EMPTY_RANGE if self.start == y.start: diff --git a/tests/test_series/test_range.py b/tests/test_series/test_range.py index 5fa18b72030becac9cb876d3fa3dfbf8af8543d2..1966cf84b255c3a158aa62f7128f6f508b6bb0cb 100644 --- a/tests/test_series/test_range.py +++ b/tests/test_series/test_range.py @@ -9,8 +9,10 @@ from firanka.series import Range class TestRange(unittest.TestCase): def do_intersect(self, a, b, val): - self.assertEqual(bool(Range(a).intersection(b)), val) - self.assertEqual(bool(Range(b).intersection(a)), val) + if bool(Range(a).intersection(b)) != val: + self.fail('%s ^ %s != %s' % (Range(a), Range(b), val)) + if bool(Range(b).intersection(a)) != val: + self.fail('%s ^ %s != %s' % (Range(b), Range(a), val)) def test_intersection(self): self.do_intersect(Range(-10, -1, True, True), '<2;3>', False) @@ -23,6 +25,7 @@ class TestRange(unittest.TestCase): def test_constructor(self): self.assertRaises(ValueError, lambda: Range('#2;3>')) self.assertRaises(ValueError, lambda: Range('(2;3!')) + self.assertEqual(Range(1,2,True,False), Range('<1;2)')) def test_contains(self): self.assertFalse(-1 in Range('<-10;-1)'))