Skip to content
Snippets Groups Projects
test_decorators.py 7.55 KiB
import logging
import queue
import unittest
from socket import socket

import time
from satella.coding import wraps, chain_functions, postcondition, \
    log_exceptions, queue_get, precondition, short_none
from satella.coding.decorators import auto_adapt_to_methods, attach_arguments, \
    execute_before, loop_while, memoize, copy_arguments, replace_argument_if, \
    retry, return_as_list, default_return, transform_result, transform_arguments, \
    cache_memoize
from satella.coding.predicates import x
from satella.exceptions import PreconditionError

logger = logging.getLogger(__name__)


class TestDecorators(unittest.TestCase):

    def test_cached_memoizer(self):
        a = {'calls': 0}

        @cache_memoize(1)
        def returns(b):
            nonlocal a
            a['calls'] += 1
            return b

        self.assertEqual(returns(6), 6)
        self.assertEqual(a['calls'], 1)
        self.assertEqual(returns(6), 6)
        self.assertEqual(a['calls'], 1)
        time.sleep(1.1)
        self.assertEqual(returns(6), 6)
        self.assertEqual(a['calls'], 2)

    def test_transform_arguments(self):
        @transform_arguments(a='a*a')
        def square(a):
            return a

        self.assertEqual(square(4), 16)

    def test_transform_result(self):
        @transform_result('x*a')
        def square(a):
            return a

        self.assertEqual(square(4), 16)

    def test_default_returns(self):
        @default_return(6)
        def returns(v):
            return v

        self.assertEqual(returns(None), 6)
        self.assertEqual(returns(4), 4)

    def test_return_as_list(self):
        @return_as_list(ignore_nulls=True)
        def test():
            yield 2
            yield 3
            yield None
            yield 4

        self.assertEqual(test(), [2, 3, 4])

        @return_as_list()
        def test():
            yield 2
            yield 3
            yield None
            yield 4

        self.assertEqual(test(), [2, 3, None, 4])

    def test_retry(self):
        a = {'test': 0, 'limit': 2, 'true': False, 'false': False}

        def on_failure(e):
            nonlocal a
            a['true'] = True

        def on_success(retries):
            nonlocal a
            a['false'] = True

        @retry(3, ValueError, swallow_exception=False, call_on_failure=on_failure,
               call_on_success=on_success)
        def do_op():
            a['test'] += 1
            if a['test'] < a['limit']:
                raise ValueError()

        do_op()
        self.assertTrue(a['false'])
        a['limit'] = 10
        a['false'] = False
        self.assertRaises(ValueError, do_op)
        self.assertTrue(a['true'])
        self.assertFalse(a['false'])

    def test_replace_argument_if(self):
        @replace_argument_if('y', x.int(), str)
        def ints_only(y):
            self.assertEqual(y, 2)
        ints_only(2)
        ints_only('2')

        @replace_argument_if('y', 1, predicate=x % 2 == 0)
        def only_odds(y):
            self.assertEqual(y % 2, 1)

        only_odds(0)
        only_odds(1)

        @replace_argument_if('args', (x[0], ), predicate=x.len() > 1)
        def args_len_1(*args):
            self.assertEqual(len(args), 1)

        args_len_1(1)
        args_len_1(1, 1)

    def test_copy_arguments(self):
        @copy_arguments()
        def alter_dict(dct):
            return dct.pop('a')

        b = {'a': 5}
        self.assertEqual(alter_dict(b), 5)
        self.assertEqual(b, {'a': 5})

    def test_memoize(self):
        a = {'call_count': 0}

        @memoize
        def memoizer(b):
            a['call_count'] += 1
            return b

        five = memoizer(5)
        self.assertEqual(a['call_count'], 1)
        five = memoizer(5)
        self.assertEqual(a['call_count'], 1)
        five = memoizer(6)
        self.assertEqual(a['call_count'], 2)

    def test_loop_while(self):
        class MyLooped:
            terminating = False
            i = 0

            @loop_while(x.i < 10)
            def run(self):
                self.i += 1

        a = MyLooped()
        a.run()
        self.assertGreaterEqual(a.i, 10)
        b = {'i': 0}

        @loop_while(lambda: b['i'] < 10)
        def run():
            nonlocal b
            b['i'] += 1

        run()
        self.assertGreaterEqual(b['i'], 10)

    def test_execute_before(self):
        a = 0

        @execute_before
        def increase_a(factor=1):
            nonlocal a
            a += factor

        @increase_a(factor=2)
        def launch_me():
            nonlocal a
            a += 1

        launch_me()
        self.assertEqual(a, 3)

    def test_precondition_none(self):
        @precondition(short_none('x == 2'))
        def x(y):
            return y
        x(2)
        x(None)
        self.assertRaises(PreconditionError, lambda: x(3))

    def test_queue_get(self):
        class Queue:
            def __init__(self):
                self.queue = queue.Queue()
                self.on_empty_called = False

            @queue_get('queue', timeout=0, method_to_execute_on_empty='process_on_empty')
            def process(self, item):
                pass

            def process_on_empty(self):
                self.on_empty_called = True

        q = Queue()
        q.queue.put(True)
        q.process()
        q.process()
        self.assertTrue(q.on_empty_called)

    def test_log_exceptions(self):
        try:
            with log_exceptions(logger):
                int('a')
        except ValueError:
            pass
        else:
            self.fail('exception swallowed!')

    def test_postcondition(self):
        @postcondition(lambda x: x == 2)
        def return_a_value(x):
            return x

        self.assertEqual(return_a_value(2), 2)
        self.assertRaises(PreconditionError, lambda: return_a_value(3))

    def test_auto_adapt_to_methods(self):
        @auto_adapt_to_methods
        def times_two(fun):
            def outer(a):
                return fun(a * 2)

            return outer

        class Test:
            @times_two
            def twice(self, a):
                return a * 2

        @times_two
        def twice(a):
            return a * 2

        self.assertEqual(Test().twice(2), 8)
        self.assertEqual(twice(2), 8)

    def test_chain_kwargs(self):
        @chain_functions
        def double_arguments(**kwargs):
            kwargs['a'] = kwargs['a'] * 2
            return kwargs

        @double_arguments
        def multiply_times_two(**kwargs):
            return kwargs['a'] * 2

        self.assertEqual(multiply_times_two(a=2), 8)

    def test_chain(self):
        @chain_functions
        def double_arguments(a):
            return a * 2

        @double_arguments
        def multiply_times_two(a):
            return a * 2

        self.assertEqual(multiply_times_two(2), 8)

    def test_attach_arguments(self):
        @attach_arguments(label=2)
        def test_me(**kwargs):
            self.assertEqual(kwargs, {'label': 2, 'value': 4})

        test_me(value=4)

    def test_wraps(self):
        @wraps(socket)
        class MySocket(socket):
            pass

        self.assertEqual(MySocket.__name__, socket.__name__)
        self.assertEqual(MySocket.__doc__, socket.__doc__)
        self.assertEqual(MySocket.__module__, socket.__module__)

    def test_wraps_onfunction(self):
        def my_fun(a):
            """Returns the argument"""
            return a

        @wraps(my_fun)
        def f(a):
            return my_fun(a)

        self.assertEqual(f.__doc__, my_fun.__doc__)
        self.assertEqual(f.__name__, my_fun.__name__)
        self.assertEqual(f.__module__, my_fun.__module__)