diff --git a/CHANGELOG.md b/CHANGELOG.md index 179308ea60b21f733c6274f007e12af22c8f13ef..601f12c8e40519eae868a43cb6c34597f6c2ae17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,4 +5,5 @@ * SyncableDroppable.cleanup() bugs out, wrote a quick patch for it to do nothing and filed as #61. * unit tests for overloading * fixed #60 +* fixed #59 diff --git a/satella/__init__.py b/satella/__init__.py index 265ec16a287bac6a0672bb9d403d240a0c2b3aae..57a522ce8b02bcfff918389cf07ecf0e62cf7e1f 100644 --- a/satella/__init__.py +++ b/satella/__init__.py @@ -1 +1 @@ -__version__ = '2.24.1a7' +__version__ = '2.24.1' diff --git a/satella/coding/sequences/iterators.py b/satella/coding/sequences/iterators.py index c589ac53abc90ddfc9c128dd6f3a1798f9ac2fcd..99a059ba508d89b85be509fd80ed821c1bd1f934 100644 --- a/satella/coding/sequences/iterators.py +++ b/satella/coding/sequences/iterators.py @@ -480,8 +480,8 @@ def skip_first(iterator: Iteratable, n: int) -> tp.Iterator[T]: class _ListWrapperIteratorIterator(tp.Iterator[T]): __slots__ = 'parent', 'pos' - def __init__(self, parent): - self.parent = parent + def __init__(self, parent: Iteratable[T]): + self.parent = parent if isinstance(parent, tp.Iterator) else iter(parent) self.pos = 0 def __length_hint__(self) -> int: @@ -514,25 +514,21 @@ class ListWrapperIterator(tp.Iterator[T]): This is additionally a generic class. """ - __slots__ = 'iterator', 'exhausted', 'list' - - def __next__(self) -> T: - return self.next() + __slots__ = 'iterator', 'exhausted', 'list', 'internal_pointer' def __contains__(self, item: T) -> bool: if item not in self.list and self.exhausted: return False - for item2 in self: - self.list.append(item2) + for item2 in itertools.chain(self.list, self): if item2 == item: return True - else: - return False + return False - def __init__(self, iterator: Iteratable): - self.iterator = iter(iterator) + def __init__(self, iterator: Iteratable[T]): + self.iterator = iterator if isinstance(iterator, tp.Iterator) else iter(iterator) self.exhausted = False self.list = [] + self.internal_pointer = 0 def exhaust(self) -> None: """ @@ -553,16 +549,20 @@ class ListWrapperIterator(tp.Iterator[T]): while len(self.list) < i: try: - self.list.append(next(self.iterator)) + item = next(self.iterator) + self.list.append(item) except StopIteration: self.exhausted = True - return + break def __len__(self) -> int: self.exhaust() return len(self.list) def __getitem__(self, item: tp.Union[slice, int]) -> tp.Union[tp.List[T], T]: + """ + :raises IndexError: invalid index + """ if isinstance(item, int): if len(self.list) < item + 1: self.advance_to_item(item + 1) @@ -570,21 +570,37 @@ class ListWrapperIterator(tp.Iterator[T]): self.advance_to_item(item.stop) return self.list[item] - def next(self) -> T: + def __next__(self) -> T: """ Get the next item :raises StopIteration: next element is not available due to iterator finishing """ - if self.exhausted: + if self.internal_pointer == len(self.list) and self.exhausted: raise StopIteration() - else: + + # We can serve that from memory + if len(self.list) > self.internal_pointer: + in_ptr = self.internal_pointer + elem = self.list[in_ptr] + self.internal_pointer += 1 + return elem + + # We cannot serve it from memory + try: item = next(self.iterator) self.list.append(item) + self.internal_pointer += 1 return item + except StopIteration: + self.exhausted = True + raise def __iter__(self) -> tp.Iterator[T]: - return _ListWrapperIteratorIterator(self) + """ + Return a brand new iterator, that will use this iterator + """ + return self @silence_excs(StopIteration) diff --git a/tests/test_coding/test_iterators.py b/tests/test_coding/test_iterators.py index a8ac7cc11179c8a13426b1aea2c322b9d5484b0a..28b2d74e5bf927df2507d27641c6729f3b71f174 100644 --- a/tests/test_coding/test_iterators.py +++ b/tests/test_coding/test_iterators.py @@ -6,27 +6,30 @@ from satella.coding.sequences import smart_enumerate, ConstruableIterator, walk, IteratorListAdapter, is_empty, ListWrapperIterator +def iterate(): + yield 1 + yield 2 + yield 3 + yield 4 + yield 5 + + class TestIterators(unittest.TestCase): - @unittest.skip("Fails the entire suite by hanging") def test_list_wrapper_iterator_contains(self): - def iterate(): - yield 1 - yield 2 - yield 3 - yield 4 - yield 5 lwe = ListWrapperIterator(iterate()) self.assertTrue(2 in lwe) + self.assertEqual(lwe.internal_pointer, 2) self.assertLessEqual(len(lwe.list), 2) self.assertFalse(6 in lwe) self.assertEqual(len(lwe.list), 5) + self.assertEqual(lwe.internal_pointer, 5) def test_list_wrapper_iterator(self): a = {'count': 0} - def iterate(): + def iterate2(): nonlocal a a['count'] += 1 @@ -36,12 +39,13 @@ class TestIterators(unittest.TestCase): a['count'] += 1 yield 3 - lwe = ListWrapperIterator(iterate()) - self.assertEqual(list(iter(lwe)), [1, 2, 3]) + lwe = ListWrapperIterator(iterate2()) + self.assertEqual(list(lwe), [1, 2, 3]) + return self.assertEqual(a['count'], 3) - self.assertEqual(list(iter(lwe)), [1, 2, 3]) + self.assertEqual(list(lwe), []) self.assertEqual(a['count'], 3) - lwe2 = ListWrapperIterator(iterate()) + lwe2 = ListWrapperIterator(iterate2()) self.assertEqual(lwe2[1:2], [2]) self.assertEqual(a['count'], 5) self.assertEqual(lwe2[2], 3)