From 4715553016bc72af449ffafe8a2257d23809453d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Tue, 29 Jun 2021 18:45:59 +0200
Subject: [PATCH] fix minijson

---
 .coveragerc            |  1 +
 CHANGELOG.md           |  3 +++
 Dockerfile             |  6 +++---
 docs/index.rst         |  3 +++
 minijson.pyx           | 44 +++++++++++++++++++++++++++++-------------
 setup.cfg              |  2 +-
 tests/test_minijson.py | 31 +++++++++++++++++++++++++++++
 7 files changed, 73 insertions(+), 17 deletions(-)

diff --git a/.coveragerc b/.coveragerc
index 947f627..5c37658 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -2,4 +2,5 @@
 omit=
     setup.py
     docs/*
+    tests/*
 plugins = Cython.Coverage
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 725c3e0..0ed6772 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,3 +2,6 @@ Changelog is kept at [GitHub](https://github.com/Dronehub/minijson/releases),
 here's only the changelog for the version in development
 
 # v2.4
+
+* added argument default
+* fixing issue with serializing classes that subclass dict, list and tuple
diff --git a/Dockerfile b/Dockerfile
index 3b29805..d01bdd6 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,14 +1,14 @@
-FROM pypy:3.5
+FROM python:3.5
 RUN apt-get update && \
     apt-get install -y patchelf
-RUN pypy3 -m pip install Cython pytest coverage pytest-cov auditwheel doctor-wheel twine
+RUN python -m pip install Cython pytest coverage pytest-cov auditwheel doctor-wheel twine
 
 ENV DEBUG=1
 
 WORKDIR /tmp/compile
 ADD . /tmp/compile/
 
-RUN pypy3 setup.py install && \
+RUN python setup.py install && \
     chmod ugo+x /tmp/compile/tests/test.sh
 
 CMD ["/tmp/compile/tests/test.sh"]
diff --git a/docs/index.rst b/docs/index.rst
index 0c4df87..c1a16a3 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -20,6 +20,9 @@ whose all keys are strings.
 You should avoid objects with keys different than strings, since they will always use a
 4-byte length field. This is to be improved in a future release.
 
+.. warning:: Take care for your data to be without cycles. Feeding the encoder cycles
+   will probably dump your interpreter's core.
+
 Indices and tables
 ==================
 
diff --git a/minijson.pyx b/minijson.pyx
index f6ec5a9..c518961 100644
--- a/minijson.pyx
+++ b/minijson.pyx
@@ -120,7 +120,7 @@ cdef inline tuple parse_sdict(bytes data, int elem_count, int starting_position)
     return offset, dct
 
 
-cdef inline bint can_be_encoded_as_a_dict(dict dct):
+cdef inline bint can_be_encoded_as_a_dict(dct):
     for key in dct.keys():
         if not isinstance(key, str):
             return False
@@ -298,12 +298,18 @@ cpdef object loads(object data):
     return parse(data, 0)[1]
 
 
-cpdef int dump(object data, cio: io.BytesIO) except -1:
+cdef inline bint is_jsonable(y):
+    return y is None or isinstance(y, (int, float, str, dict, list, tuple))
+
+
+cpdef int dump(object data, cio: io.BytesIO, default: tp.Optional[tp.Callable] = None) except -1:
     """
     Write an object to a stream
 
     :param data: object to write
     :param cio: stream to write to
+    :param default: a function that should be used to convert non-JSONable objects to JSONable ones.
+        Default, None, will raise an EncodingError upon encountering such a value
     :return: amount of bytes written
     :raises EncodingError: invalid data
     """
@@ -404,7 +410,7 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
             cio.write(STRUCT_L.pack(length))
             length = 5
         for elem in data:
-            length += dump(elem, cio)
+            length += dump(elem, cio, default)
         return length
     elif isinstance(data, dict):
         length = len(data)
@@ -426,7 +432,7 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
             for field_name, elem in data.items():
                 cio.write(bytearray([len(field_name)]))
                 cio.write(field_name.encode('utf-8'))
-                length += dump(elem, cio)
+                length += dump(elem, cio, default)
             return length
         else:
             if length <= 0xF:
@@ -445,35 +451,47 @@ cpdef int dump(object data, cio: io.BytesIO) except -1:
                 offset = 5
 
             for key, value in data.items():
-                offset += dump(key, cio)
-                offset += dump(value, cio)
+                offset += dump(key, cio, default)
+                offset += dump(value, cio, default)
             return offset
-    else:
+    elif default is None:
         raise EncodingError('Unknown value type %s' % (data, ))
+    else:
+        v = default(data)
+        if not is_jsonable(v):
+            raise EncodingError('Default returned type %s, which is not jsonable' % (type(v), ))
+        return dump(v, cio, default)
 
 
-cpdef bytes dumps(object data):
+cpdef bytes dumps(object data, default: tp.Optional[tp.Callable] = None):
     """
     Serialize given data to a MiniJSON representation
 
     :param data: data to serialize
+    :param default: a function that should be used to convert non-JSONable objects to JSONable ones.
+        Default, None, will raise an EncodingError upon encountering such a value
     :return: return MiniJSON representation
-    :raises DecodingError: object not serializable
+    :raises EncodingError: object not serializable
     """
     cio = io.BytesIO()
-    dump(data, cio)
+    dump(data, cio, default)
     return cio.getvalue()
 
 
-cpdef bytes dumps_object(object data):
+cpdef bytes dumps_object(object data, default: tp.Optional[tp.Callable] = None):
     """
-    Dump an object's __dict__
+    Dump an object's :code:`__dict__`.
+    
+    Note that subobject's :code:`__dict__` will not be copied. Use default for that.
     
     :param data: object to dump 
+    :param default: a function that should be used to convert non-JSONable objects to JSONable ones.
+        Default, None, will raise an EncodingError upon encountering such a value
     :return: resulting bytes
     :raises EncodingError: encoding error
     """
-    return dumps(data.__dict__)
+    return dumps(data.__dict__, default)
+
 
 cpdef object loads_object(data, object obj_class):
     """
diff --git a/setup.cfg b/setup.cfg
index d96fdd4..ba492f5 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,6 +1,6 @@
 # coding: utf-8
 [metadata]
-version = 2.4a2
+version = 2.4a4
 name = minijson
 long_description = file: README.md
 long_description_content_type = text/markdown; charset=UTF-8
diff --git a/tests/test_minijson.py b/tests/test_minijson.py
index 3db5d9d..c9a381b 100644
--- a/tests/test_minijson.py
+++ b/tests/test_minijson.py
@@ -17,6 +17,37 @@ class TestMiniJSON(unittest.TestCase):
     def assertSameAfterDumpsAndLoads(self, c):
         self.assertEqual(loads(dumps(c)), c)
 
+    def test_default(self):
+        def transform(c):
+            return c.real, c.imag
+
+        dumps(2 + 3j, transform)
+        dumps({'test': 2 + 3j}, transform)
+
+    def test_subclasses_of_dicts(self):
+        class Subclass(dict):
+            pass
+
+        a = Subclass({1: 2, 3: 4})
+        b = dumps(a)
+        self.assertEquals(loads(b), {1: 2, 3: 4})
+
+    def test_subclasses_of_lists(self):
+        class Subclass(list):
+            pass
+
+        a = Subclass([1, 2, 3])
+        b = dumps(a)
+        self.assertEquals(loads(b), [1, 2, 3])
+
+    def test_subclasses_of_tuples(self):
+        class Subclass(tuple):
+            pass
+
+        a = Subclass((1, 2, 3))
+        b = dumps(a)
+        self.assertEquals(loads(b), [1, 2, 3])
+
     def test_malformed(self):
         self.assertRaises(EncodingError, lambda: dumps(2 + 3j))
         self.assertLoadingIsDecodingError(b'\x00\x02a')
-- 
GitLab