From be7d4258689e24c66dd35992263047688a025ba3 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Wed, 26 May 2021 15:55:23 +0200
Subject: [PATCH] improved encoding of small non-all-keys-are-strings
 dictionaries

---
 CHANGELOG.md           |  3 ++-
 docs/specification.rst | 16 +++++++++++-----
 minijson/routines.pyx  | 31 ++++++++++++++++++++++++++++---
 setup.py               |  2 +-
 tests/test_minijson.py | 13 +++++++++++++
 5 files changed, 55 insertions(+), 10 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index bcef5ea..2167775 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,4 +3,5 @@ here's only the changelog for the version in development
 
 # v1.4
 
-* _TBA_
+* improved encoding of small non-all-keys-are-strings dictionaries
+
diff --git a/docs/specification.rst b/docs/specification.rst
index f18f4ad..c789bfd 100644
--- a/docs/specification.rst
+++ b/docs/specification.rst
@@ -19,9 +19,11 @@ with len of (value & 0x7F)
 * If value's two highest bits are 0100 or 0101, then four lowest bits encode the number of elements,
   and the four highest bits encode type of the object:
   * 0100 - a list
-  * 0101 - an object
-  Standard representation for an object or list follows,
-  sans the element count.
+  * 0101 - an object whose keys are all strings
+  * 0110 - an object whose keys are not all strings (see value of 19 and 20 to know how it's
+    represented).
+  * Standard representation for a non-key-string object (value 19), string key object (value 11) or list (value 7) follows,
+    sans the element count.
 * If value is zero, then next character is the length of the string followed by the string
 * If value is 1, then next data is signed int
 * If value is 2, then next data is signed short
@@ -51,7 +53,11 @@ with len of (value & 0x7F)
   and then a list follows of that many elements
 * If value is 17, then next data is a unsigned short,
   and then an object follows of that many elements
-* If value is 18, then next data is a unsigned int,
+* If value is 18, then next data is an unsigned int,
   and then an object follows of that many elements
-* If value is 19, then next data is a unsigned int,
+* If value is 19, then next data is an unsigned int,
+    and then follow that many pairs of Values (key: value)
+* If value is 20, then next data is an unsigned char,
+    and then follow that many pairs of Values (key: value)
+* If value is 21, then next data is an unsigned short,
     and then follow that many pairs of Values (key: value)
diff --git a/minijson/routines.pyx b/minijson/routines.pyx
index 166dc98..5e108e7 100644
--- a/minijson/routines.pyx
+++ b/minijson/routines.pyx
@@ -155,6 +155,11 @@ cpdef tuple parse(bytes data, int starting_position):
             elements = value_type & 0xF
             offset, e_dict = parse_dict(data, elements, starting_position+1)
             return offset+1, e_dict
+        elif value_type & 0xF0 == 0x60:
+            e_dict = {}
+            elements = value_type & 0xF
+            offset, e_dict = parse_sdict(data, elements, starting_position+1)
+            return offset+1, e_dict
         elif value_type == 0:
             string_length = data[starting_position+1]
             offset, b_field_name = parse_cstring(data, starting_position+1)
@@ -228,6 +233,14 @@ cpdef tuple parse(bytes data, int starting_position):
             elements, = STRUCT_L.unpack(data[starting_position+1:starting_position+5])
             offset, e_dict = parse_sdict(data, elements, starting_position+5)
             return offset+5, e_dict
+        elif value_type == 20:
+            elements = data[starting_position+1]
+            offset, e_dict = parse_sdict(data, elements, starting_position+2)
+            return offset+2, e_dict
+        elif value_type == 21:
+            elements, = STRUCT_H.unpack(data[starting_position+1:starting_position+3])
+            offset, e_dict = parse_sdict(data, elements, starting_position+3)
+            return offset+3, e_dict
         raise DecodingError('Unknown sequence type %s!' % (value_type, ))
     except IndexError as e:
         raise DecodingError('String too short!') from e
@@ -362,9 +375,21 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
                 raise EncodingError('Keys have to be strings!') from e
             return length
         else:
-            cio.write(b'\x13')
-            cio.write(STRUCT_L.pack(length))
-            offset = 5
+            if length < 16:
+                cio.write(bytearray([0b01100000 | length]))
+                offset = 1
+            elif length < 256:
+                cio.write(bytearray([20, length]))
+                offset = 2
+            elif length < 0xFFFF:
+                cio.write(b'\x15')
+                cio.write(STRUCT_H.pack(length))
+                offset = 3
+            else:
+                cio.write(b'\x13')
+                cio.write(STRUCT_L.pack(length))
+                offset = 5
+
             for key, value in data.items():
                 offset += dump(key, cio)
                 offset += dump(value, cio)
diff --git a/setup.py b/setup.py
index 0c5f945..0ae1e65 100644
--- a/setup.py
+++ b/setup.py
@@ -14,7 +14,7 @@ if 'DEBUG' in os.environ:
     directives['embedsignature'] = True
 
 
-setup(version='1.4a1',
+setup(version='1.4',
       packages=find_packages(include=['minijson', 'minijson.*']),
       ext_modules=build([Multibuild('minijson', find_pyx('minijson'),
                                     dont_snakehouse=dont_snakehouse), ],
diff --git a/tests/test_minijson.py b/tests/test_minijson.py
index 0b2cacc..d839248 100644
--- a/tests/test_minijson.py
+++ b/tests/test_minijson.py
@@ -7,6 +7,19 @@ class TestMiniJSON(unittest.TestCase):
     def assertSameAfterDumpsAndLoads(self, c):
         self.assertEqual(loads(dumps(c)), c)
 
+    def test_short_nonstring_key_dicts(self):
+        a = {}
+        for i in range(20):
+            a[i] = i
+        self.assertSameAfterDumpsAndLoads(a)
+        a = {}
+        for i in range(300):
+            a[i] = i
+        self.assertSameAfterDumpsAndLoads(a)
+        for i in range(700000):
+            a[i] = i
+        self.assertSameAfterDumpsAndLoads(a)
+
     def test_string(self):
         a = 'test'
         b = 't'*128
-- 
GitLab