From eda3be1961800ebf288caf10a05789a17d854bfd Mon Sep 17 00:00:00 2001
From: Piotr Maslanka <piotr.maslanka@henrietta.com.pl>
Date: Fri, 8 Dec 2017 20:43:08 +0100
Subject: [PATCH] passes tests

---
 firanka/series/range.py         | 54 ++++++++++++++++++---------------
 setup.py                        |  2 --
 tests/test_series/test_range.py |  2 +-
 3 files changed, 30 insertions(+), 28 deletions(-)

diff --git a/firanka/series/range.py b/firanka/series/range.py
index bf60693..81a0123 100644
--- a/firanka/series/range.py
+++ b/firanka/series/range.py
@@ -4,10 +4,20 @@ import six
 import logging
 import re
 from satella.coding import for_argument
+import functools
 
 logger = logging.getLogger(__name__)
 
 
+def pre_range(fun):
+    @functools.wraps(fun)
+    def inner(self, arg, *args, **kwargs):
+        if not isinstance(arg, Range):
+            arg = Range(arg)
+        return fun(self, arg, *args, **kwargs)
+    return inner
+
+
 class Range(object):
     """
     Range of real numbers
@@ -68,45 +78,39 @@ class Range(object):
             '>' if self.rend_inclusive else ')',
         )
 
+    @pre_range
     def intersection(self, y):
-        if not isinstance(y, Range): y = Range(y)
+        print(str(self), str(y))
+        if self.start > y.start:
+            return y.intersection(self)
 
-        x = self
+        assert self.start <= y.start
 
-        # Check for intersection being impossible
-        if (x.stop < y.start) or (x.start > y.stop) or \
-            (x.stop == y.start and not x.rend_inclusive and not y.lend_inclusive) or \
-            (x.start == x.stop and not x.lend_inclusive and not y.rend_inclusive):
+        if (self.stop < y.start) or (y.stop < y.start):
             return EMPTY_RANGE
 
-        # Check for range extension
-        if (x.start == y.stop) and (x.lend_inclusive or y.lend_inclusive):
-            return Range(y.start, x.stop, y.lend_inclusive, x.rend_inclusive)
-
-        if (x.start == y.stop) and (x.lend_inclusive or y.lend_inclusive):
-            return Range(y.start, x.stop, y.lend_inclusive, x.rend_inclusive)
-
+        if self.stop == y.start and not (self.rend_inclusive or y.lend_inclusive):
+            return EMPTY_RANGE
 
-        if x.start == y.start:
-            start = x.start
-            lend_inclusive = x.lend_inclusive or y.lend_inclusive
+        if self.start == y.start:
+            start = self.start
+            lend_inclusive = self.lend_inclusive or y.lend_inclusive
         else:
-            p = x if x.start > y.start else y
-            start = p.start
-            lend_inclusive = p.lend_inclusive
+            start = y.start
+            lend_inclusive = y.lend_inclusive
 
-        if x.stop == y.stop:
-            stop = x.stop
-            rend_inclusive = x.rend_inclusive or y.rend_inclusive
+        if self.stop == y.stop:
+            stop = self.stop
+            rend_inclusive = self.rend_inclusive or y.rend_inclusive
         else:
-            p = x if x.stop < y.stop else y
+            p, q = (self, y) if self.stop < y.stop else (y, self)
             stop = p.stop
-            rend_inclusive = p.rend_inclusive
+            rend_inclusive = p.rend_inclusive and (stop in q)
 
         return Range(start, stop, lend_inclusive, rend_inclusive)
 
+    @pre_range
     def __eq__(self, other):
-        if not isinstance(other, Range): other = Range(other)
         return self.start == other.start and self.stop == other.stop and self.lend_inclusive == other.lend_inclusive and self.rend_inclusive == other.rend_inclusive
 
     def __hash__(self):
diff --git a/setup.py b/setup.py
index 0099f5c..ca77cde 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,5 @@
 # coding=UTF-8
 from setuptools import setup, find_packages
-from pip.req import parse_requirements
-
 
 from firanka import __version__
 
diff --git a/tests/test_series/test_range.py b/tests/test_series/test_range.py
index 7c54260..8398c3c 100644
--- a/tests/test_series/test_range.py
+++ b/tests/test_series/test_range.py
@@ -9,7 +9,7 @@ class TestRange(unittest.TestCase):
 
         self.assertFalse(Range(-10, -1, True, True).intersection('<2;3>'))
         self.assertFalse(Range(-10, -1, True, False).intersection('(-1;3>'))
-        self.assertEquals(Range('<-10;-1)').intersection('<-1;1>'), '<-1;1>')
+        self.assertFalse(Range('<-10;-1)').intersection('<-1;1>'))
 
 
     def test_str(self):
-- 
GitLab