From afe820f03fc36643ebcd26d4be27fe0972d33b22 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Piotr=20Ma=C5=9Blanka?= <piotr.maslanka@henrietta.com.pl>
Date: Wed, 30 Jun 2021 15:07:58 +0200
Subject: [PATCH] try to add MiniJSONEncoder

---
 CHANGELOG.md           |   2 +
 Dockerfile             |   6 +-
 docs/usage.rst         |   9 ++
 minijson.pyx           | 352 ++++++++++++++++++++++++-----------------
 tests/test_minijson.py |  21 ++-
 5 files changed, 244 insertions(+), 146 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index adae924..759546d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,3 +2,5 @@ Changelog is kept at [GitHub](https://github.com/Dronehub/minijson/releases),
 here's only the changelog for the version in development
 
 # v2.5
+
+* added `MiniJSONEncoder`
diff --git a/Dockerfile b/Dockerfile
index 5a642ca..a6138e8 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,12 +1,12 @@
-FROM pypy:3.5
+FROM python:3.8
 RUN apt-get update && \
     apt-get install -y patchelf
-RUN pypy3 -m pip install Cython pytest coverage pytest-cov auditwheel doctor-wheel twine
+RUN python -m pip install Cython pytest coverage pytest-cov auditwheel doctor-wheel twine
 
 WORKDIR /tmp/compile
 ADD . /tmp/compile/
 
-RUN pypy3 setup.py install && \
+RUN python setup.py install && \
     chmod ugo+x /tmp/compile/tests/test.sh
 
 CMD ["/tmp/compile/tests/test.sh"]
diff --git a/docs/usage.rst b/docs/usage.rst
index 807988e..3b52e3f 100644
--- a/docs/usage.rst
+++ b/docs/usage.rst
@@ -66,3 +66,12 @@ Example:
     b = dumps_object(a)
     c = loads_object(b, Test)
     assert a.a == c.a
+
+MiniJSONEncoder
+---------------
+
+There's also a class available for encoding. Use it like you would a normal Python
+:code:`JSONEncoder`:
+
+.. autoclass:: minijson.MiniJSONEncoder
+    :members:
diff --git a/minijson.pyx b/minijson.pyx
index c518961..b996151 100644
--- a/minijson.pyx
+++ b/minijson.pyx
@@ -302,165 +302,233 @@ cdef inline bint is_jsonable(y):
     return y is None or isinstance(y, (int, float, str, dict, list, tuple))
 
 
-cpdef int dump(object data, cio: io.BytesIO, default: tp.Optional[tp.Callable] = None) except -1:
+cdef class MiniJSONEncoder:
     """
-    Write an object to a stream
+    A base class for encoders.
 
-    :param data: object to write
-    :param cio: stream to write to
-    :param default: a function that should be used to convert non-JSONable objects to JSONable ones.
-        Default, None, will raise an EncodingError upon encountering such a value
-    :return: amount of bytes written
-    :raises EncodingError: invalid data
+    It is advised to use this class over :meth:`~minijson.dump` and
+    :meth:`~minijson.dumps` due to finer grained control over floats.
+
+    :param default: a default function used
+    :param use_double: whether to use doubles instead of floats to represent floating point numbers
+
+    :ivar use_double: (bool) whether to use doubles instead of floats (used when
+        :meth:`~minijson.MiniJSONEncoder.should_double_be_used` is not overrided)
     """
     cdef:
-        str field_name
-        int length
-        bytes b_data
-    if data is None:
-        cio.write(b'\x08')
-        return 1
-    elif data is True:
-        cio.write(b'\x16')
-        return 1
-    elif data is False:
-        cio.write(b'\x17')
-        return 1
-    elif isinstance(data, str):
-        length = len(data)
-        if length < 128:
-            cio.write(bytearray([0x80 | length]))
-            cio.write(data.encode('utf-8'))
-            return 1+length
-        elif length <= 0xFF:
-            cio.write(bytearray([0, length]))
-            cio.write(data.encode('utf-8'))
-            return 2+length
-        elif length <= 0xFFFF:
-            cio.write(b'\x0D')
-            cio.write(STRUCT_H.pack(length))
-            cio.write(data.encode('utf-8'))
-            return 3+length
-        else:       # Python strings cannot grow past 0xFFFFFFFF characters
-            cio.write(b'\x0E')
-            cio.write(STRUCT_L.pack(length))
-            cio.write(data.encode('utf-8'))
-            return 5+length
-    elif isinstance(data, int):
-        if -128 <= data <= 127: # signed char, type 3
-            cio.write(b'\x03')
-            cio.write(STRUCT_b.pack(data))
-            return 2
-        elif 0 <= data <= 255:  # unsigned char, type 6
-            cio.write(bytearray([6, data]))
-            return 2
-        elif -32768 <= data <= 32767:   # signed short, type 2
-            cio.write(b'\x02')
-            cio.write(STRUCT_h.pack(data))
-            return 3
-        elif 0 <= data <= 65535:        # unsigned short, type 5
-            cio.write(b'\x05')
-            cio.write(STRUCT_H.pack(data))
-            return 3
-        elif 0 <= data <= 0xFFFFFF:         # unsigned 3byte, type 12
-            cio.write(b'\x0C')
-            cio.write(STRUCT_L.pack(data)[1:])
-            return 4
-        elif -2147483648 <= data <= 2147483647:     # signed int, type 1
-            cio.write(b'\x01')
-            cio.write(STRUCT_l.pack(data))
-            return 5
-        elif 0 <= data <= 0xFFFFFFFF:       # unsigned int, type 4
-            cio.write(b'\x04')
-            cio.write(STRUCT_L.pack(data))
-            return 5
+        object _default
+        bint use_double
+
+    def __init__(self, default: tp.Optional[None] = None,
+                 use_double: bool = False):
+        self._default = default
+        self.use_double = use_double
+
+    cpdef bint should_double_be_used(self, double x):
+        """
+        A function that you are meant to overload that will decide on a per-case basis
+        which representation should be used for given number.
+        
+        :param x: number to check 
+        :return: True if double should be used, else False
+        """
+        return self.use_double
+
+    cpdef default(self, v):
+        """
+        Convert an object to a JSON-able representation.
+        
+        Overload this to provide your default function in other way that giving
+        the callable as a parameter.
+
+        :param v: object to convert
+        :return: a JSONable representation
+        """
+        if self._default is None:
+            raise EncodingError('Unknown value type %s' % (v, ))
         else:
-            length = 5
-            while True:
-                try:
-                    b_data = data.to_bytes(length, 'big', signed=True)
-                    break
-                except OverflowError:
-                    length += 1
-            cio.write(bytearray([0x18, length]))
-            cio.write(b_data)
-    elif isinstance(data, float):
-        if float_encoding_mode == 0:
-            cio.write(b'\x09')
-            cio.write(STRUCT_f.pack(data))
-            return 5
-        else:
-            cio.write(b'\x0A')
-            cio.write(STRUCT_d.pack(data))
-            return 9
-    elif isinstance(data, (tuple, list)):
-        length = len(data)
-        if length < 16:
-            cio.write(bytearray([0b01000000 | length]))
-            length = 1
-        elif length < 256:
-            cio.write(bytearray([7, length]))
-            length = 2
-        elif length < 65536:
-            cio.write(b'\x0F')
-            cio.write(STRUCT_H.pack(length))
-            length = 3
-        elif length <= 0xFFFFFFFF:
-            cio.write(b'\x10')
-            cio.write(STRUCT_L.pack(length))
-            length = 5
-        for elem in data:
-            length += dump(elem, cio, default)
-        return length
-    elif isinstance(data, dict):
-        length = len(data)
-        if can_be_encoded_as_a_dict(data):
+            return self._default(v)
+
+    def encode(self, v) -> bytes:
+        """
+        Encode a provided object
+
+        :param v: object to encode
+        :return: returned bytes
+        """
+        cio = io.BytesIO()
+        self.dump(v, cio)
+        return cio.getvalue()
+
+    cpdef int dump(self, object data, cio: io.BytesIO) except -1:
+        """
+        Write an object to a stream
+    
+        :param data: object to write
+        :param cio: stream to write to
+        :return: amount of bytes written
+        :raises EncodingError: invalid data
+        """
+        cdef:
+            str field_name
+            int length
+            bytes b_data
+        if data is None:
+            cio.write(b'\x08')
+            return 1
+        elif data is True:
+            cio.write(b'\x16')
+            return 1
+        elif data is False:
+            cio.write(b'\x17')
+            return 1
+        elif isinstance(data, str):
+            length = len(data)
+            if length < 128:
+                cio.write(bytearray([0x80 | length]))
+                cio.write(data.encode('utf-8'))
+                return 1+length
+            elif length <= 0xFF:
+                cio.write(bytearray([0, length]))
+                cio.write(data.encode('utf-8'))
+                return 2+length
+            elif length <= 0xFFFF:
+                cio.write(b'\x0D')
+                cio.write(STRUCT_H.pack(length))
+                cio.write(data.encode('utf-8'))
+                return 3+length
+            else:       # Python strings cannot grow past 0xFFFFFFFF characters
+                cio.write(b'\x0E')
+                cio.write(STRUCT_L.pack(length))
+                cio.write(data.encode('utf-8'))
+                return 5+length
+        elif isinstance(data, int):
+            if -128 <= data <= 127: # signed char, type 3
+                cio.write(b'\x03')
+                cio.write(STRUCT_b.pack(data))
+                return 2
+            elif 0 <= data <= 255:  # unsigned char, type 6
+                cio.write(bytearray([6, data]))
+                return 2
+            elif -32768 <= data <= 32767:   # signed short, type 2
+                cio.write(b'\x02')
+                cio.write(STRUCT_h.pack(data))
+                return 3
+            elif 0 <= data <= 65535:        # unsigned short, type 5
+                cio.write(b'\x05')
+                cio.write(STRUCT_H.pack(data))
+                return 3
+            elif 0 <= data <= 0xFFFFFF:         # unsigned 3byte, type 12
+                cio.write(b'\x0C')
+                cio.write(STRUCT_L.pack(data)[1:])
+                return 4
+            elif -2147483648 <= data <= 2147483647:     # signed int, type 1
+                cio.write(b'\x01')
+                cio.write(STRUCT_l.pack(data))
+                return 5
+            elif 0 <= data <= 0xFFFFFFFF:       # unsigned int, type 4
+                cio.write(b'\x04')
+                cio.write(STRUCT_L.pack(data))
+                return 5
+            else:
+                length = 5
+                while True:
+                    try:
+                        b_data = data.to_bytes(length, 'big', signed=True)
+                        break
+                    except OverflowError:
+                        length += 1
+                cio.write(bytearray([0x18, length]))
+                cio.write(b_data)
+        elif isinstance(data, float):
+            if self.should_double_be_used(data):
+                cio.write(b'\x0A')
+                cio.write(STRUCT_d.pack(data))
+                return 9
+            else:
+                cio.write(b'\x09')
+                cio.write(STRUCT_f.pack(data))
+                return 5
+        elif isinstance(data, (tuple, list)):
+            length = len(data)
             if length < 16:
-                cio.write(bytearray([0b01010000 | length]))
+                cio.write(bytearray([0b01000000 | length]))
                 length = 1
             elif length < 256:
-                cio.write(bytearray([11, len(data)]))
+                cio.write(bytearray([7, length]))
                 length = 2
             elif length < 65536:
-                cio.write(b'\x11')
+                cio.write(b'\x0F')
                 cio.write(STRUCT_H.pack(length))
                 length = 3
             elif length <= 0xFFFFFFFF:
-                cio.write(b'\x12')
+                cio.write(b'\x10')
                 cio.write(STRUCT_L.pack(length))
                 length = 5
-            for field_name, elem in data.items():
-                cio.write(bytearray([len(field_name)]))
-                cio.write(field_name.encode('utf-8'))
-                length += dump(elem, cio, default)
+            for elem in data:
+                length += self.dump(elem, cio)
             return length
+        elif isinstance(data, dict):
+            length = len(data)
+            if can_be_encoded_as_a_dict(data):
+                if length < 16:
+                    cio.write(bytearray([0b01010000 | length]))
+                    length = 1
+                elif length < 256:
+                    cio.write(bytearray([11, len(data)]))
+                    length = 2
+                elif length < 65536:
+                    cio.write(b'\x11')
+                    cio.write(STRUCT_H.pack(length))
+                    length = 3
+                elif length <= 0xFFFFFFFF:
+                    cio.write(b'\x12')
+                    cio.write(STRUCT_L.pack(length))
+                    length = 5
+                for field_name, elem in data.items():
+                    cio.write(bytearray([len(field_name)]))
+                    cio.write(field_name.encode('utf-8'))
+                    length += self.dump(elem, cio)
+                return length
+            else:
+                if length <= 0xF:
+                    cio.write(bytearray([0b01100000 | length]))
+                    offset = 1
+                elif length <= 0xFF:
+                    cio.write(bytearray([20, length]))
+                    offset = 2
+                elif length <= 0xFFFF:
+                    cio.write(b'\x15')
+                    cio.write(STRUCT_H.pack(length))
+                    offset = 3
+                else:       # Python objects cannot grow to have more than 0xFFFFFFFF members
+                    cio.write(b'\x13')
+                    cio.write(STRUCT_L.pack(length))
+                    offset = 5
+
+                for key, value in data.items():
+                    offset += self.dump(key, cio)
+                    offset += self.dump(value, cio)
+                return offset
         else:
-            if length <= 0xF:
-                cio.write(bytearray([0b01100000 | length]))
-                offset = 1
-            elif length <= 0xFF:
-                cio.write(bytearray([20, length]))
-                offset = 2
-            elif length <= 0xFFFF:
-                cio.write(b'\x15')
-                cio.write(STRUCT_H.pack(length))
-                offset = 3
-            else:       # Python objects cannot grow to have more than 0xFFFFFFFF members
-                cio.write(b'\x13')
-                cio.write(STRUCT_L.pack(length))
-                offset = 5
-
-            for key, value in data.items():
-                offset += dump(key, cio, default)
-                offset += dump(value, cio, default)
-            return offset
-    elif default is None:
-        raise EncodingError('Unknown value type %s' % (data, ))
-    else:
-        v = default(data)
-        if not is_jsonable(v):
-            raise EncodingError('Default returned type %s, which is not jsonable' % (type(v), ))
-        return dump(v, cio, default)
+            v = self.default(data)
+            return self.dump(v, cio)
+
+
+cpdef int dump(object data, cio: io.BytesIO, default: tp.Callable) except -1:
+    """
+    Write an object to a stream
+
+    :param data: object to write
+    :param cio: stream to write to
+    :param default: a callable/1 that should return a JSON-able representation for objects
+        that can't be JSONed otherwise
+    :return: amount of bytes written
+    :raises EncodingError: invalid data
+    """
+    global float_encoding_mode
+    cdef MiniJSONEncoder mje = MiniJSONEncoder(default, float_encoding_mode == 1)
+    return mje.dump(data, cio)
 
 
 cpdef bytes dumps(object data, default: tp.Optional[tp.Callable] = None):
diff --git a/tests/test_minijson.py b/tests/test_minijson.py
index 46e8a4c..4d3f2a8 100644
--- a/tests/test_minijson.py
+++ b/tests/test_minijson.py
@@ -1,10 +1,29 @@
 import unittest
 from minijson import dumps, loads, dumps_object, loads_object, EncodingError, DecodingError, \
-    switch_default_double, switch_default_float
+    switch_default_double, switch_default_float, MiniJSONEncoder
 
 
 class TestMiniJSON(unittest.TestCase):
 
+    def test_encoder_overrided_default(self):
+        class Encoder(MiniJSONEncoder):
+            def default(self, v):
+                return v.real, v.imag
+
+        e = Encoder()
+        e.encode(2+3j)
+
+    def test_encoder_given_default(self):
+        def encode(v):
+            return v.real, v.imag
+
+        e = MiniJSONEncoder(default=encode)
+        e.encode(2 + 3j)
+
+    def test_encoder_no_default(self):
+        e = MiniJSONEncoder()
+        self.assertRaises(EncodingError, lambda: e.encode(2+3j))
+
     def test_accepts_bytearrays(self):
         b = {'test': 'hello'}
         a = dumps(b)
-- 
GitLab