diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx index 9106f057bf6f1d9b0b0d5a00cb0c57499915094d..b3816f98c8bab017630457a066a180a4a5b8426c 100644 --- a/rocksdb/_rocksdb.pyx +++ b/rocksdb/_rocksdb.pyx @@ -109,9 +109,12 @@ cdef class PyComparator(object): cdef const comparator.Comparator* get_comparator(self): return NULL + cdef set_info_log(self, shared_ptr[logger.Logger] info_log): + pass + @cython.internal cdef class PyGenericComparator(PyComparator): - cdef const comparator.Comparator* comparator_ptr + cdef comparator.ComparatorWrapper* comparator_ptr cdef object ob def __cinit__(self, object ob): @@ -121,11 +124,10 @@ cdef class PyGenericComparator(PyComparator): raise TypeError("Cannot set comparator: %s" % ob) self.ob = ob - self.comparator_ptr = <comparator.Comparator*>( - new comparator.ComparatorWrapper( + self.comparator_ptr = new comparator.ComparatorWrapper( bytes_to_string(ob.name()), <void*>ob, - compare_callback)) + compare_callback) def __dealloc__(self): if not self.comparator_ptr == NULL: @@ -135,7 +137,10 @@ cdef class PyGenericComparator(PyComparator): return self.ob cdef const comparator.Comparator* get_comparator(self): - return self.comparator_ptr + return <comparator.Comparator*> self.comparator_ptr + + cdef set_info_log(self, shared_ptr[logger.Logger] info_log): + self.comparator_ptr.set_info_log(info_log) @cython.internal cdef class PyBytewiseComparator(PyComparator): @@ -160,10 +165,17 @@ cdef class PyBytewiseComparator(PyComparator): cdef int compare_callback( void* ctx, + logger.Logger* log, + string& error_msg, const Slice& a, const Slice& b) with gil: - return (<object>ctx).compare(slice_to_bytes(a), slice_to_bytes(b)) + try: + return (<object>ctx).compare(slice_to_bytes(a), slice_to_bytes(b)) + except BaseException as error: + tb = traceback.format_exc() + logger.Log(log, "Error in compare callback: %s", <bytes>tb) + error_msg.assign(<bytes>str(error)) BytewiseComparator = PyBytewiseComparator ######################################### @@ -1090,6 +1102,12 @@ cdef class DB(object): cython.address(self.db)) check_status(st) + + # Inject the loggers into the python callbacks + cdef shared_ptr[logger.Logger] info_log = self.db.GetOptions().info_log + if opts.py_comparator is not None: + opts.py_comparator.set_info_log(info_log) + self.opts = opts self.opts.in_use = True diff --git a/rocksdb/comparator.pxd b/rocksdb/comparator.pxd index 4ff8b777bb2f5ac719e4a2c01b11139e3d1df11d..c54c26169ede8657dfacc4fe6a3aa2bc878e4a70 100644 --- a/rocksdb/comparator.pxd +++ b/rocksdb/comparator.pxd @@ -1,5 +1,8 @@ from libcpp.string cimport string from slice_ cimport Slice +from logger cimport Logger +from std_memory cimport shared_ptr + cdef extern from "rocksdb/comparator.h" namespace "rocksdb": cdef cppclass Comparator: const char* Name() @@ -7,8 +10,14 @@ cdef extern from "rocksdb/comparator.h" namespace "rocksdb": cdef extern const Comparator* BytewiseComparator() nogil except + -ctypedef int (*compare_func)(void*, const Slice&, const Slice&) +ctypedef int (*compare_func)( + void*, + Logger*, + string&, + const Slice&, + const Slice&) cdef extern from "cpp/comparator_wrapper.hpp" namespace "py_rocks": cdef cppclass ComparatorWrapper: ComparatorWrapper(string, void*, compare_func) nogil except + + void set_info_log(shared_ptr[Logger]) nogil except+ diff --git a/rocksdb/cpp/comparator_wrapper.hpp b/rocksdb/cpp/comparator_wrapper.hpp index 7f17a04725086d5986190511ae1ea7754620cb08..1d10b9df3a0621b39b8783735e4b99cba025ef76 100644 --- a/rocksdb/cpp/comparator_wrapper.hpp +++ b/rocksdb/cpp/comparator_wrapper.hpp @@ -1,13 +1,21 @@ #include "rocksdb/comparator.h" +#include "rocksdb/env.h" +#include <stdexcept> using std::string; using rocksdb::Comparator; using rocksdb::Slice; +using rocksdb::Logger; namespace py_rocks { class ComparatorWrapper: public Comparator { public: - typedef int (*compare_func)(void*, const Slice&, const Slice&); + typedef int (*compare_func)( + void*, + Logger*, + string&, + const Slice&, + const Slice&); ComparatorWrapper( string name, @@ -19,7 +27,20 @@ namespace py_rocks { {} int Compare(const Slice& a, const Slice& b) const { - return this->compare_callback(this->compare_context, a, b); + string error_msg; + int val; + + val = this->compare_callback( + this->compare_context, + this->info_log.get(), + error_msg, + a, + b); + + if (error_msg.size()) { + throw std::runtime_error(error_msg.c_str()); + } + return val; } const char* Name() const { @@ -29,9 +50,14 @@ namespace py_rocks { void FindShortestSeparator(string* start, const Slice& limit) const {} void FindShortSuccessor(string* key) const {} + void set_info_log(std::shared_ptr<Logger> info_log) { + this->info_log = info_log; + } + private: string name; void* compare_context; compare_func compare_callback; + std::shared_ptr<Logger> info_log; }; }