From 9fcafffbf472219ffa77d9f81eb29905831e8b1a Mon Sep 17 00:00:00 2001
From: twmht <qrnnis2623891@gmail.com>
Date: Sun, 16 Apr 2017 14:13:01 +0800
Subject: [PATCH] add default merge operator

---
 rocksdb/_rocksdb.pyx          |  48 +++++++++++++++-
 rocksdb/merge_operator.pxd    |  21 +++++++
 rocksdb/options.pxd           |  12 +++-
 rocksdb/tests/test_db.py      | 101 ++++++++++++++++++++++++++++++++++
 rocksdb/tests/test_options.py |  33 +++++++++++
 5 files changed, 211 insertions(+), 4 deletions(-)

diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx
index b39fd67..ca31ec8 100644
--- a/rocksdb/_rocksdb.pyx
+++ b/rocksdb/_rocksdb.pyx
@@ -318,8 +318,8 @@ cdef class PyMergeOperator(object):
     cdef object ob
 
     def __cinit__(self, object ob):
+        self.ob = ob
         if isinstance(ob, IAssociativeMergeOperator):
-            self.ob = ob
             self.merge_op.reset(
                 <merge_operator.MergeOperator*>
                     new merge_operator.AssociativeMergeOperatorWrapper(
@@ -328,7 +328,6 @@ cdef class PyMergeOperator(object):
                         merge_callback))
 
         elif isinstance(ob, IMergeOperator):
-            self.ob = ob
             self.merge_op.reset(
                 <merge_operator.MergeOperator*>
                     new merge_operator.MergeOperatorWrapper(
@@ -337,11 +336,29 @@ cdef class PyMergeOperator(object):
                         <void*>ob,
                         full_merge_callback,
                         partial_merge_callback))
+        elif isinstance(ob, str):
+            if ob == "put":
+              self.merge_op = merge_operator.MergeOperators.CreatePutOperator()
+            elif ob == "put_v1":
+              self.merge_op = merge_operator.MergeOperators.CreateDeprecatedPutOperator()
+            elif ob == "uint64add":
+              self.merge_op = merge_operator.MergeOperators.CreateUInt64AddOperator()
+            elif ob == "stringappend":
+              self.merge_op = merge_operator.MergeOperators.CreateStringAppendOperator()
+            #TODO: necessary?
+            #  elif ob == "stringappendtest":
+              #  self.merge_op = merge_operator.MergeOperators.CreateStringAppendTESTOperator()
+            elif ob == "max":
+              self.merge_op = merge_operator.MergeOperators.CreateMaxOperator()
+            else:
+                msg = "{0} is not the default type".format(ob)
+                raise TypeError(msg)
         else:
             msg = "%s is not of this types %s"
             msg %= (ob, (IAssociativeMergeOperator, IMergeOperator))
             raise TypeError(msg)
 
+
     cdef object get_ob(self):
         return self.ob
 
@@ -695,6 +712,7 @@ cdef class HashLinkListMemtableFactory(PyMemtableFactory):
         self.factory.reset(memtablerep.NewHashLinkListRepFactory(bucket_count))
 ##################################
 
+
 cdef class CompressionType(object):
     no_compression = u'no_compression'
     snappy_compression = u'snappy_compression'
@@ -787,6 +805,32 @@ cdef class Options(object):
         def __set__(self, value):
             self.opts.max_open_files = value
 
+    property compression_opts:
+        def __get__(self):
+            cdef dict ret_ob = {}
+
+            ret_ob['window_bits'] = self.opts.compression_opts.window_bits
+            ret_ob['level'] = self.opts.compression_opts.level
+            ret_ob['strategy'] = self.opts.compression_opts.strategy
+            ret_ob['max_dict_bytes'] = self.opts.compression_opts.max_dict_bytes
+
+            return ret_ob
+
+        def __set__(self, dict value):
+            cdef options.CompressionOptions* copts
+            copts = cython.address(self.opts.compression_opts)
+            #  CompressionOptions(int wbits, int _lev, int _strategy, int _max_dict_bytes)
+            if 'window_bits' in value:
+                copts.window_bits  = value['window_bits']
+            if 'level' in value:
+                copts.level = value['level']
+            if 'strategy' in value:
+                copts.strategy = value['strategy']
+            if 'max_dict_bytes' in value:
+                copts.max_dict_bytes = value['max_dict_bytes']
+
+
+
     property compression:
         def __get__(self):
             if self.opts.compression == options.kNoCompression:
diff --git a/rocksdb/merge_operator.pxd b/rocksdb/merge_operator.pxd
index b8a95da..36ae98e 100644
--- a/rocksdb/merge_operator.pxd
+++ b/rocksdb/merge_operator.pxd
@@ -3,11 +3,32 @@ from libcpp cimport bool as cpp_bool
 from libcpp.deque cimport deque
 from slice_ cimport Slice
 from logger cimport Logger
+from std_memory cimport shared_ptr
 
 cdef extern from "rocksdb/merge_operator.h" namespace "rocksdb":
     cdef cppclass MergeOperator:
         pass
 
+#  cdef extern from  "utilities/merge_operators.h" namespace "rocksdb::MergeOperators":
+cdef extern from  "utilities/merge_operators.h" namespace "rocksdb":
+    cdef cppclass MergeOperators:
+        @staticmethod
+        shared_ptr[MergeOperator] CreatePutOperator()
+        @staticmethod
+        shared_ptr[MergeOperator] CreateDeprecatedPutOperator()
+        @staticmethod
+        shared_ptr[MergeOperator] CreateUInt64AddOperator()
+        @staticmethod
+        shared_ptr[MergeOperator] CreateStringAppendOperator()
+        @staticmethod
+        shared_ptr[MergeOperator] CreateStringAppendTESTOperator()
+        @staticmethod
+        shared_ptr[MergeOperator] CreateMaxOperator()
+        @staticmethod
+        shared_ptr[MergeOperator] CreateFromStringId(const string &)
+
+
+
 ctypedef cpp_bool (*merge_func)(
     void*,
     const Slice&,
diff --git a/rocksdb/options.pxd b/rocksdb/options.pxd
index 93395ab..e54bf69 100644
--- a/rocksdb/options.pxd
+++ b/rocksdb/options.pxd
@@ -16,6 +16,14 @@ from universal_compaction cimport CompactionOptionsUniversal
 from cache cimport Cache
 
 cdef extern from "rocksdb/options.h" namespace "rocksdb":
+    cdef cppclass CompressionOptions:
+        int window_bits;
+        int level;
+        int strategy;
+        uint32_t max_dict_bytes
+        CompressionOptions() except +
+        CompressionOptions(int, int, int, int) except +
+
     ctypedef enum CompactionStyle:
         kCompactionStyleLevel
         kCompactionStyleUniversal
@@ -61,7 +69,6 @@ cdef extern from "rocksdb/options.h" namespace "rocksdb":
         CompressionType compression
         CompactionPri compaction_pri
         # TODO: compression_per_level
-        # TODO: compression_opts
         shared_ptr[SliceTransform] prefix_extractor
         int num_levels
         int level0_file_num_compaction_trigger
@@ -121,7 +128,8 @@ cdef extern from "rocksdb/options.h" namespace "rocksdb":
         size_t inplace_update_num_locks
         shared_ptr[Cache] row_cache
         # TODO: remove options source_compaction_factor, max_grandparent_overlap_bytes and expanded_compaction_factor from document
-        uint64_t max_compaction_bytes;
+        uint64_t max_compaction_bytes
+        CompressionOptions compression_opts
 
     cdef cppclass WriteOptions:
         cpp_bool sync
diff --git a/rocksdb/tests/test_db.py b/rocksdb/tests/test_db.py
index 4b25a6d..7eebd27 100644
--- a/rocksdb/tests/test_db.py
+++ b/rocksdb/tests/test_db.py
@@ -4,6 +4,7 @@ import gc
 import unittest
 import rocksdb
 from itertools import takewhile
+import struct
 
 def int_to_bytes(ob):
     return str(ob).encode('ascii')
@@ -230,6 +231,106 @@ class AssocCounter(rocksdb.interfaces.AssociativeMergeOperator):
     def name(self):
         return b'AssocCounter'
 
+class TestUint64Merge(unittest.TestCase, TestHelper):
+    def setUp(self):
+        opts = rocksdb.Options()
+        opts.create_if_missing = True
+        opts.merge_operator = "uint64add"
+        self._clean()
+        self.db = rocksdb.DB('/tmp/test', opts)
+
+    def tearDown(self):
+        self._close_db()
+
+    def test_merge(self):
+        self.db.put(b'a', struct.pack('Q', 5566))
+        for x in range(1000):
+            self.db.merge(b"a", struct.pack('Q', x))
+        print ('value', struct.unpack('Q', self.db.get(b'a'))[0])
+        self.assertEqual(5566 + sum(range(1000)), struct.unpack('Q', self.db.get(b'a'))[0])
+
+class TestUint64Merge(unittest.TestCase, TestHelper):
+    def setUp(self):
+        opts = rocksdb.Options()
+        opts.create_if_missing = True
+        opts.merge_operator = "uint64add"
+        self._clean()
+        self.db = rocksdb.DB('/tmp/test', opts)
+
+    def tearDown(self):
+        self._close_db()
+
+    def test_merge(self):
+        self.db.put(b'a', struct.pack('Q', 5566))
+        for x in range(1000):
+            self.db.merge(b"a", struct.pack('Q', x))
+        #  print ('value', struct.unpack('Q', self.db.get(b'a'))[0])
+        self.assertEqual(5566 + sum(range(1000)), struct.unpack('Q', self.db.get(b'a'))[0])
+
+class TestPutMerge(unittest.TestCase, TestHelper):
+    def setUp(self):
+        opts = rocksdb.Options()
+        opts.create_if_missing = True
+        opts.merge_operator = "put"
+        self._clean()
+        self.db = rocksdb.DB('/tmp/test', opts)
+
+    def tearDown(self):
+        self._close_db()
+
+    def test_merge(self):
+        self.db.put(b'a', b'ccc')
+        self.db.merge(b'a', b'ddd')
+        self.assertEqual(self.db.get(b'a'), 'ddd')
+
+class TestPutV1Merge(unittest.TestCase, TestHelper):
+    def setUp(self):
+        opts = rocksdb.Options()
+        opts.create_if_missing = True
+        opts.merge_operator = "put_v1"
+        self._clean()
+        self.db = rocksdb.DB('/tmp/test', opts)
+
+    def tearDown(self):
+        self._close_db()
+
+    def test_merge(self):
+        self.db.put(b'a', b'ccc')
+        self.db.merge(b'a', b'ddd')
+        self.assertEqual(self.db.get(b'a'), 'ddd')
+
+class TestStringAppendOperatorMerge(unittest.TestCase, TestHelper):
+    def setUp(self):
+        opts = rocksdb.Options()
+        opts.create_if_missing = True
+        opts.merge_operator = "stringappend"
+        self._clean()
+        self.db = rocksdb.DB('/tmp/test', opts)
+
+    def tearDown(self):
+        self._close_db()
+
+    def test_merge(self):
+        self.db.put(b'a', b'ccc')
+        self.db.merge(b'a', b'ddd')
+        self.assertEqual(self.db.get(b'a'), 'ccc,ddd')
+
+class TestStringMaxOperatorMerge(unittest.TestCase, TestHelper):
+    def setUp(self):
+        opts = rocksdb.Options()
+        opts.create_if_missing = True
+        opts.merge_operator = "max"
+        self._clean()
+        self.db = rocksdb.DB('/tmp/test', opts)
+
+    def tearDown(self):
+        self._close_db()
+
+    def test_merge(self):
+        self.db.put(b'a', int_to_bytes(55))
+        self.db.merge(b'a', int_to_bytes(56))
+        self.assertEqual(int(self.db.get(b'a')), 56)
+
 
 class TestAssocMerge(unittest.TestCase, TestHelper):
     def setUp(self):
diff --git a/rocksdb/tests/test_options.py b/rocksdb/tests/test_options.py
index 09786a2..ec2db4c 100644
--- a/rocksdb/tests/test_options.py
+++ b/rocksdb/tests/test_options.py
@@ -22,6 +22,39 @@ class TestMergeOperator(rocksdb.interfaces.MergeOperator):
         return b'testmergeop'
 
 class TestOptions(unittest.TestCase):
+    def test_default_merge_operator(self):
+        opts = rocksdb.Options()
+        self.assertEqual(True, opts.paranoid_checks)
+        opts.paranoid_checks = False
+        self.assertEqual(False, opts.paranoid_checks)
+
+        self.assertIsNone(opts.merge_operator)
+        opts.merge_operator = "uint64add"
+        self.assertIsNotNone(opts.merge_operator)
+        self.assertEqual(opts.merge_operator, "uint64add")
+        with self.assertRaises(TypeError):
+            opts.merge_operator = "not an operator"
+
+    def test_compression_opts(self):
+        opts = rocksdb.Options()
+        compression_opts = opts.compression_opts
+        # default value
+        self.assertEqual(isinstance(compression_opts, dict), True)
+        self.assertEqual(compression_opts['window_bits'], -14)
+        self.assertEqual(compression_opts['level'], -1)
+        self.assertEqual(compression_opts['strategy'], 0)
+        self.assertEqual(compression_opts['max_dict_bytes'], 0)
+
+        with self.assertRaises(TypeError):
+            opts.compression_opts = list(1,2)
+
+        opts.compression_opts = {'window_bits': 1, 'level': 2, 'strategy': 3, 'max_dict_bytes': 4}
+        compression_opts = opts.compression_opts
+        self.assertEqual(compression_opts['window_bits'], 1)
+        self.assertEqual(compression_opts['level'], 2)
+        self.assertEqual(compression_opts['strategy'], 3)
+        self.assertEqual(compression_opts['max_dict_bytes'], 4)
+
     def test_simple(self):
         opts = rocksdb.Options()
         self.assertEqual(True, opts.paranoid_checks)
-- 
GitLab