From 63a4a02b781c64a12865b320d9427d8eda28f1df Mon Sep 17 00:00:00 2001
From: Piotr Maslanka <piotr.maslanka@henrietta.com.pl>
Date: Sat, 9 Dec 2017 21:12:36 +0100
Subject: [PATCH] smarter builder

---
 firanka/builder.py    | 28 ++++++++++++++++++++++------
 tests/test_builder.py |  5 +++--
 2 files changed, 25 insertions(+), 8 deletions(-)

diff --git a/firanka/builder.py b/firanka/builder.py
index 1b15543..b9d499a 100644
--- a/firanka/builder.py
+++ b/firanka/builder.py
@@ -15,7 +15,10 @@ __all__ = [
 ]
 
 class DiscreteKnowledgeBuilder(object):
-    def __init__(self, series):
+    def __init__(self, series=None):
+
+        if series is None:
+            series = DiscreteSeries([], '(0;0)')
 
         if not isinstance(series, DiscreteSeries):
             raise TypeError('discrete knowledge builder supports only discrete series')
@@ -34,11 +37,24 @@ class DiscreteKnowledgeBuilder(object):
 
         self.new_data[index] = value
 
-    def update_series(self):
-        """:return: a new DiscreteSeries instance"""
+    def as_series(self):
+        """
+        Update
+        :return: a new DiscreteSeries instance
+        """
+
+        new_data = []
+
+        cp_new_data = copy.copy(self.new_data)
+
+        # Update
+        for k, v in self.series.data:
+            if k in cp_new_data:
+                v = cp_new_data.pop(k)
+            new_data.append((k,v))
 
-        new_data = copy.copy(self.series.data)
-        for k,v in self.new_data.items():
-            new_data.add((k,v))
+        # Add those that remained
+        for k,v in cp_new_data.items():
+            new_data.append((k,v))
 
         return DiscreteSeries(new_data, self.domain)
diff --git a/tests/test_builder.py b/tests/test_builder.py
index 0c2b704..51c816a 100644
--- a/tests/test_builder.py
+++ b/tests/test_builder.py
@@ -16,9 +16,10 @@ class TestBuilder(unittest.TestCase):
 
         kb.put(3, 4)
         kb.put(-1, 5)
+        kb.put(0, 2)
         kb.put(-1, 6)
 
-        s2 = kb.update_series()
+        s2 = kb.as_series()
 
         self.assertTrue(s2.domain, '<-1;3>')
-        self.assertEqual(s2.data,[(-1,6), (0,1), (1,2), (3,4)])
+        self.assertEqual(s2.data,[(-1,6), (0,2), (1,2), (3,4)])
-- 
GitLab