From 87a7ddfe1c8f3c6b76c158a86d0641cfdfcf398a Mon Sep 17 00:00:00 2001
From: hofmockel <dreagonfly@gmx.de>
Date: Wed, 22 Oct 2014 09:43:47 +0200
Subject: [PATCH] Move filter_policy to block_based_table_factory.

---
 rocksdb/_rocksdb.pyx      | 84 +++++++++++++++++++--------------------
 rocksdb/options.pxd       |  2 -
 rocksdb/table_factory.pxd |  2 +
 3 files changed, 42 insertions(+), 46 deletions(-)

diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx
index 254b209..1a81fee 100644
--- a/rocksdb/_rocksdb.pyx
+++ b/rocksdb/_rocksdb.pyx
@@ -49,6 +49,8 @@ from interfaces import SliceTransform as ISliceTransform
 import traceback
 import errors
 
+ctypedef const filter_policy.FilterPolicy ConstFilterPolicy
+
 cdef extern from "cpp/utils.hpp" namespace "py_rocks":
     cdef const Slice* vector_data(vector[Slice]&)
 
@@ -202,41 +204,36 @@ cdef class PyFilterPolicy(object):
     cdef object get_ob(self):
         return None
 
-    cdef const filter_policy.FilterPolicy* get_policy(self):
-        return NULL
+    cdef shared_ptr[ConstFilterPolicy] get_policy(self):
+        return shared_ptr[ConstFilterPolicy]()
 
     cdef set_info_log(self, shared_ptr[logger.Logger] info_log):
         pass
 
 @cython.internal
 cdef class PyGenericFilterPolicy(PyFilterPolicy):
-    cdef filter_policy.FilterPolicyWrapper* policy
+    cdef shared_ptr[filter_policy.FilterPolicyWrapper] policy
     cdef object ob
 
     def __cinit__(self, object ob):
-        self.policy = NULL
         if not isinstance(ob, IFilterPolicy):
             raise TypeError("%s is not of type %s" % (ob, IFilterPolicy))
 
         self.ob = ob
-        self.policy = new filter_policy.FilterPolicyWrapper(
+        self.policy.reset(new filter_policy.FilterPolicyWrapper(
                 bytes_to_string(ob.name()),
                 <void*>ob,
                 create_filter_callback,
-                key_may_match_callback)
-
-    def __dealloc__(self):
-        if not self.policy == NULL:
-            del self.policy
+                key_may_match_callback))
 
     cdef object get_ob(self):
         return self.ob
 
-    cdef const filter_policy.FilterPolicy* get_policy(self):
-        return <filter_policy.FilterPolicy*> self.policy
+    cdef shared_ptr[ConstFilterPolicy] get_policy(self):
+        return <shared_ptr[ConstFilterPolicy]>(self.policy)
 
     cdef set_info_log(self, shared_ptr[logger.Logger] info_log):
-        self.policy.set_info_log(info_log)
+        self.policy.get().set_info_log(info_log)
 
 
 cdef void create_filter_callback(
@@ -274,18 +271,13 @@ cdef cpp_bool key_may_match_callback(
 
 @cython.internal
 cdef class PyBloomFilterPolicy(PyFilterPolicy):
-    cdef const filter_policy.FilterPolicy* policy
+    cdef shared_ptr[ConstFilterPolicy] policy
 
     def __cinit__(self, int bits_per_key):
-        self.policy = NULL
-        self.policy = filter_policy.NewBloomFilterPolicy(bits_per_key)
-
-    def __dealloc__(self):
-        if not self.policy == NULL:
-            del self.policy
+        self.policy.reset(filter_policy.NewBloomFilterPolicy(bits_per_key))
 
     def name(self):
-        return PyBytes_FromString(self.policy.Name())
+        return PyBytes_FromString(self.policy.get().Name())
 
     def create_filter(self, keys):
         cdef string dst
@@ -294,7 +286,7 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy):
         for key in keys:
             c_keys.push_back(bytes_to_slice(key))
 
-        self.policy.CreateFilter(
+        self.policy.get().CreateFilter(
             vector_data(c_keys),
             c_keys.size(),
             cython.address(dst))
@@ -302,14 +294,14 @@ cdef class PyBloomFilterPolicy(PyFilterPolicy):
         return string_to_bytes(dst)
 
     def key_may_match(self, key, filter_):
-        return self.policy.KeyMayMatch(
+        return self.policy.get().KeyMayMatch(
             bytes_to_slice(key),
             bytes_to_slice(filter_))
 
     cdef object get_ob(self):
         return self
 
-    cdef const filter_policy.FilterPolicy* get_policy(self):
+    cdef shared_ptr[ConstFilterPolicy] get_policy(self):
         return self.policy
 
 BloomFilterPolicy = PyBloomFilterPolicy
@@ -561,13 +553,19 @@ cdef class PyTableFactory(object):
     cdef shared_ptr[table_factory.TableFactory] get_table_factory(self):
         return self.factory
 
+    cdef set_info_log(self, shared_ptr[logger.Logger] info_log):
+        pass
+
 cdef class BlockBasedTableFactory(PyTableFactory):
+    cdef PyFilterPolicy py_filter_policy
+
     def __init__(self,
             index_type='binary_search',
             py_bool hash_index_allow_collision=True,
             checksum='crc32',
             PyCache block_cache=None,
             PyCache block_cache_compressed=None,
+            filter_policy=None,
             no_block_cache=False,
             block_size=None,
             block_size_deviation=None,
@@ -622,8 +620,24 @@ cdef class BlockBasedTableFactory(PyTableFactory):
         if block_cache_compressed is not None:
             table_options.block_cache_compressed = block_cache_compressed.get_cache()
 
+        # Set the filter_policy
+        self.py_filter_policy = None
+        if filter_policy is not None:
+            if isinstance(filter_policy, PyFilterPolicy):
+                if (<PyFilterPolicy?>filter_policy).get_policy().get() == NULL:
+                    raise Exception("Cannot set filter policy: %s" % filter_policy)
+                self.py_filter_policy = filter_policy
+            else:
+                self.py_filter_policy = PyGenericFilterPolicy(filter_policy)
+
+            table_options.filter_policy = self.py_filter_policy.get_policy()
+
         self.factory.reset(table_factory.NewBlockBasedTableFactory(table_options))
 
+    cdef set_info_log(self, shared_ptr[logger.Logger] info_log):
+        if self.py_filter_policy is not None:
+            self.py_filter_policy.set_info_log(info_log)
+
 cdef class PlainTableFactory(PyTableFactory):
     def __init__(
             self,
@@ -701,7 +715,6 @@ cdef class Options(object):
     cdef options.Options* opts
     cdef PyComparator py_comparator
     cdef PyMergeOperator py_merge_operator
-    cdef PyFilterPolicy py_filter_policy
     cdef PySliceTransform py_prefix_extractor
     cdef PyTableFactory py_table_factory
     cdef PyMemtableFactory py_memtable_factory
@@ -721,7 +734,6 @@ cdef class Options(object):
     def __init__(self, **kwargs):
         self.py_comparator = BytewiseComparator()
         self.py_merge_operator = None
-        self.py_filter_policy = None
         self.py_prefix_extractor = None
         self.py_table_factory = None
         self.py_memtable_factory = None
@@ -1196,22 +1208,6 @@ cdef class Options(object):
             self.py_merge_operator = PyMergeOperator(value)
             self.opts.merge_operator = self.py_merge_operator.get_operator()
 
-    property filter_policy:
-        def __get__(self):
-            if self.py_filter_policy is None:
-                return None
-            return self.py_filter_policy.get_ob()
-
-        def __set__(self, value):
-            if isinstance(value, PyFilterPolicy):
-                if (<PyFilterPolicy?>value).get_policy() == NULL:
-                    raise Exception("Cannot set filter policy: %s" % value)
-                self.py_filter_policy = value
-            else:
-                self.py_filter_policy = PyGenericFilterPolicy(value)
-
-            self.opts.filter_policy = self.py_filter_policy.get_policy()
-
     property prefix_extractor:
         def __get__(self):
             if self.py_prefix_extractor is None:
@@ -1297,8 +1293,8 @@ cdef class DB(object):
         if opts.py_comparator is not None:
             opts.py_comparator.set_info_log(info_log)
 
-        if opts.py_filter_policy is not None:
-            opts.py_filter_policy.set_info_log(info_log)
+        if opts.py_table_factory is not None:
+            opts.py_table_factory.set_info_log(info_log)
 
         if opts.prefix_extractor is not None:
             opts.py_prefix_extractor.set_info_log(info_log)
diff --git a/rocksdb/options.pxd b/rocksdb/options.pxd
index b17a5ab..a2be987 100644
--- a/rocksdb/options.pxd
+++ b/rocksdb/options.pxd
@@ -5,7 +5,6 @@ from libc.stdint cimport uint64_t
 from std_memory cimport shared_ptr
 from comparator cimport Comparator
 from merge_operator cimport MergeOperator
-from filter_policy cimport FilterPolicy
 from logger cimport Logger
 from slice_ cimport Slice
 from snapshot cimport Snapshot
@@ -32,7 +31,6 @@ cdef extern from "rocksdb/options.h" namespace "rocksdb":
     cdef cppclass Options:
         const Comparator* comparator
         shared_ptr[MergeOperator] merge_operator
-        const FilterPolicy* filter_policy
         # TODO: compaction_filter
         # TODO: compaction_filter_factory
         cpp_bool create_if_missing
diff --git a/rocksdb/table_factory.pxd b/rocksdb/table_factory.pxd
index 2c61e64..2359292 100644
--- a/rocksdb/table_factory.pxd
+++ b/rocksdb/table_factory.pxd
@@ -3,6 +3,7 @@ from libcpp cimport bool as cpp_bool
 from std_memory cimport shared_ptr
 
 from cache cimport Cache
+from filter_policy cimport FilterPolicy
 
 cdef extern from "rocksdb/table.h" namespace "rocksdb":
     cdef cppclass TableFactory:
@@ -28,6 +29,7 @@ cdef extern from "rocksdb/table.h" namespace "rocksdb":
         cpp_bool whole_key_filtering
         shared_ptr[Cache] block_cache
         shared_ptr[Cache] block_cache_compressed
+        shared_ptr[FilterPolicy] filter_policy
 
     cdef TableFactory* NewBlockBasedTableFactory(const BlockBasedTableOptions&)
 
-- 
GitLab