-
Piotr Maślanka authoredb0112aba
test_decorators.py 7.55 KiB
import logging
import queue
import unittest
from socket import socket
import time
from satella.coding import wraps, chain_functions, postcondition, \
log_exceptions, queue_get, precondition, short_none
from satella.coding.decorators import auto_adapt_to_methods, attach_arguments, \
execute_before, loop_while, memoize, copy_arguments, replace_argument_if, \
retry, return_as_list, default_return, transform_result, transform_arguments, \
cache_memoize
from satella.coding.predicates import x
from satella.exceptions import PreconditionError
logger = logging.getLogger(__name__)
class TestDecorators(unittest.TestCase):
def test_cached_memoizer(self):
a = {'calls': 0}
@cache_memoize(1)
def returns(b):
nonlocal a
a['calls'] += 1
return b
self.assertEqual(returns(6), 6)
self.assertEqual(a['calls'], 1)
self.assertEqual(returns(6), 6)
self.assertEqual(a['calls'], 1)
time.sleep(1.1)
self.assertEqual(returns(6), 6)
self.assertEqual(a['calls'], 2)
def test_transform_arguments(self):
@transform_arguments(a='a*a')
def square(a):
return a
self.assertEqual(square(4), 16)
def test_transform_result(self):
@transform_result('x*a')
def square(a):
return a
self.assertEqual(square(4), 16)
def test_default_returns(self):
@default_return(6)
def returns(v):
return v
self.assertEqual(returns(None), 6)
self.assertEqual(returns(4), 4)
def test_return_as_list(self):
@return_as_list(ignore_nulls=True)
def test():
yield 2
yield 3
yield None
yield 4
self.assertEqual(test(), [2, 3, 4])
@return_as_list()
def test():
yield 2
yield 3
yield None
yield 4
self.assertEqual(test(), [2, 3, None, 4])
def test_retry(self):
a = {'test': 0, 'limit': 2, 'true': False, 'false': False}
def on_failure(e):
nonlocal a
a['true'] = True
def on_success(retries):
nonlocal a
a['false'] = True
@retry(3, ValueError, swallow_exception=False, call_on_failure=on_failure,
call_on_success=on_success)
def do_op():
a['test'] += 1
if a['test'] < a['limit']:
raise ValueError()
do_op()
self.assertTrue(a['false'])
a['limit'] = 10
a['false'] = False
self.assertRaises(ValueError, do_op)
self.assertTrue(a['true'])
self.assertFalse(a['false'])
def test_replace_argument_if(self):
@replace_argument_if('y', x.int(), str)
def ints_only(y):
self.assertEqual(y, 2)
ints_only(2)
ints_only('2')
@replace_argument_if('y', 1, predicate=x % 2 == 0)
def only_odds(y):
self.assertEqual(y % 2, 1)
only_odds(0)
only_odds(1)
@replace_argument_if('args', (x[0], ), predicate=x.len() > 1)
def args_len_1(*args):
self.assertEqual(len(args), 1)
args_len_1(1)
args_len_1(1, 1)
def test_copy_arguments(self):
@copy_arguments()
def alter_dict(dct):
return dct.pop('a')
b = {'a': 5}
self.assertEqual(alter_dict(b), 5)
self.assertEqual(b, {'a': 5})
def test_memoize(self):
a = {'call_count': 0}
@memoize
def memoizer(b):
a['call_count'] += 1
return b
five = memoizer(5)
self.assertEqual(a['call_count'], 1)
five = memoizer(5)
self.assertEqual(a['call_count'], 1)
five = memoizer(6)
self.assertEqual(a['call_count'], 2)
def test_loop_while(self):
class MyLooped:
terminating = False
i = 0
@loop_while(x.i < 10)
def run(self):
self.i += 1
a = MyLooped()
a.run()
self.assertGreaterEqual(a.i, 10)
b = {'i': 0}
@loop_while(lambda: b['i'] < 10)
def run():
nonlocal b
b['i'] += 1
run()
self.assertGreaterEqual(b['i'], 10)
def test_execute_before(self):
a = 0
@execute_before
def increase_a(factor=1):
nonlocal a
a += factor
@increase_a(factor=2)
def launch_me():
nonlocal a
a += 1
launch_me()
self.assertEqual(a, 3)
def test_precondition_none(self):
@precondition(short_none('x == 2'))
def x(y):
return y
x(2)
x(None)
self.assertRaises(PreconditionError, lambda: x(3))
def test_queue_get(self):
class Queue:
def __init__(self):
self.queue = queue.Queue()
self.on_empty_called = False
@queue_get('queue', timeout=0, method_to_execute_on_empty='process_on_empty')
def process(self, item):
pass
def process_on_empty(self):
self.on_empty_called = True
q = Queue()
q.queue.put(True)
q.process()
q.process()
self.assertTrue(q.on_empty_called)
def test_log_exceptions(self):
try:
with log_exceptions(logger):
int('a')
except ValueError:
pass
else:
self.fail('exception swallowed!')
def test_postcondition(self):
@postcondition(lambda x: x == 2)
def return_a_value(x):
return x
self.assertEqual(return_a_value(2), 2)
self.assertRaises(PreconditionError, lambda: return_a_value(3))
def test_auto_adapt_to_methods(self):
@auto_adapt_to_methods
def times_two(fun):
def outer(a):
return fun(a * 2)
return outer
class Test:
@times_two
def twice(self, a):
return a * 2
@times_two
def twice(a):
return a * 2
self.assertEqual(Test().twice(2), 8)
self.assertEqual(twice(2), 8)
def test_chain_kwargs(self):
@chain_functions
def double_arguments(**kwargs):
kwargs['a'] = kwargs['a'] * 2
return kwargs
@double_arguments
def multiply_times_two(**kwargs):
return kwargs['a'] * 2
self.assertEqual(multiply_times_two(a=2), 8)
def test_chain(self):
@chain_functions
def double_arguments(a):
return a * 2
@double_arguments
def multiply_times_two(a):
return a * 2
self.assertEqual(multiply_times_two(2), 8)
def test_attach_arguments(self):
@attach_arguments(label=2)
def test_me(**kwargs):
self.assertEqual(kwargs, {'label': 2, 'value': 4})
test_me(value=4)
def test_wraps(self):
@wraps(socket)
class MySocket(socket):
pass
self.assertEqual(MySocket.__name__, socket.__name__)
self.assertEqual(MySocket.__doc__, socket.__doc__)
self.assertEqual(MySocket.__module__, socket.__module__)
def test_wraps_onfunction(self):
def my_fun(a):
"""Returns the argument"""
return a
@wraps(my_fun)
def f(a):
return my_fun(a)
self.assertEqual(f.__doc__, my_fun.__doc__)
self.assertEqual(f.__name__, my_fun.__name__)
self.assertEqual(f.__module__, my_fun.__module__)