From 6277f9ab5c7ca85ec54d68a7905118c3a956350e Mon Sep 17 00:00:00 2001
From: hofmockel <dreagonfly@gmx.de>
Date: Thu, 16 Jan 2014 08:51:01 +0100
Subject: [PATCH] Support unicode objects for paths

Use sys.getfilesystemencoding() for encoding
---
 docs/api/database.rst         |  2 +-
 docs/api/options.rst          |  4 ++--
 rocksdb/_rocksdb.pyx          | 31 +++++++++++++++++++++++++------
 rocksdb/tests/test_db.py      |  6 ++++++
 rocksdb/tests/test_options.py |  9 +++++++++
 5 files changed, 43 insertions(+), 9 deletions(-)

diff --git a/docs/api/database.rst b/docs/api/database.rst
index 2ebab20..1e21ebf 100644
--- a/docs/api/database.rst
+++ b/docs/api/database.rst
@@ -8,7 +8,7 @@ Database object
 
     .. py:method:: __init__(db_name, Options opts, read_only=False)
 
-        :param string db_name:  Name of the database to open
+        :param unicode db_name:  Name of the database to open
         :param opts: Options for this specific database
         :type opts: :py:class:`rocksdb.Options`
         :param bool read_only: If ``True`` the database is opened read-only.
diff --git a/docs/api/options.rst b/docs/api/options.rst
index 14d3579..4593250 100644
--- a/docs/api/options.rst
+++ b/docs/api/options.rst
@@ -329,7 +329,7 @@ Options object
         and the db data dir's absolute path will be used as the log file
         name's prefix.
 
-        | *Type:* ``string``
+        | *Type:* ``unicode``
         | *Default:* ``""``
 
     .. py:attribute:: wal_dir
@@ -340,7 +340,7 @@ Options object
         If it is non empty, the log files will be in kept the specified dir.
         When destroying the db, all log files in wal_dir and the dir itself is deleted
 
-        | *Type:* ``string``
+        | *Type:* ``unicode``
         | *Default:* ``""``
 
     .. py:attribute:: disable_seek_compaction
diff --git a/rocksdb/_rocksdb.pyx b/rocksdb/_rocksdb.pyx
index e357b66..d9f6541 100644
--- a/rocksdb/_rocksdb.pyx
+++ b/rocksdb/_rocksdb.pyx
@@ -7,6 +7,7 @@ from cython.operator cimport dereference as deref
 from cpython.string cimport PyString_AsString
 from cpython.string cimport PyString_Size
 from cpython.string cimport PyString_FromString
+from cpython.unicode cimport PyUnicode_Decode
 
 from std_memory cimport shared_ptr
 cimport options
@@ -24,6 +25,7 @@ from slice_ cimport slice_to_str
 from slice_ cimport str_to_slice
 from status cimport Status
 
+import sys
 from interfaces import MergeOperator as IMergeOperator
 from interfaces import AssociativeMergeOperator as IAssociativeMergeOperator
 from interfaces import FilterPolicy as IFilterPolicy
@@ -64,6 +66,23 @@ cdef check_status(const Status& st):
 ######################################################
 
 
+cdef string bytes_to_string(bytes path) except *:
+    return string(PyBytes_AsString(path), PyBytes_Size(path))
+
+## only for filsystem paths
+cdef string path_to_string(object path) except *:
+    if isinstance(path, bytes):
+        return bytes_to_string(path)
+    if isinstance(path, unicode):
+        path = path.encode(sys.getfilesystemencoding())
+        return bytes_to_string(path)
+    else:
+       raise TypeError("Wrong type for path: %s" % path)
+
+cdef object string_to_path(string path):
+    fs_encoding = sys.getfilesystemencoding()
+    return PyUnicode_Decode(path.c_str(), path.size(), fs_encoding, "replace")
+
 ## Here comes the stuff for the comparator
 @cython.internal
 cdef class PyComparator(object):
@@ -609,15 +628,15 @@ cdef class Options(object):
 
     property db_log_dir:
         def __get__(self):
-            return self.opts.db_log_dir
+            return string_to_path(self.opts.db_log_dir)
         def __set__(self, value):
-            self.opts.db_log_dir = value
+            self.opts.db_log_dir = path_to_string(value)
 
     property wal_dir:
         def __get__(self):
-            return self.opts.wal_dir
+            return string_to_path(self.opts.wal_dir)
         def __set__(self, value):
-            self.opts.wal_dir = value
+            self.opts.wal_dir = path_to_string(value)
 
     property disable_seek_compaction:
         def __get__(self):
@@ -946,14 +965,14 @@ cdef class DB(object):
             check_status(
                 db.DB_OpenForReadOnly(
                     deref(opts.opts),
-                    db_name,
+                    path_to_string(db_name),
                     cython.address(self.db),
                     False))
         else:
             check_status(
                 db.DB_Open(
                     deref(opts.opts),
-                    db_name,
+                    path_to_string(db_name),
                     cython.address(self.db)))
 
         self.opts = opts
diff --git a/rocksdb/tests/test_db.py b/rocksdb/tests/test_db.py
index 202cc31..0348c25 100644
--- a/rocksdb/tests/test_db.py
+++ b/rocksdb/tests/test_db.py
@@ -24,6 +24,12 @@ class TestDB(unittest.TestCase, TestHelper):
     def tearDown(self):
         self._close_db()
 
+    def test_unicode_path(self):
+        name = b'/tmp/M\xc3\xbcnchen'.decode('utf8')
+        rocksdb.DB(name, rocksdb.Options(create_if_missing=True))
+        self.addCleanup(shutil.rmtree, name)
+        self.assertTrue(os.path.isdir(name))
+
     def test_get_none(self):
         self.assertIsNone(self.db.get('xxx'))
 
diff --git a/rocksdb/tests/test_options.py b/rocksdb/tests/test_options.py
index 599f3f3..a9b1851 100644
--- a/rocksdb/tests/test_options.py
+++ b/rocksdb/tests/test_options.py
@@ -52,3 +52,12 @@ class TestOptions(unittest.TestCase):
         ob = rocksdb.LRUCache(100)
         opts.block_cache = ob
         self.assertEqual(ob, opts.block_cache)
+
+    def test_unicode_path(self):
+        name = b'/tmp/M\xc3\xbcnchen'.decode('utf8')
+        opts = rocksdb.Options()
+        opts.db_log_dir = name
+        opts.wal_dir = name
+
+        self.assertEqual(name, opts.db_log_dir)
+        self.assertEqual(name, opts.wal_dir)
-- 
GitLab