Skip to content
Snippets Groups Projects
Commit d48845a8 authored by Piotr Maślanka's avatar Piotr Maślanka
Browse files

v2.7.8: add KeyAwareDefaultDict

parent 2d02f6bd
No related branches found
Tags v2.7.8
No related merge requests found
# v2.7.8
* _TBA_
* added `KeyAwareDefaultDict`
# v2.7.7
......
......@@ -164,3 +164,8 @@ set to False.
.. autoclass:: satella.coding.structures.DirtyDict
:members:
KeyAwareDefaultDict
===================
.. autoclass:: satella.coding.structures.KeyAwareDefaultDict
__version__ = '2.7.8_a1'
__version__ = '2.7.8'
from .dictionaries import DictObject, apply_dict_object, DictionaryView, TwoWayDictionary, \
DirtyDict
DirtyDict, KeyAwareDefaultDict
from .hashable_objects import HashableWrapper
from .immutable import Immutable, frozendict
from .singleton import Singleton, SingletonWithRegardsTo
......@@ -10,6 +10,7 @@ from .ranking import Ranking
from .sorted_list import SortedList, SliceableDeque
__all__ = [
'KeyAwareDefaultDict',
'Proxy',
'DirtyDict',
'SortedList',
......
......@@ -6,7 +6,8 @@ from satella.coding.recast_exceptions import rethrow_as
from satella.configuration.schema import Descriptor, descriptor_from_dict
from satella.exceptions import ConfigurationValidationError
__all__ = ['DictObject', 'apply_dict_object', 'DictionaryView', 'TwoWayDictionary', 'DirtyDict']
__all__ = ['DictObject', 'apply_dict_object', 'DictionaryView', 'TwoWayDictionary', 'DirtyDict',
'KeyAwareDefaultDict']
K, V, T = tp.TypeVar('K'), tp.TypeVar('V'), tp.TypeVar('T')
......@@ -359,3 +360,35 @@ class DirtyDict(collections.UserDict, tp.Generic[K, V]):
a = self.data.copy()
self.dirty = False
return a
class KeyAwareDefaultDict(collections.abc.MutableMapping):
"""
A defaultdict whose factory function accepts the key to provide a default value for the key
:param factory_function: a callable that accepts a single argument, a key, for which it is to provide
a default value
"""
def __len__(self) -> int:
return len(self.dict)
def __iter__(self):
return iter(self.dict)
def __init__(self, factory_function: tp.Callable[[K], V], *args, **kwargs):
self.dict = dict(*args, **kwargs)
self.factory_function = factory_function
def __getitem__(self, item):
if item in self.dict:
return self.dict[item]
else:
self.dict[item] = self.factory_function(item)
return self.dict[item]
def __setitem__(self, key, value):
self.dict[key] = value
def __delitem__(self, key):
del self.dict[key]
\ No newline at end of file
......@@ -9,12 +9,16 @@ import mock
from satella.coding.structures import TimeBasedHeap, Heap, typednamedtuple, \
OmniHashableMixin, DictObject, apply_dict_object, Immutable, frozendict, SetHeap, \
DictionaryView, HashableWrapper, TwoWayDictionary, Ranking, SortedList, SliceableDeque, \
DirtyDict
DirtyDict, KeyAwareDefaultDict
logger = logging.getLogger(__name__)
class TestMisc(unittest.TestCase):
def test_key_aware_defaultdict(self):
a = KeyAwareDefaultDict(int)
self.assertEqual(a['1'], 1)
def test_dirty_dict(self):
a = DirtyDict({1: 2, 3: 4})
self.assertFalse(a.dirty)
......
import os
from os.path import join
import tempfile
import unittest
import shutil
from satella.files import read_re_sub_and_write, find_files, split
def putfile(path: str) -> None:
with open(path, 'wb') as f_out:
f_out.write(b'\x32')
class TestFiles(unittest.TestCase):
def test_split(self):
self.assertIn(split('c:/windows/system32/system32.exe'), [['c:', 'windows', 'system32',
......@@ -26,15 +32,23 @@ class TestFiles(unittest.TestCase):
def test_find_files(self):
directory = tempfile.mkdtemp()
os.mkdir(os.path.join(directory, 'test'))
with open(os.path.join(directory, 'test', 'test.txt'), 'wb') as f_out:
f_out.write(b'test')
self.assertEqual(list(find_files(directory, r'(.*)test(.*)\.txt',
apply_wildcard_to_entire_path=True)), [
os.path.join(directory, 'test', 'test.txt')])
self.assertEqual(list(find_files(directory, r'(.*)\.txt')), [
os.path.join(directory, 'test', 'test.txt')])
os.mkdir(join(directory, 'test'))
putfile(join(directory, 'test', 'test.txt'))
os.mkdir(join(directory, 'test', 'test'))
putfile(join(directory, 'test', 'test', 'test.txt'))
putfile(join(directory, 'test', 'test', 'test2.txt'))
self.assertEqual(set(find_files(directory, r'(.*)test(.*)\.txt',
apply_wildcard_to_entire_path=True)), {
join(directory, 'test', 'test.txt'),
join(directory, 'test', 'test', 'test.txt'),
join(directory, 'test', 'test', 'test2.txt')
})
self.assertEqual(set(find_files(directory, r'(.*)\.txt')), {
join(directory, 'test', 'test.txt'),
join(directory, 'test', 'test', 'test.txt'),
join(directory, 'test', 'test', 'test2.txt')
})
shutil.rmtree(directory)
def test_read_re_sub_and_write(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment