diff --git a/docs/coding/functions.rst b/docs/coding/functions.rst index 97d4a0a8dd4585f23dd8606920b5eee073a95fae..43e04dd27060bc0d30ec65b8728d20d276a4f3d8 100644 --- a/docs/coding/functions.rst +++ b/docs/coding/functions.rst @@ -69,6 +69,9 @@ Function overloading .. autoclass:: satella.coding.overload :members: +.. autoclass:: satella.coding.ARGS + :members: + .. autofunction:: satella.coding.is_signature_a_more_generic_than_b diff --git a/satella/coding/__init__.py b/satella/coding/__init__.py index 9752da5be0a9fcef608046b273b8c8b21582e1e8..7b3fbb64be6d54a7e625901608ed452312e1e3ca 100644 --- a/satella/coding/__init__.py +++ b/satella/coding/__init__.py @@ -18,8 +18,7 @@ from .misc import update_if_not_none, update_key_if_none, update_attr_if_none, q get_arguments, call_with_arguments, chain_callables, Closeable, contains, \ enum_value, length from .environment import Context -from .overloading import overload, class_or_instancemethod, is_signature_a_more_generic_than_b, \ - is_type_a_more_generic_than_b, extract_type_signature_from +from .overloading import overload, class_or_instancemethod, TypeSignature from .recast_exceptions import rethrow_as, silence_excs, catch_exception, log_exceptions, \ raises_exception, reraise_as from .expect_exception import expect_exception @@ -28,10 +27,9 @@ from .deep_compare import assert_equal, InequalityReason, Inequal __all__ = [ 'EmptyContextManager', 'Context', 'length', 'assert_equal', 'InequalityReason', 'Inequal', - 'Closeable', 'contains', 'enum_value', 'reraise_as', + 'Closeable', 'contains', 'enum_value', 'reraise_as' 'expect_exception', - 'overload', 'class_or_instancemethod', 'extract_type_signature_from', 'is_signature_a_more_generic_than_b', - 'is_type_a_more_generic_than_b', + 'overload', 'class_or_instancemethod', 'TypeSignature', 'update_if_not_none', 'DocsFromParent', 'update_key_if_none', 'queue_iterator', 'update_attr_if_none', 'update_key_if_not_none', 'source_to_function', 'update_key_if_true', diff --git a/satella/coding/overloading.py b/satella/coding/overloading.py index d52567bed94b500a7fa4d08ad6f57040bdd1cb89..f2450d49364511969f77abf3e93feb1ebf1a9047 100644 --- a/satella/coding/overloading.py +++ b/satella/coding/overloading.py @@ -1,30 +1,9 @@ +from __future__ import annotations + import inspect import typing as tp -from inspect import Parameter -def extract_type_signature_from(fun: tp.Callable) -> tp.Dict[str, type]: - """ - Extract type signature of a function - :param fun: function to extract signature from - :return: a dict, having all parameters normally passed to the function - """ - sign = {} - params = inspect.signature(fun).parameters - for parameter in params.values(): - if parameter.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - if parameter.annotation == Parameter.empty: - sign[parameter.name] = None - else: - sign[parameter.name] = parameter.annotation - elif parameter.kind in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD): - if parameter.annotation == Parameter.empty: - annot = None - else: - annot = parameter.annotation - sign[parameter.name] = (parameter.kind == Parameter.KEYWORD_ONLY), annot - return sign - # Taken from https://stackoverflow.com/questions/28237955/same-name-for-classmethod-and- # instancemethod @@ -49,35 +28,52 @@ class class_or_instancemethod(classmethod): return descr_get(instance, type_) -def is_type_a_more_generic_than_b(a: tp.Dict[str, tp.Type], b: tp.Dict[str, tp.Type]) -> bool: +class TypeSignature(inspect.Signature): """ - Can it be said that type a is more generic than b - :param a: type extracted from a with :func:`~satella.coding.overloading. - :param b: - :return: + Augmented type signature. """ - if a is None: + __slots__ = () + + def __init__(self, t_sign: inspect.Signature): + self._return_annotation = t_sign._return_annotation + self._parameters = t_sign._parameters + + @staticmethod + def from_fun(fun): + return TypeSignature(inspect.Signature.from_callable(fun)) + + def can_be_called_with_args(self, *args, **kwargs) -> bool: + called = self._bind(*args, **kwargs) + return all(issubclass(self.signature.parameters.get(arg_name, NONEARGS)._annotation, arg_value) + for arg_name, arg_value in called.items()) + + def is_more_generic_than(self, b: TypeSignature) -> bool: + if self == {}: + for key in self: + key1 = self[key] + key2 = b.get(key, None) + if key2 is None: + return key2 == {} + + if key2.is_more_generic_than(key1): + return False return True - for key in a: - key = b.get(key, None) - - a_ = a[key] - b_ = b[key] - if isinstance(a_, tuple): - if not isinstance(b_, tuple): - raise TypeError('Type mismatch %s to %s' % (a_, b_)) - if issubclass(b_, a_): - return True - return False + def __lt__(self, other: TypeSignature) -> bool: + return self.is_more_generic_than(other) -def is_signature_a_more_generic_than_b(a: tp.Tuple[{}, [], []], b) -> bool: - """ - - :param a: - @param b: - :return: is A more generic than B - """ + def matches(self, *args, **kwargs) -> bool: + """ + Does this invocation match this signature? + """ + bound_args = self.bind(*args, **kwargs) + bound_args.apply_defaults() + for param_name, param_value in bound_args.arguments.items(): + if isinstance(param_value, self._parameters[param_name].annotation): + continue + else: + return False + return True class overload: @@ -99,12 +95,11 @@ class overload: >>> print('Int') Note that this instance's __wrapped__ will refer to the first function. + TypeError will be called if no signatures match arguments. """ def __init__(self, fun: tp.Callable): - self.type_signatures_to_functions = { - extract_type_signature_from(fun): fun - } # type: tp.Dict[tp.Tuple[type, ...], tp.Callable] + self.type_signatures_to_functions = {TypeSignature.from_fun(fun): fun} if hasattr(fun, '__doc__'): self.__doc__ = fun.__doc__ self.__wrapped__ = fun @@ -113,10 +108,9 @@ class overload: """ :raises ValueError: this signature already has an overload """ - sign = extract_type_signature_from(fun) + sign = TypeSignature.from_fun(fun) if sign in self.type_signatures_to_functions: - f = self.type_signatures_to_functions[sign] - raise ValueError('Method of this signature is already overloaded with %s' % (f,)) + raise TypeError('Method of this signature is already overloaded with %s' % (f,)) self.type_signatures_to_functions[sign] = fun return self @@ -124,17 +118,14 @@ class overload: """ Call one of the overloaded functions. - :raises TypeError: no type signature matched + :raises TypeError: no type signature given """ + matching = [] for sign, fun in self.type_signatures_to_functions.items(): - try: - for type_, arg in zip(sign, args): - if type_ is None: - continue - if not isinstance(arg, type_): - raise ValueError() - - return fun(*args, **kwargs) - except ValueError: - pass - raise TypeError('No matching functions found') + if sign.matches(*args, **kwargs): + matching.append((sign, fun)) + matching.sort() + if not matching: + raise TypeError('No matching entries!') + else: + return matching[-1][1](*args, **kwargs) # call the most specific function you could find diff --git a/tests/test_coding/test_misc.py b/tests/test_coding/test_misc.py index 05c89d34c29e81389a954aff4bb0dd769e2607d5..4086058674bb94b38bb95fee925fac60ccaf8672 100644 --- a/tests/test_coding/test_misc.py +++ b/tests/test_coding/test_misc.py @@ -5,7 +5,7 @@ import unittest from satella.coding import update_key_if_not_none, overload, class_or_instancemethod, \ update_key_if_true, get_arguments, call_with_arguments, chain_callables, Closeable, \ - contains, enum_value, for_argument, length + contains, enum_value, for_argument, length, distance from satella.coding.structures import HashableMixin, ComparableEnum from satella.coding.transforms import jsonify, intify @@ -102,7 +102,7 @@ class TestCase(unittest.TestCase): def test_length(self): y = [1, 2, 3] x = (z for z in y) - self.assertEqual(length(x, 3)) + self.assertEqual(length(x), 3) def test_execute_with_locals(self): def fun(a, b, *args, c=None, **kwargs): @@ -230,6 +230,18 @@ class TestCase(unittest.TestCase): self.assertEqual(a['type'], 'int') self.assertRaises(TypeError, lambda: what_type(2.0)) + def test_distance(self): + class A: + pass + + class B(A): + pass + + class C(B): + pass + + self.assertEqual(distance(A, C), 2) + def test_update_key_if_not_none(self): a = {} update_key_if_not_none(a, 'test', None)