diff --git a/CHANGELOG.md b/CHANGELOG.md index e46d40e5fe64d0813d7fb8cd47d0787df558281a..48f86137b59dbe07939d39c5e37120691d33dac0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # v2.2.10 -* _TBA_ +* added [import_from](satella/imports.py) # v2.2.9 diff --git a/docs/import.rst b/docs/import.rst new file mode 100644 index 0000000000000000000000000000000000000000..37905c222af74c414da3351eb90bec928a397c52 --- /dev/null +++ b/docs/import.rst @@ -0,0 +1,24 @@ +Import +====== + +Sometimes you just have a fairly nested module hierarchy, +and you want to import everything. Don't worry, Satella's got +you covered. Just use this function + +.. autofunction:: satella.imports.import_from + +An example use would be your module's __init__.py containing +following code: + +:: + + from satella.imports import import_from + + __all__ = [] + + import_from(__path__, __name__, __all__, locals()) + +In this case, everything will be accessible from this module. +This will examine the __all__ of your submodules, and dir() it +if __all__'s not available. Note that lack of availability of +__all__ will emit a log warning. diff --git a/docs/index.rst b/docs/index.rst index 5684be53e281db6ea00a6b8d766377736b7e00c4..3168e26af403c6b7a8dcd1591cf786b1d8f47ecf 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,6 +16,7 @@ Welcome to satella's documentation! exception_handling json posix + import Indices and tables diff --git a/satella/__init__.py b/satella/__init__.py index d4503531548ae9321e60aaee9364dde5bafb5ca3..9a9fd1389d4dac1a8f0830fdf95b9fc561852bfd 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1,2 +1,2 @@ # coding=UTF-8 -__version__ = '2.2.10a1' +__version__ = '2.2.10' diff --git a/satella/imports.py b/satella/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..06bc5723f5ccde8d803767c344e2e72448585a03 --- /dev/null +++ b/satella/imports.py @@ -0,0 +1,62 @@ +import typing as tp +import importlib +import pkgutil +import logging +import os + +__all__ = ['import_from'] + +logger = logging.getLogger(__name__) + + +def import_from(path: tp.List[str], package_prefix: str, all_: tp.List[str], locals: tp.Dict[str, tp.Any], recursive: bool = True, + fail_on_attributerror: bool = True, add_all: bool = True) -> None: + """ + Import everything from a given module. Append these module's all to. + + This will examine __all__ of given module (if it has any, else it will just import everything + from it, which is probably a bad practice and will heavily pollute the namespace. + + As a side effect, this will equip all of your packages with __all__. + + :param path: module's __path__ + :param package_prefix: package prefix to import from. Use __name__ + :param all_: module's __all__ to append to + :param recursive: whether to import packages as well + :param fail_on_attributerror: whether to fail if a module reports something in their __all__ that + is physically not there (ie. getattr() raised AttributeError + :param locals: module's locals, obtain them by calling locals() in importing module's context + :param add_all: whether to create artificial __all__'s for modules that don't have them + :raise AttributeError: module's __all__ contained entry that was not in this module + """ + logger.warning('Invoking with path=%s', path) + for importer, modname, ispkg in pkgutil.walk_packages(path, onerror=lambda x: None): + if recursive and ispkg: + module = importlib.import_module(package_prefix+'.'+modname) + logger.warning(repr(package_prefix)) + logger.warning(repr(modname)) + try: + mod_all = module.__all__ + except AttributeError: + mod_all = [] + if add_all: + module.__all__ = mod_all + import_from([os.path.join(path[0], modname)], package_prefix+'.'+modname, mod_all, module.__dict__, recursive=recursive, fail_on_attributerror=fail_on_attributerror), + locals[modname] = module + __all__.append(modname) + elif not ispkg: + module = importlib.import_module(package_prefix+'.'+modname) + try: + package_ref = module.__all__ + except AttributeError: + logger.warning('Module %s does not contain __all__, enumerating it instead', package_prefix+'.'+modname) + package_ref = dir(module) + + for item in package_ref: + try: + locals[item] = getattr(module, item) + except AttributeError: + if fail_on_attributerror: + raise + else: + all_.append(item) diff --git a/tests/test_imports/__init__.py b/tests/test_imports/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa84e47ee529b348bfcf12c1c64433e3c127dde4 --- /dev/null +++ b/tests/test_imports/__init__.py @@ -0,0 +1,5 @@ +import logging +import typing as tp + +logger = logging.getLogger(__name__) + diff --git a/tests/test_imports/importa/__init__.py b/tests/test_imports/importa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0c476295a6beb7f506572b7b538ddd5408fb8647 --- /dev/null +++ b/tests/test_imports/importa/__init__.py @@ -0,0 +1,14 @@ +import logging +import typing as tp + +logger = logging.getLogger(__name__) + +from satella.imports import import_from + +__all__ = [] + + +def do_import(): + logger.warning(repr(__path__)) + logger.warning(repr(__name__)) + import_from(__path__, __name__, __all__, locals(), recursive=True, fail_on_attributerror=False, add_all=True) diff --git a/tests/test_imports/importa/importb/__init__.py b/tests/test_imports/importa/importb/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa84e47ee529b348bfcf12c1c64433e3c127dde4 --- /dev/null +++ b/tests/test_imports/importa/importb/__init__.py @@ -0,0 +1,5 @@ +import logging +import typing as tp + +logger = logging.getLogger(__name__) + diff --git a/tests/test_imports/importa/importb/dontimportme.py b/tests/test_imports/importa/importb/dontimportme.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6304852ff21b3b5805dee53c14b5d71a38add7 --- /dev/null +++ b/tests/test_imports/importa/importb/dontimportme.py @@ -0,0 +1,7 @@ +import logging +import typing as tp + +logger = logging.getLogger(__name__) + +def sub(a: float, b: float) -> float: + return a-b \ No newline at end of file diff --git a/tests/test_imports/importa/importb/importme.py b/tests/test_imports/importa/importb/importme.py new file mode 100644 index 0000000000000000000000000000000000000000..08c9b8ea81cd9c4b1b65c59aff25b01fc78f68c4 --- /dev/null +++ b/tests/test_imports/importa/importb/importme.py @@ -0,0 +1,10 @@ +import logging +import typing as tp + +logger = logging.getLogger(__name__) + +__all__ = ['add', 'would_have_failed'] + +def add(a: float, b: float) -> float: + return a+b + diff --git a/tests/test_imports/test_import.py b/tests/test_imports/test_import.py new file mode 100644 index 0000000000000000000000000000000000000000..912d1f1f115822a8178ad92cc992f9274e938539 --- /dev/null +++ b/tests/test_imports/test_import.py @@ -0,0 +1,15 @@ +import logging +import unittest + +logger = logging.getLogger(__name__) + + +class TestImports(unittest.TestCase): + def test_imports(self): + import tests.test_imports.importa + tests.test_imports.importa.do_import() + + tests.test_imports.importa.importb.__all__ + + self.assertEqual(tests.test_imports.importa.importb.add(4, 5), 9) + self.assertEqual(tests.test_imports.importa.importb.sub(4, 5), -1)