From 19d19f3ba6a54b2ea2d04e86f7e11392b9d44cf4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Wed, 26 May 2021 16:44:40 +0200
Subject: [PATCH] improved checking for malformed structures during parse, v1.6

---
 CHANGELOG.md           |  2 ++
 docs/specification.rst | 24 ++++++++++++------------
 docs/usage.rst         |  3 ++-
 minijson/routines.pyx  | 24 +++++++++++++++++++-----
 setup.py               |  2 +-
 tests/test_minijson.py |  6 ++++++
 6 files changed, 42 insertions(+), 19 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1714661..c5a4e37 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,3 +2,5 @@ Changelog is kept at [GitHub](https://github.com/Dronehub/minijson/releases),
 here's only the changelog for the version in development
 
 # v1.6
+
+* improved checking for malformed structures during parse
diff --git a/docs/specification.rst b/docs/specification.rst
index c789bfd..a64aebe 100644
--- a/docs/specification.rst
+++ b/docs/specification.rst
@@ -15,7 +15,7 @@ Type Value consists of:
 * unsigned char * data
 
 * If value's highest bit is turned on, then remains are a UTF-8 string
-with len of (value & 0x7F)
+  with len of (value & 0x7F)
 * If value's two highest bits are 0100 or 0101, then four lowest bits encode the number of elements,
   and the four highest bits encode type of the object:
   * 0100 - a list
@@ -32,21 +32,21 @@ with len of (value & 0x7F)
 * If value is 5, then next data is unsigned short
 * If value is 6, then next data is unsigned char
 * If value is 7, then next data is number of elements of a list,
- follows by Value of each element
+  follows by Value of each element
 * If value is 8, the value is a NULL
 * If value is 9, then next element is a IEEE single
 * If value is 10, then next element is a IEEE double
 * If value is 11, then next element is amount of entries for
-    an object, then there goes the length of the field name,
-    followed by field name in UTF-8, and then goes the Value
-    of the element
+  an object, then there goes the length of the field name,
+  followed by field name in UTF-8, and then goes the Value
+  of the element
 * If value is 12, then next data is unsigned int24
 * If value is 13, then next data is an unsigned short representing the count
-    of characters, and then these characters follow and are
-    interpreted as a UTF-8 string
+  of characters, and then these characters follow and are
+  interpreted as a UTF-8 string
 * If value is 14, then next data is an unsigned int representing the count
-    of characters, and then these characters follow and are
-    interpreted as a UTF-8 string
+  of characters, and then these characters follow and are
+  interpreted as a UTF-8 string
 * If value is 15, then next data is a unsigned short,
   and then a list follows of that many elements
 * If value is 16, then next data is a unsigned int,
@@ -56,8 +56,8 @@ with len of (value & 0x7F)
 * If value is 18, then next data is an unsigned int,
   and then an object follows of that many elements
 * If value is 19, then next data is an unsigned int,
-    and then follow that many pairs of Values (key: value)
+  and then follow that many pairs of Values (key: value)
 * If value is 20, then next data is an unsigned char,
-    and then follow that many pairs of Values (key: value)
+  and then follow that many pairs of Values (key: value)
 * If value is 21, then next data is an unsigned short,
-    and then follow that many pairs of Values (key: value)
+  and then follow that many pairs of Values (key: value)
diff --git a/docs/usage.rst b/docs/usage.rst
index f985e4b..af88af6 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -59,4 +59,5 @@ Example:
 
     a = Test(3)
     b = dumps_object(a)
-    loads_object(b, Test)
+    c = loads_object(b, Test)
+    assert a.a == c.a
diff --git a/minijson/routines.pyx b/minijson/routines.pyx
index 48a8f2c..4447b04 100644
--- a/minijson/routines.pyx
+++ b/minijson/routines.pyx
@@ -136,13 +136,17 @@ cpdef tuple parse(bytes data, int starting_position):
         char sint8
         list e_list
         dict e_dict
-        bytes b_field_name
+        bytes b_field_name, byte_data
         str s_field_name
     try:
         if value_type & 0x80:
             string_length = value_type & 0x7F
             try:
-                return string_length+1, data[starting_position+1:starting_position+string_length+1].decode('utf-8')
+                byte_data = data[starting_position+1:starting_position+string_length+1]
+                if len(byte_data) != string_length:
+                    raise DecodingError('Too short a frame, expected %s bytes got %s' % (string_length,
+                                                                                         len(byte_data)))
+                return string_length+1, byte_data.decode('utf-8')
             except UnicodeDecodeError as e:
                 raise DecodingError('Invalid UTF-8') from e
         elif value_type & 0xF0 == 0x40:
@@ -163,6 +167,8 @@ cpdef tuple parse(bytes data, int starting_position):
         elif value_type == 0:
             string_length = data[starting_position+1]
             offset, b_field_name = parse_cstring(data, starting_position+1)
+            if len(b_field_name) != string_length:
+                raise DecodingError('Expected %s bytes, got %s' % (string_length, len(b_field_name)))
             try:
                 return offset+1, b_field_name.decode('utf-8')
             except UnicodeDecodeError as e:
@@ -209,10 +215,18 @@ cpdef tuple parse(bytes data, int starting_position):
             return offset+2, e_dict
         elif value_type == 13:
             string_length, = STRUCT_H.unpack(data[starting_position+1:starting_position+3])
-            return 3+string_length, data[starting_position+3:starting_position+string_length+3].decode('utf-8')
+            byte_data = data[starting_position+3:starting_position+string_length+3]
+            if len(byte_data) != string_length:
+                raise DecodingError('Too short a frame, expected %s bytes got %s' % (string_length,
+                                                                                     len(byte_data)))
+            return 3+string_length, byte_data.decode('utf-8')
         elif value_type == 14:
             string_length, = STRUCT_L.unpack(data[starting_position+1:starting_position+5])
-            return 5+string_length, data[starting_position+5:starting_position+string_length+5].decode('utf-8')
+            byte_data = data[starting_position+5:starting_position+string_length+5]
+            if len(byte_data) != string_length:
+                raise DecodingError('Too short a frame, expected %s bytes got %s' % (string_length,
+                                                                                     len(byte_data)))
+            return 5+string_length, byte_data.decode('utf-8')
         elif value_type == 15:
             elements, = STRUCT_H.unpack(data[starting_position+1:starting_position+3])
             offset, e_list = parse_list(data, elements, starting_position+3)
@@ -242,7 +256,7 @@ cpdef tuple parse(bytes data, int starting_position):
             offset, e_dict = parse_sdict(data, elements, starting_position+3)
             return offset+3, e_dict
         raise DecodingError('Unknown sequence type %s!' % (value_type, ))
-    except IndexError as e:
+    except (IndexError, struct.error) as e:
         raise DecodingError('String too short!') from e
 
 cpdef object loads(bytes data):
diff --git a/setup.py b/setup.py
index ee331f8..51b3c78 100644
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@ if 'DEBUG' in os.environ:
     directives['embedsignature'] = True
 
 
-setup(version='1.6a1',
+setup(version='1.6',
       packages=find_packages(include=['minijson', 'minijson.*']),
       ext_modules=build([Multibuild('minijson', find_pyx('minijson'),
                                     dont_snakehouse=dont_snakehouse), ],
diff --git a/tests/test_minijson.py b/tests/test_minijson.py
index 7737d97..55dd3ea 100644
--- a/tests/test_minijson.py
+++ b/tests/test_minijson.py
@@ -7,6 +7,12 @@ class TestMiniJSON(unittest.TestCase):
     def assertSameAfterDumpsAndLoads(self, c):
         self.assertEqual(loads(dumps(c)), c)
 
+    def test_malformed(self):
+        self.assertRaises(EncodingError, lambda: dumps(2+3j))
+        self.assertRaises(DecodingError, lambda: loads(b'\x00\x02a'))
+        self.assertRaises(DecodingError, lambda: loads(b'\x00\x02a'))
+        self.assertRaises(DecodingError, lambda: loads(b'\x09\x00'))
+
     def test_short_nonstring_key_dicts(self):
         a = {}
         for i in range(20):
-- 
GitLab