From 3fcbc31489cafc731d8c7212ffc7341fa5d80299 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Tue, 21 Apr 2015 17:43:10 -0400 Subject: [PATCH] Optimize iterator adapters. Specifically, make count, nth, and last call the corresponding methods on the underlying iterator where possible. This way, if the underlying iterator has an optimized count, nth, or last implementations (e.g. slice::Iter), these methods will propagate these optimizations. Additionally, change Skip::next to take advantage of a potentially optimized nth method on the underlying iterator. --- src/libcore/iter.rs | 163 ++++++++++++++++++++++++++---- src/libcoretest/iter.rs | 215 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 360 insertions(+), 18 deletions(-) diff --git a/src/libcore/iter.rs b/src/libcore/iter.rs index be66623d0e151..37cabe6021698 100644 --- a/src/libcore/iter.rs +++ b/src/libcore/iter.rs @@ -1472,6 +1472,32 @@ impl Iterator for Chain where } } + #[inline] + fn count(self) -> usize { + (if !self.flag { self.a.count() } else { 0 }) + self.b.count() + } + + #[inline] + fn nth(&mut self, mut n: usize) -> Option { + if !self.flag { + for x in self.a.by_ref() { + if n == 0 { + return Some(x) + } + n -= 1; + } + self.flag = true; + } + self.b.nth(n) + } + + #[inline] + fn last(self) -> Option { + let a_last = if self.flag { None } else { self.a.last() }; + let b_last = self.b.last(); + b_last.or(a_last) + } + #[inline] fn size_hint(&self) -> (usize, Option) { let (a_lower, a_upper) = self.a.size_hint(); @@ -1777,6 +1803,20 @@ impl Iterator for Enumerate where I: Iterator { fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() } + + #[inline] + fn nth(&mut self, n: usize) -> Option<(usize, I::Item)> { + self.iter.nth(n).map(|a| { + let i = self.count + n; + self.count = i + 1; + (i, a) + }) + } + + #[inline] + fn count(self) -> usize { + self.iter.count() + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1834,6 +1874,28 @@ impl Iterator for Peekable { } } + #[inline] + fn count(self) -> usize { + (if self.peeked.is_some() { 1 } else { 0 }) + self.iter.count() + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + match self.peeked { + Some(_) if n == 0 => self.peeked.take(), + Some(_) => { + self.peeked = None; + self.iter.nth(n-1) + }, + None => self.iter.nth(n) + } + } + + #[inline] + fn last(self) -> Option { + self.iter.last().or(self.peeked) + } + #[inline] fn size_hint(&self) -> (usize, Option) { let (lo, hi) = self.iter.size_hint(); @@ -1960,27 +2022,49 @@ impl Iterator for Skip where I: Iterator { type Item = ::Item; #[inline] - fn next(&mut self) -> Option<::Item> { - let mut next = self.iter.next(); + fn next(&mut self) -> Option { if self.n == 0 { - next + self.iter.next() } else { - let mut n = self.n; - while n > 0 { - n -= 1; - match next { - Some(_) => { - next = self.iter.next(); - continue - } - None => { - self.n = 0; - return None - } - } - } + let old_n = self.n; self.n = 0; - next + self.iter.nth(old_n) + } + } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + // Can't just add n + self.n due to overflow. + if self.n == 0 { + self.iter.nth(n) + } else { + let to_skip = self.n; + self.n = 0; + // nth(n) skips n+1 + if self.iter.nth(to_skip-1).is_none() { + return None; + } + self.iter.nth(n) + } + } + + #[inline] + fn count(self) -> usize { + self.iter.count().saturating_sub(self.n) + } + + #[inline] + fn last(mut self) -> Option { + if self.n == 0 { + self.iter.last() + } else { + let next = self.next(); + if next.is_some() { + // recurse. n should be 0. + self.last().or(next) + } else { + None + } } } @@ -2038,6 +2122,20 @@ impl Iterator for Take where I: Iterator{ } } + #[inline] + fn nth(&mut self, n: usize) -> Option { + if self.n > n { + self.n -= n + 1; + self.iter.nth(n) + } else { + if self.n > 0 { + self.iter.nth(self.n - 1); + self.n = 0; + } + None + } + } + #[inline] fn size_hint(&self) -> (usize, Option) { let (lower, upper) = self.iter.size_hint(); @@ -2199,6 +2297,35 @@ impl Iterator for Fuse where I: Iterator { } } + #[inline] + fn nth(&mut self, n: usize) -> Option { + if self.done { + None + } else { + let nth = self.iter.nth(n); + self.done = nth.is_none(); + nth + } + } + + #[inline] + fn last(self) -> Option { + if self.done { + None + } else { + self.iter.last() + } + } + + #[inline] + fn count(self) -> usize { + if self.done { + 0 + } else { + self.iter.count() + } + } + #[inline] fn size_hint(&self) -> (usize, Option) { if self.done { diff --git a/src/libcoretest/iter.rs b/src/libcoretest/iter.rs index 95a6236e9c394..0415c75aa5204 100644 --- a/src/libcoretest/iter.rs +++ b/src/libcoretest/iter.rs @@ -100,6 +100,42 @@ fn test_iterator_chain() { assert_eq!(i, expected.len()); } +#[test] +fn test_iterator_chain_nth() { + let xs = [0, 1, 2, 3, 4, 5]; + let ys = [30, 40, 50, 60]; + let zs = []; + let expected = [0, 1, 2, 3, 4, 5, 30, 40, 50, 60]; + for (i, x) in expected.iter().enumerate() { + assert_eq!(Some(x), xs.iter().chain(ys.iter()).nth(i)); + } + assert_eq!(zs.iter().chain(xs.iter()).nth(0), Some(&0)); + + let mut it = xs.iter().chain(zs.iter()); + assert_eq!(it.nth(5), Some(&5)); + assert_eq!(it.next(), None); +} + +#[test] +fn test_iterator_chain_last() { + let xs = [0, 1, 2, 3, 4, 5]; + let ys = [30, 40, 50, 60]; + let zs = []; + assert_eq!(xs.iter().chain(ys.iter()).last(), Some(&60)); + assert_eq!(zs.iter().chain(ys.iter()).last(), Some(&60)); + assert_eq!(ys.iter().chain(zs.iter()).last(), Some(&60)); + assert_eq!(zs.iter().chain(zs.iter()).last(), None); +} + +#[test] +fn test_iterator_chain_count() { + let xs = [0, 1, 2, 3, 4, 5]; + let ys = [30, 40, 50, 60]; + let zs = []; + assert_eq!(xs.iter().chain(ys.iter()).count(), 10); + assert_eq!(zs.iter().chain(ys.iter()).count(), 4); +} + #[test] fn test_filter_map() { let it = (0..).step_by(1).take(10) @@ -116,6 +152,34 @@ fn test_iterator_enumerate() { } } +#[test] +fn test_iterator_enumerate_nth() { + let xs = [0, 1, 2, 3, 4, 5]; + for (i, &x) in xs.iter().enumerate() { + assert_eq!(i, x); + } + + let mut it = xs.iter().enumerate(); + while let Some((i, &x)) = it.nth(0) { + assert_eq!(i, x); + } + + let mut it = xs.iter().enumerate(); + while let Some((i, &x)) = it.nth(1) { + assert_eq!(i, x); + } + + let (i, &x) = xs.iter().enumerate().nth(3).unwrap(); + assert_eq!(i, x); + assert_eq!(i, 3); +} + +#[test] +fn test_iterator_enumerate_count() { + let xs = [0, 1, 2, 3, 4, 5]; + assert_eq!(xs.iter().count(), 6); +} + #[test] fn test_iterator_peekable() { let xs = vec![0, 1, 2, 3, 4, 5]; @@ -148,6 +212,59 @@ fn test_iterator_peekable() { assert_eq!(it.len(), 0); } +#[test] +fn test_iterator_peekable_count() { + let xs = [0, 1, 2, 3, 4, 5]; + let ys = [10]; + let zs: [i32; 0] = []; + + assert_eq!(xs.iter().peekable().count(), 6); + + let mut it = xs.iter().peekable(); + assert_eq!(it.peek(), Some(&&0)); + assert_eq!(it.count(), 6); + + assert_eq!(ys.iter().peekable().count(), 1); + + let mut it = ys.iter().peekable(); + assert_eq!(it.peek(), Some(&&10)); + assert_eq!(it.count(), 1); + + assert_eq!(zs.iter().peekable().count(), 0); + + let mut it = zs.iter().peekable(); + assert_eq!(it.peek(), None); + +} + +#[test] +fn test_iterator_peekable_nth() { + let xs = [0, 1, 2, 3, 4, 5]; + let mut it = xs.iter().peekable(); + + assert_eq!(it.peek(), Some(&&0)); + assert_eq!(it.nth(0), Some(&0)); + assert_eq!(it.peek(), Some(&&1)); + assert_eq!(it.nth(1), Some(&2)); + assert_eq!(it.peek(), Some(&&3)); + assert_eq!(it.nth(2), Some(&5)); + assert_eq!(it.next(), None); +} + +#[test] +fn test_iterator_peekable_last() { + let xs = [0, 1, 2, 3, 4, 5]; + let ys = [0]; + + let mut it = xs.iter().peekable(); + assert_eq!(it.peek(), Some(&&0)); + assert_eq!(it.last(), Some(&5)); + + let mut it = ys.iter().peekable(); + assert_eq!(it.peek(), Some(&&0)); + assert_eq!(it.last(), Some(&0)); +} + #[test] fn test_iterator_take_while() { let xs = [0, 1, 2, 3, 5, 13, 15, 16, 17, 19]; @@ -189,6 +306,49 @@ fn test_iterator_skip() { assert_eq!(it.len(), 0); } +#[test] +fn test_iterator_skip_nth() { + let xs = [0, 1, 2, 3, 5, 13, 15, 16, 17, 19, 20, 30]; + + let mut it = xs.iter().skip(0); + assert_eq!(it.nth(0), Some(&0)); + assert_eq!(it.nth(1), Some(&2)); + + let mut it = xs.iter().skip(5); + assert_eq!(it.nth(0), Some(&13)); + assert_eq!(it.nth(1), Some(&16)); + + let mut it = xs.iter().skip(12); + assert_eq!(it.nth(0), None); + +} + +#[test] +fn test_iterator_skip_count() { + let xs = [0, 1, 2, 3, 5, 13, 15, 16, 17, 19, 20, 30]; + + assert_eq!(xs.iter().skip(0).count(), 12); + assert_eq!(xs.iter().skip(1).count(), 11); + assert_eq!(xs.iter().skip(11).count(), 1); + assert_eq!(xs.iter().skip(12).count(), 0); + assert_eq!(xs.iter().skip(13).count(), 0); +} + +#[test] +fn test_iterator_skip_last() { + let xs = [0, 1, 2, 3, 5, 13, 15, 16, 17, 19, 20, 30]; + + assert_eq!(xs.iter().skip(0).last(), Some(&30)); + assert_eq!(xs.iter().skip(1).last(), Some(&30)); + assert_eq!(xs.iter().skip(11).last(), Some(&30)); + assert_eq!(xs.iter().skip(12).last(), None); + assert_eq!(xs.iter().skip(13).last(), None); + + let mut it = xs.iter().skip(5); + assert_eq!(it.next(), Some(&13)); + assert_eq!(it.last(), Some(&30)); +} + #[test] fn test_iterator_take() { let xs = [0, 1, 2, 3, 5, 13, 15, 16, 17, 19]; @@ -205,6 +365,30 @@ fn test_iterator_take() { assert_eq!(it.len(), 0); } +#[test] +fn test_iterator_take_nth() { + let xs = [0, 1, 2, 4, 5]; + let mut it = xs.iter(); + { + let mut take = it.by_ref().take(3); + let mut i = 0; + while let Some(&x) = take.nth(0) { + assert_eq!(x, i); + i += 1; + } + } + assert_eq!(it.nth(1), Some(&5)); + assert_eq!(it.nth(0), None); + + let xs = [0, 1, 2, 3, 4]; + let mut it = xs.iter().take(7); + let mut i = 1; + while let Some(&x) = it.nth(1) { + assert_eq!(x, i); + i += 2; + } +} + #[test] fn test_iterator_take_short() { let xs = [0, 1, 2, 3]; @@ -881,6 +1065,37 @@ fn test_fuse() { assert_eq!(it.len(), 0); } +#[test] +fn test_fuse_nth() { + let xs = [0, 1, 2]; + let mut it = xs.iter(); + + assert_eq!(it.len(), 3); + assert_eq!(it.nth(2), Some(&2)); + assert_eq!(it.len(), 0); + assert_eq!(it.nth(2), None); + assert_eq!(it.len(), 0); +} + +#[test] +fn test_fuse_last() { + let xs = [0, 1, 2]; + let it = xs.iter(); + + assert_eq!(it.len(), 3); + assert_eq!(it.last(), Some(&2)); +} + +#[test] +fn test_fuse_count() { + let xs = [0, 1, 2]; + let it = xs.iter(); + + assert_eq!(it.len(), 3); + assert_eq!(it.count(), 3); + // Can't check len now because count consumes. +} + #[bench] fn bench_rposition(b: &mut Bencher) { let it: Vec = (0..300).collect();