diff --git a/CHANGELOG.md b/CHANGELOG.md index ee5008b94c0dcc5f2bafee5874b360d30402d287..16a491f84b89cf355eef3a15a554f6ef68ab1f28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,4 @@ -# v2.23.1 \ No newline at end of file +# v2.24.0 + +* completely overhauled [overload](https://satella.readthedocs.io/en/latest/coding/functions.html#function-overloading), + ie. fixed #57. diff --git a/docs/coding/functions.rst b/docs/coding/functions.rst index a929934062f5d55757d6e714f87cf04da7ce9e93..05c7a01a5793129c38ffce68663ca436ac4f7058 100644 --- a/docs/coding/functions.rst +++ b/docs/coding/functions.rst @@ -66,9 +66,19 @@ You can also decorate given callables in order not to be wrapped with Function overloading -------------------- +.. warning:: This is coded for cases where the function prototypes differ significantly, for ex. matches + only one prototype. For cases where a single call might match multiple prototypes, and if it's + desired that the implementation tells them apart, this implementation might not be of sufficient complexity. + Go file a ticket that you cannot use Satella with some implementation. Just type down what kind of implementation + that was. + .. autoclass:: satella.coding.overload :members: +.. autoclass:: satella.coding.TypeSignature + :members: + + DocsFromParent -------------- diff --git a/satella/__init__.py b/satella/__init__.py index 32815b254730f32de279dbc62335d4a5877b39aa..1c24e3331d9c331f3786944f0d8c003c36aa6c6a 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.23.1a1' +__version__ = '2.23.1b1' diff --git a/satella/coding/__init__.py b/satella/coding/__init__.py index 72369906bb3573798f10f2a4e46f1d937e6bf79d..20d3b7ea441e9502cb0e78b04d0f56074af3f7f4 100644 --- a/satella/coding/__init__.py +++ b/satella/coding/__init__.py @@ -18,7 +18,7 @@ from .misc import update_if_not_none, update_key_if_none, update_attr_if_none, q get_arguments, call_with_arguments, chain_callables, Closeable, contains, \ enum_value, length from .environment import Context -from .overloading import overload, class_or_instancemethod +from .overloading import overload, class_or_instancemethod, TypeSignature from .recast_exceptions import rethrow_as, silence_excs, catch_exception, log_exceptions, \ raises_exception, reraise_as from .expect_exception import expect_exception @@ -27,9 +27,9 @@ from .deep_compare import assert_equal, InequalityReason, Inequal __all__ = [ 'EmptyContextManager', 'Context', 'length', 'assert_equal', 'InequalityReason', 'Inequal', - 'Closeable', 'contains', 'enum_value', 'reraise_as', + 'Closeable', 'contains', 'enum_value', 'expect_exception', - 'overload', 'class_or_instancemethod', + 'overload', 'class_or_instancemethod', 'TypeSignature', 'update_if_not_none', 'DocsFromParent', 'update_key_if_none', 'queue_iterator', 'update_attr_if_none', 'update_key_if_not_none', 'source_to_function', 'update_key_if_true', diff --git a/satella/coding/overloading.py b/satella/coding/overloading.py index 8270de48803d56eab6fe99ccac20020a7109b7bc..c4f1c0eb46dcc79e365ff58692a2e76e7234e9c7 100644 --- a/satella/coding/overloading.py +++ b/satella/coding/overloading.py @@ -1,18 +1,13 @@ +from __future__ import annotations + +import functools import inspect +import operator import typing as tp from inspect import Parameter +from satella.coding.structures import frozendict -def extract_type_signature_from(fun: tp.Callable) -> tp.Tuple[type, ...]: - sign = [] - params = inspect.signature(fun).parameters - for parameter in params.values(): - if parameter.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD): - if parameter.annotation == Parameter.empty: - sign.append(None) - else: - sign.append(parameter.annotation) - return tuple(sign) # Taken from https://stackoverflow.com/questions/28237955/same-name-for-classmethod-and- @@ -38,6 +33,52 @@ class class_or_instancemethod(classmethod): return descr_get(instance, type_) +class TypeSignature(inspect.Signature): + + __slots__ = () + + def __init__(self, t_sign: inspect.Signature): + self._return_annotation = t_sign._return_annotation + self._parameters = t_sign._parameters + + @staticmethod + def from_fun(fun): + return TypeSignature(inspect.Signature.from_callable(fun)) + + def can_be_called_with_args(self, *args, **kwargs) -> bool: + called = self._bind(*args, **kwargs) + return all(issubclass(self.signature.parameters.get(arg_name, NONEARGS)._annotation, arg_value) + for arg_name, arg_value in called.items()) + + def is_more_generic_than(self, b: TypeSignature) -> bool: + if self == {}: + for key in self: + key1 = self[key] + key2 = b.get(key, None) + if key2 is None: + return key2 == {} + + if key2.is_more_generic_than(key1): + return False + return True + + def __lt__(self, other: TypeSignature) -> bool: + return self.is_more_generic_than(other) + + def matches(self, *args, **kwargs) -> bool: + """ + Does this invocation match this signature? + """ + bound_args = self.bind(*args, **kwargs) + bound_args.apply_defaults() + for param_name, param_value in bound_args.arguments.items(): + if isinstance(param_value, self._parameters[param_name].annotation): + continue + else: + return False + return True + + class overload: """ A class used for method overloading. @@ -57,12 +98,11 @@ class overload: >>> print('Int') Note that this instance's __wrapped__ will refer to the first function. + TypeError will be called if no signatures match arguments. """ def __init__(self, fun: tp.Callable): - self.type_signatures_to_functions = { - extract_type_signature_from(fun): fun - } # type: tp.Dict[tp.Tuple[type, ...], tp.Callable] + self.type_signatures_to_functions = {TypeSignature.from_fun(fun): fun} if hasattr(fun, '__doc__'): self.__doc__ = fun.__doc__ self.__wrapped__ = fun @@ -71,10 +111,9 @@ class overload: """ :raises ValueError: this signature already has an overload """ - sign = extract_type_signature_from(fun) + sign = TypeSignature.from_fun(fun) if sign in self.type_signatures_to_functions: - f = self.type_signatures_to_functions[sign] - raise ValueError('Method of this signature is already overloaded with %s' % (f,)) + raise TypeError('Method of this signature is already overloaded with %s' % (f,)) self.type_signatures_to_functions[sign] = fun return self @@ -82,17 +121,14 @@ class overload: """ Call one of the overloaded functions. - :raises TypeError: no type signature matched + :raises TypeError: no type signature given """ + matching = [] for sign, fun in self.type_signatures_to_functions.items(): - try: - for type_, arg in zip(sign, args): - if type_ is None: - continue - if not isinstance(arg, type_): - raise ValueError() - - return fun(*args, **kwargs) - except ValueError: - pass - raise TypeError('No matching functions found') + if sign.matches(*args, **kwargs): + matching.append((sign, fun)) + matching.sort() + if not matching: + raise TypeError('No matching entries!') + else: + return matching[-1][1](*args, **kwargs) # call the most specific function you could find