From 8d5d4cfc638b5c80dbb932fe7b8c97ac665f5b13 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Wed, 26 May 2021 14:30:59 +0200
Subject: [PATCH] fix a bug with lists

---
 CHANGELOG.md           |   8 ++
 Dockerfile             |   2 +
 minijson/__init__.py   |   9 +-
 minijson/routines.pyx  | 238 +++++++++++++++++++++++------------------
 setup.py               |   2 +-
 tests/test_minijson.py |  15 +++
 6 files changed, 163 insertions(+), 111 deletions(-)
 create mode 100644 CHANGELOG.md

diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..0450230
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,8 @@
+Changelog is kept at [GitHub](https://github.com/Dronehub/minijson/releases),
+here's only the changelog for the version in development
+
+# v1.1
+
+* fixed to work under older Pythons (got rid of the f-strings)
+* fixed docstrings to signal that some functions raise exceptions
+* fixed a bug with encoding long lists
diff --git a/Dockerfile b/Dockerfile
index aa2b5e4..56d6625 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -2,6 +2,8 @@ FROM smokserwis/build:python3
 
 RUN pip install snakehouse Cython satella pytest
 
+ENV DEBUG=1
+
 WORKDIR /tmp/compile
 ADD minijson /tmp/compile/minijson
 ADD setup.py /tmp/compile/setup.py
diff --git a/minijson/__init__.py b/minijson/__init__.py
index fd95c4a..5308c40 100644
--- a/minijson/__init__.py
+++ b/minijson/__init__.py
@@ -1,6 +1,3 @@
-from .routines import dumps, loads, switch_default_double, switch_default_float, \
-    dumps_object, loads_object, parse, dump
-from .exceptions import MiniJSONError, EncodingError, DecodingError
-
-__all__ = ['dumps', 'loads', 'switch_default_float', 'switch_default_double',
-           'dumps_object', 'loads_object', 'parse', 'dump']
+from .routines import dumps, loads, dump, parse, dumps_object, loads_object, \
+    switch_default_float, switch_default_double
+from .exceptions import MiniJSONError, DecodingError, EncodingError
diff --git a/minijson/routines.pyx b/minijson/routines.pyx
index 2dc8422..06e4616 100644
--- a/minijson/routines.pyx
+++ b/minijson/routines.pyx
@@ -12,29 +12,76 @@ STRUCT_H = struct.Struct('>H')
 STRUCT_l = struct.Struct('>l')
 STRUCT_L = struct.Struct('>L')
 
-cdef int coding_mode = 0     # 0 for default FLOAT
-                             # 1 for default DOUBLE
+cdef int float_encoding_mode = 0     # 0 for default FLOAT
+                                     # 1 for default DOUBLE
 
 cpdef void switch_default_float():
     """
     Set default encoding of floats to IEEE 754 single
     """
-    global coding_mode
-    coding_mode = 0
+    global float_encoding_mode
+    float_encoding_mode = 0
 
 cpdef void switch_default_double():
     """
     Set default encoding of floats to IEEE 754 double
     """
-    global coding_mode
-    coding_mode = 1
+    global float_encoding_mode
+    float_encoding_mode = 1
 
-cdef inline tuple parse_cstring(bytes data, int starting_position):
+cdef tuple parse_cstring(bytes data, int starting_position):
     cdef:
         int strlen = data[starting_position]
         bytes subdata = data[starting_position+1:starting_position+1+strlen]
     return strlen+1, subdata
 
+cdef tuple parse_list(bytes data, int elem_count, int starting_position):
+    """
+    Parse a list with this many elements
+    
+    :param data: data to parse as a list
+    :param elem_count: count of elements 
+    :param starting_position: starting position
+
+    :return: tuple of (how many bytes were there in the list, the list itself)
+    """
+    cdef:
+        list lst = []
+        int i, ofs, offset = 0
+    for i in range(elem_count):
+        ofs, elem = parse(data, starting_position+offset)
+        offset += ofs
+        lst.append(elem)
+    return offset, lst
+
+cdef inline tuple parse_dict(bytes data, int elem_count, int starting_position):
+    """
+    Parse a dict with this many elements
+    
+    :param data: data to parse as a list
+    :param elem_count: count of elements 
+    :param starting_position: starting position
+
+    :return: tuple of (how many bytes were there in the list, the dict itself)
+    """
+    cdef:
+        dict dct = {}
+        bytes b_field_name
+        str s_field_name
+        int i, ofs, offset = 0
+    for i in range(elem_count):
+        ofs, b_field_name = parse_cstring(data, starting_position+offset)
+        try:
+            s_field_name = b_field_name.decode('utf-8')
+        except UnicodeDecodeError as e:
+            raise DecodingError('Invalid UTF-8 field name!') from e
+        offset += ofs
+        ofs, elem = parse(data, starting_position+offset)
+        offset += ofs
+        dct[s_field_name] = elem
+    return offset, dct
+
+
 cpdef tuple parse(bytes data, int starting_position):
     """
     Parse given stream of data starting at a position
@@ -44,10 +91,11 @@ cpdef tuple parse(bytes data, int starting_position):
     :param starting_position: first position in the bytestring at which to look
     :return: a tuple of (how many bytes does this piece of data take, the piece of data itself)
     :rtype: tp.Tuple[int, tp.Any]
+    :raises DecodingError: invalid stream
     """
     cdef:
         int value_type = data[starting_position]
-        int string_length
+        int string_length, elements, i, offset, length
         unsigned int uint32
         int sint32
         unsigned short uint16
@@ -56,99 +104,75 @@ cpdef tuple parse(bytes data, int starting_position):
         char sint8
         list e_list
         dict e_dict
-        int elements, i, offset, length
         bytes b_field_name
         str s_field_name
-    if value_type & 0x80:
-        string_length = value_type & 0x7F
-        try:
-            return string_length+1, data[starting_position+1:starting_position+string_length+1].decode('utf-8')
-        except UnicodeDecodeError as e:
-            raise DecodingError('Invalid UTF-8') from e
-    elif value_type & 0xF0 == 0x40:
-        elements = value_type & 0xF
-        offset = 1
-        e_list = []
-        for i in range(elements):
-            length, elem = parse(data, starting_position+offset)
-            offset += length
-            e_list.append(elem)
-        return offset, e_list
-    elif value_type & 0xF0 == 0x50:
-        e_dict = {}
-        offset = 1
-        elements = value_type & 0xF
-
-        for i in range(elements):
-            length, b_field_name = parse_cstring(data, starting_position+offset)
-            s_field_name = b_field_name.decode('utf-8')
-            offset += length
-            length, elem = parse(data, starting_position+offset)
-            offset += length
-            e_dict[s_field_name] = elem
-        return offset, e_dict
-    elif value_type == 0:
-        string_length = data[starting_position+1]
-        offset, b_field_name = parse_cstring(data, starting_position+1)
-        try:
-            return offset+1, b_field_name.decode('utf-8')
-        except UnicodeDecodeError as e:
-            raise DecodingError('Invalid UTF-8') from e
-    elif value_type in (1, 4):
-        uint32 = (data[starting_position+1] << 24) | (data[starting_position+2] << 16) | (data[starting_position+3] << 8) | data[starting_position+4]
-        if value_type == 4:
-            return 5, uint32
-        else:
-            sint32 = uint32
-            return 5, sint32
-    elif value_type in (2, 5):
-        uint16 = (data[starting_position+1] << 8) | data[starting_position+2]
-        if value_type == 5:
-            return 3, uint16
-        else:
-            sint16 = uint16
-            return 3, sint16
-    elif value_type in (3, 6):
-        uint8 = data[starting_position+1]
-        if value_type == 6:
-            return 2, uint8
-        else:
-            sint8 = uint8
-            return 2, sint8
-    elif value_type == 7:
-        elements = data[starting_position+1]
-        e_list = []
-        offset = 2
-        for i in range(elements):
-            length, elem = parse(data, starting_position+offset)
-            offset += length
-            e_list.append(elem)
-        return e_list
-    elif value_type == 8:
-        return 1, None
-    elif value_type == 9:
-        return 5, *STRUCT_f.unpack(data[starting_position+1:starting_position+5])
-    elif value_type == 10:
-        return 9, *STRUCT_d.unpack(data[starting_position+1:starting_position+9])
-    elif value_type == 12:
-        uint32 = (data[starting_position+1] << 16) | (data[starting_position+2] << 8) | data[starting_position+3]
-        return 4, uint32
-    elif value_type == 11:
-        elements = data[starting_position+1]
-        e_dict = {}
-        offset = 2
-
-        for i in range(elements):
-            length, b_field_name = parse_cstring(data, starting_position+offset)
-            s_field_name = b_field_name.decode('utf-8')
-            offset += length
-            length, elem = parse(data, starting_position+offset)
-
-            offset += length
-            e_dict[s_field_name] = elem
-        return offset, e_dict
-    raise DecodingError(f'Unknown sequence type {value_type}!')
-
+    try:
+        if value_type & 0x80:
+            string_length = value_type & 0x7F
+            try:
+                return string_length+1, data[starting_position+1:starting_position+string_length+1].decode('utf-8')
+            except UnicodeDecodeError as e:
+                raise DecodingError('Invalid UTF-8') from e
+        elif value_type & 0xF0 == 0x40:
+            elements = value_type & 0xF
+            e_list = []
+            string_length, e_list = parse_list(data, elements, starting_position+1)
+            return string_length+1, e_list
+        elif value_type & 0xF0 == 0x50:
+            e_dict = {}
+            elements = value_type & 0xF
+            offset, e_dict = parse_dict(data, elements, starting_position+1)
+            return offset+1, e_dict
+        elif value_type == 0:
+            string_length = data[starting_position+1]
+            offset, b_field_name = parse_cstring(data, starting_position+1)
+            try:
+                return offset+1, b_field_name.decode('utf-8')
+            except UnicodeDecodeError as e:
+                raise DecodingError('Invalid UTF-8') from e
+        elif value_type in (1, 4):
+            uint32 = (data[starting_position+1] << 24) | (data[starting_position+2] << 16) | (data[starting_position+3] << 8) | data[starting_position+4]
+            if value_type == 4:
+                return 5, uint32
+            else:
+                sint32 = uint32
+                return 5, sint32
+        elif value_type in (2, 5):
+            uint16 = (data[starting_position+1] << 8) | data[starting_position+2]
+            if value_type == 5:
+                return 3, uint16
+            else:
+                sint16 = uint16
+                return 3, sint16
+        elif value_type in (3, 6):
+            uint8 = data[starting_position+1]
+            if value_type == 6:
+                return 2, uint8
+            else:
+                sint8 = uint8
+                return 2, sint8
+        elif value_type == 7:
+            elements = data[starting_position+1]
+            e_list = []
+            offset, e_list = parse_list(data, elements, starting_position+2)
+            return offset+2, e_list
+        elif value_type == 8:
+            return 1, None
+        elif value_type == 9:
+            return 5, *STRUCT_f.unpack(data[starting_position+1:starting_position+5])
+        elif value_type == 10:
+            return 9, *STRUCT_d.unpack(data[starting_position+1:starting_position+9])
+        elif value_type == 12:
+            uint32 = (data[starting_position+1] << 16) | (data[starting_position+2] << 8) | data[starting_position+3]
+            return 4, uint32
+        elif value_type == 11:
+            elements = data[starting_position+1]
+            e_dict = {}
+            offset, e_dict = parse_dict(data, elements, starting_position+2)
+            return offset+2, e_dict
+        raise DecodingError('Unknown sequence type %s!' % (value_type, ))
+    except IndexError as e:
+        raise DecodingError('String too short!') from e
 
 cpdef object loads(bytes data):
     """
@@ -167,7 +191,8 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
 
     :param data: object to write
     :param cio: stream to write to
-    :return: bytes written
+    :return: amount of bytes written
+    :raises EncodingError: invalid data
     """
     cdef:
         str field_name
@@ -216,9 +241,9 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
             cio.write(STRUCT_L.pack(data))
             return 5
         else:
-            raise EncodingError(f'Too large integer {data}')
+            raise EncodingError('Too large integer %s' % (data, ))
     elif isinstance(data, float):
-        if coding_mode == 0:
+        if float_encoding_mode == 0:
             cio.write(b'\x09')
             cio.write(STRUCT_f.pack(data))
             return 5
@@ -255,7 +280,8 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
             length += dump(elem, cio)
         return length
     else:
-        raise EncodingError(f'Unknown value type {data}')
+        raise EncodingError('Unknown value type %s' % (data, ))
+
 
 cpdef bytes dumps(object data):
     """
@@ -290,5 +316,9 @@ cpdef object loads_object(bytes data, object obj_class):
     :return: instance of obj_class
     :raises DecodingError: decoding error
     """
-    cdef dict kwargs = loads(data)
+    cdef dict kwargs
+    try:
+         kwargs = loads(data)
+    except TypeError:
+        raise DecodingError('Expected an object to be of type dict!')
     return obj_class(**kwargs)
diff --git a/setup.py b/setup.py
index 3169dc6..0f8a79c 100644
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@ if 'DEBUG' in os.environ:
     directives['embedsignature'] = True
 
 
-setup(version='1.0',
+setup(version='1.1rc1',
       packages=find_packages(include=['minijson', 'minijson.*']),
       ext_modules=build([Multibuild('minijson', find_pyx('minijson'),
                                     dont_snakehouse=dont_snakehouse), ],
diff --git a/tests/test_minijson.py b/tests/test_minijson.py
index e9ed1fe..4319416 100644
--- a/tests/test_minijson.py
+++ b/tests/test_minijson.py
@@ -13,6 +13,21 @@ class TestMiniJSON(unittest.TestCase):
         a = [None]*256
         self.assertRaises(EncodingError, lambda: dumps(a))
 
+    def test_long_lists(self):
+        a = [None]*17
+        b = dumps(a)
+        print('Encoded %s' % (b, ))
+        c = loads(b)
+        self.assertEqual(a, c)
+
+    def test_long_dicts(self):
+        a = {}
+        for i in range(17):
+            a[str(i)] = i
+        b = dumps(a)
+        c = loads(b)
+        self.assertEqual(a, c)
+
     def test_exceptions(self):
         a = {}
         for i in range(65535):
-- 
GitLab