diff --git a/CHANGELOG.md b/CHANGELOG.md index b8a8d96a8c6f5586e209b0dbc3020da0b28eb8c3..13caa2c20a0c3d2d2ad88dda56f08171a103a8cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,3 +3,4 @@ * minor pylint improvements * better coverage * SyncableDroppable.cleanup() bugs out, wrote a quick patch for it to do nothing and filed as #61. +* unit tests for overloading \ No newline at end of file diff --git a/satella/__init__.py b/satella/__init__.py index 49bfef55083a1ffb111a3567c9f4bb3aa50a9ef2..13bf4972493defe1d6bcf7b659a4c4395c8278a7 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.24.1a4' +__version__ = '2.24.1a5' diff --git a/satella/coding/overloading.py b/satella/coding/overloading.py index 7a1eb9bdf4db2c977b148157cdcd411687d3a200..db65cb1da544ce644b5654c7db0df658f7f7ee35 100644 --- a/satella/coding/overloading.py +++ b/satella/coding/overloading.py @@ -28,17 +28,33 @@ class class_or_instancemethod(classmethod): class TypeSignature(inspect.Signature): + """ + A type signature. + + You can compare signatures: + + >>> def a(a: object): + >>> pass + >>> def b(a: int): + >>> pass + >>> TypeSignature.from_fun(a) < TypeSignature(b) + """ __slots__ = () def __init__(self, t_sign: inspect.Signature): + """ + :param t_sign: a inspect.Signature + """ self._return_annotation = t_sign._return_annotation self._parameters = t_sign._parameters @staticmethod - def from_fun(fun): + def from_fun(fun) -> TypeSignature: + """Return a type signature from a function""" return TypeSignature(inspect.Signature.from_callable(fun)) def can_be_called_with_args(self, *args, **kwargs) -> bool: + """Can this type signature be called with following arguments?""" called = self._bind(*args, **kwargs) # pylint: disable=protected-access @@ -46,6 +62,7 @@ class TypeSignature(inspect.Signature): for arg_name, arg_value in called.items()) def is_more_generic_than(self, b: TypeSignature) -> bool: + """Is this type signature more generic than an other?""" if self == {}: for key in self: key1 = self[key] @@ -67,9 +84,7 @@ class TypeSignature(inspect.Signature): bound_args = self.bind(*args, **kwargs) bound_args.apply_defaults() for param_name, param_value in bound_args.arguments.items(): - if isinstance(param_value, self._parameters[param_name].annotation): - continue - else: + if not isinstance(param_value, self._parameters[param_name].annotation): return False return True @@ -125,13 +140,9 @@ class overload: """ matchings = [] for sign, fun in self.type_signatures_to_functions.items(): - print('Matching %s against %s', sign, fun) if sign.matches(*args, **kwargs): matchings.append((sign, fun)) - else: - print('Did not score a math between %s:%s and %s', args, kwargs, ) matchings.sort() - print(matchings) if not matchings: raise TypeError('No matching entries!') else: diff --git a/tests/test_coding/test_overloading.py b/tests/test_coding/test_overloading.py new file mode 100644 index 0000000000000000000000000000000000000000..931bed3f1a3191ed8a8d8216518783ada89e3f91 --- /dev/null +++ b/tests/test_coding/test_overloading.py @@ -0,0 +1,26 @@ +import unittest + +from satella.coding import overload, TypeSignature + + +class TestOverloading(unittest.TestCase): + def test_type_signature(self): + def a(a: object): + pass + + def b(a: int): + pass + + self.assertLess(TypeSignature.from_fun(a), TypeSignature.from_fun(b)) + + def test_something(self): + @overload + def fun(i: int): + self.assertIsInstance(i, int) + + @fun.overload + def fun(i: str): + self.assertIsInstance(i, str) + + fun(2) + fun('test')