From ca76ffb03bfe00703814ac6f48040436c1b51685 Mon Sep 17 00:00:00 2001
From: Mandar Harshe <mandarharshe@gmail.com>
Date: Thu, 30 Dec 2021 08:14:26 +0100
Subject: [PATCH] Add fixed prefix extractor

---
 rocksdb/_rocksdb.pyx        | 19 +++++++++++++------
 rocksdb/slice_transform.pxd |  6 ++++++
 rocksdb/tests/test_db.py    | 38 +++++++++++++++++++++++++++++++++++++
 3 files changed, 57 insertions(+), 6 deletions(-)

diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx
index befc51c..c085051 100644
--- a/rocksdb/_rocksdb.pyx
+++ b/rocksdb/_rocksdb.pyx
@@ -1258,8 +1258,11 @@ cdef class ColumnFamilyOptions(object):
             return self.py_prefix_extractor.get_ob()
 
         def __set__(self, value):
-            self.py_prefix_extractor = PySliceTransform(value)
-            self.copts.prefix_extractor = self.py_prefix_extractor.get_transformer()
+            if isinstance(value, int):
+                self.copts.prefix_extractor.reset(slice_transform.ST_NewFixedPrefixTransform(value))
+            else:
+                self.py_prefix_extractor = PySliceTransform(value)
+                self.copts.prefix_extractor = self.py_prefix_extractor.get_transformer()
 
     property optimize_filters_for_hits:
         def __get__(self):
@@ -2180,10 +2183,11 @@ cdef class DB(object):
 
     @staticmethod
     def __parse_read_opts(
-        verify_checksums=False,
-        fill_cache=True,
-        snapshot=None,
-        read_tier="all"):
+            verify_checksums=False,
+            fill_cache=True,
+            snapshot=None,
+            read_tier="all",
+            total_order_seek=False):
 
         # TODO: Is this really effiencet ?
         return locals()
@@ -2195,6 +2199,9 @@ cdef class DB(object):
         if py_opts['snapshot'] is not None:
             opts.snapshot = (<Snapshot?>(py_opts['snapshot'])).ptr
 
+        if py_opts['total_order_seek'] is not None:
+            opts.total_order_seek = py_opts['total_order_seek']
+
         if py_opts['read_tier'] == "all":
             opts.read_tier = options.kReadAllTier
         elif py_opts['read_tier'] == 'cache':
diff --git a/rocksdb/slice_transform.pxd b/rocksdb/slice_transform.pxd
index 37d7740..dd73502 100644
--- a/rocksdb/slice_transform.pxd
+++ b/rocksdb/slice_transform.pxd
@@ -8,6 +8,12 @@ cdef extern from "rocksdb/slice_transform.h" namespace "rocksdb":
     cdef cppclass SliceTransform:
         pass
 
+    cdef const SliceTransform* ST_NewCappedPrefixTransform "rocksdb::NewCappedPrefixTransform"(
+        size_t) nogil except+
+
+    cdef const SliceTransform* ST_NewFixedPrefixTransform "rocksdb::NewFixedPrefixTransform"(
+        size_t) nogil except+
+
 ctypedef Slice (*transform_func)(
     void*,
     Logger*,
diff --git a/rocksdb/tests/test_db.py b/rocksdb/tests/test_db.py
index 900f278..201bfb7 100644
--- a/rocksdb/tests/test_db.py
+++ b/rocksdb/tests/test_db.py
@@ -489,6 +489,44 @@ class TestPrefixExtractor(TestHelper):
         ret = takewhile(lambda item: item[0].startswith(b'00002'), it)
         self.assertEqual(ref, dict(ret))
 
+class TestFixedPrefixExtractor(TestHelper):
+    def setUp(self):
+        TestHelper.setUp(self)
+        opts = rocksdb.Options(create_if_missing=True, prefix_extractor=4)
+        self.db = rocksdb.DB(os.path.join(self.db_loc, 'test'), opts)
+
+    def _fill_db(self):
+        for x in range(3000):
+            keyx = hex(x)[2:].zfill(5).encode('utf8') + b'.x'
+            keyy = hex(x)[2:].zfill(5).encode('utf8') + b'.y'
+            keyz = hex(x)[2:].zfill(5).encode('utf8') + b'.z'
+            self.db.put(keyx, b'x')
+            self.db.put(keyy, b'y')
+            self.db.put(keyz, b'z')
+
+    def test_prefix_iterkeys(self):
+        self._fill_db()
+        self.assertEqual(b'x', self.db.get(b'00001.x'))
+        self.assertEqual(b'y', self.db.get(b'00001.y'))
+        self.assertEqual(b'z', self.db.get(b'00001.z'))
+
+        it = self.db.iterkeys()
+        it.seek(b'00002')
+
+        ref = [b'00002.x', b'00002.y', b'00002.z']
+        ret = takewhile(lambda key: key.startswith(b'00002'), it)
+        self.assertEqual(ref, list(ret))
+
+    def test_prefix_iteritems(self):
+        self._fill_db()
+
+        it = self.db.iteritems()
+        it.seek(b'00002')
+
+        ref = {b'00002.z': b'z', b'00002.y': b'y', b'00002.x': b'x'}
+        ret = takewhile(lambda item: item[0].startswith(b'00002'), it)
+        self.assertEqual(ref, dict(ret))
+
 class TestDBColumnFamilies(TestHelper):
     def setUp(self):
         TestHelper.setUp(self)
-- 
GitLab