diff --git a/CHANGELOG.md b/CHANGELOG.md index f6a2d1372e81ec75b6258856da32fb864e5e90f6..90edbe54c8a4c5b781ad2204caccb5263db95007 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # v2.7.12 -* _TBA_ +* bugfix release: `even`, `odd` and `count` will now accept Iterables # v2.7.11 diff --git a/satella/__init__.py b/satella/__init__.py index fa63300a5c1c276afd2227714d1af8757a381fe2..c2da2f54816d9b6e42648433657414100d10edbb 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.7.12_a1' +__version__ = '2.7.12' diff --git a/satella/coding/sequences/iterators.py b/satella/coding/sequences/iterators.py index 3f87aedfc4a4546d9a7d3042851b6dedc1e27cc3..34e31033472dc73a92aeaa334eb2750602c9955d 100644 --- a/satella/coding/sequences/iterators.py +++ b/satella/coding/sequences/iterators.py @@ -3,10 +3,15 @@ import itertools import typing as tp import warnings +from ..decorators import for_argument + T, U = tp.TypeVar('T'), tp.TypeVar('U') +IteratorOrIterable = tp.Union[tp.Iterator[T], tp.Iterable[T]] + -def even(sq: tp.Iterator[T]) -> tp.Iterator[T]: +@for_argument(iter) +def even(sq: IteratorOrIterable) -> tp.Iterator[T]: """ Return only elements with even indices in this iterable (first element will be returned, as indices are counted from 0) @@ -16,7 +21,8 @@ def even(sq: tp.Iterator[T]) -> tp.Iterator[T]: next(sq) -def odd(sq: tp.Iterator[T]) -> tp.Iterator[T]: +@for_argument(iter) +def odd(sq: IteratorOrIterable) -> tp.Iterator[T]: """ Return only elements with odd indices in this iterable. """ @@ -25,7 +31,7 @@ def odd(sq: tp.Iterator[T]) -> tp.Iterator[T]: yield next(sq) -def count(sq: tp.Iterator, start: tp.Optional[int] = None, step: int = 1) -> tp.Iterator[int]: +def count(sq: IteratorOrIterable, start: tp.Optional[int] = None, step: int = 1) -> tp.Iterator[int]: """ Return a sequence of integers, for each entry in the sequence with provided step. @@ -78,9 +84,6 @@ def is_instance(classes: tp.Union[tp.Tuple[type, ...], type]) -> tp.Callable[[ob return inner -T = tp.TypeVar('T') - -IteratorOrIterable = tp.Union[tp.Iterator[T], tp.Iterable[T]] def other_sequence_no_longer_than(base_sequence: IteratorOrIterable, diff --git a/tests/test_coding/test_sequences.py b/tests/test_coding/test_sequences.py index 281cc994d056ef1e5a447658e6b75b98fbc6e260..753c0d8c15999bb6621b03fc462250a113d21c3e 100644 --- a/tests/test_coding/test_sequences.py +++ b/tests/test_coding/test_sequences.py @@ -3,13 +3,18 @@ import unittest from satella.coding.sequences import choose, infinite_counter, take_n, is_instance, is_last, \ add_next, half_product, skip_first, zip_shifted, stop_after, group_quantity, \ - iter_dict_of_list, shift, other_sequence_no_longer_than, count + iter_dict_of_list, shift, other_sequence_no_longer_than, count, even, odd logger = logging.getLogger(__name__) class TestSequences(unittest.TestCase): + def test_even_and_odd(self): + a = [0, 1, 2, 3, 4, 5, 6] + self.assertEqual(list(even(a)), [0, 2, 4, 6]) + self.assertEqual(list(odd(a)), [1, 3, 5]) + def test_count(self): self.assertEqual(list(count([None, None, None], 5, -2)), [5, 3, 1])