diff --git a/src/seq/mod.rs b/src/seq/mod.rs index 4dea058c2db..f8d21c9a444 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -292,10 +292,11 @@ pub trait IteratorRandom: Iterator + Sized { /// available, complexity is `O(n)` where `n` is the iterator length. /// Partial hints (where `lower > 0`) also improve performance. /// - /// Note that the output values and the the number of RNG samples used + /// Note that the output values and the number of RNG samples used /// depends on size hints. In particular, `Iterator` combinators that don't /// change the values yielded but change the size hints may result in - /// `choose` returning different elements. + /// `choose` returning different elements. If you want consistent results + /// and RNG usage consider using [`choose_stable`]. fn choose(mut self, rng: &mut R) -> Option where R: Rng + ?Sized { let (mut lower, mut upper) = self.size_hint(); @@ -347,6 +348,62 @@ pub trait IteratorRandom: Iterator + Sized { } } + /// Choose one element at random from the iterator. + /// + /// Returns `None` if and only if the iterator is empty. + /// + /// This method is very similar to [`choose`] except that the result + /// only depends on the length of the iterator and the values produced by + /// `rng`. Notably for any iterator of a given length this will make the + /// same requests to `rng` and if the same sequence of values are produced + /// the same index will be selected from `self`. This may be useful if you + /// need consistent results no matter what type of iterator you are working + /// with. If you do not need this stability prefer [`choose`]. + /// + /// Note that this method still uses [`Iterator::size_hint`] to skip + /// constructing elements where possible, however the selection and `rng` + /// calls are the same in the face of this optimization. If you want to + /// force every element to be created regardless call `.inspect(|e| ())`. + fn choose_stable(mut self, rng: &mut R) -> Option + where R: Rng + ?Sized { + let mut consumed = 0; + let mut result = None; + + loop { + // Currently the only way to skip elements is `nth()`. So we need to + // store what index to access next here. + // This should be replaced by `advance_by()` once it is stable: + // https://github.com/rust-lang/rust/issues/77404 + let mut next = 0; + + let (lower, _) = self.size_hint(); + if lower >= 2 { + let highest_selected = (0..lower) + .filter(|ix| gen_index(rng, consumed+ix+1) == 0) + .last(); + + consumed += lower; + next = lower; + + if let Some(ix) = highest_selected { + result = self.nth(ix); + next -= ix + 1; + debug_assert!(result.is_some(), "iterator shorter than size_hint().0"); + } + } + + let elem = self.nth(next); + if elem.is_none() { + return result + } + + if gen_index(rng, consumed+1) == 0 { + result = elem; + } + consumed += 1; + } + } + /// Collects values at random from the iterator into a supplied buffer /// until that buffer is filled. /// @@ -794,6 +851,103 @@ mod test { assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); } + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable() { + let r = &mut crate::test::rng(109); + fn test_iter + Clone>(r: &mut R, iter: Iter) { + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + for count in chosen.iter() { + // Samples should follow Binomial(1000, 1/9) + // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x + // Note: have seen 153, which is unlikely but not impossible. + assert!( + 72 < *count && *count < 154, + "count not close to 1000/9: {}", + count + ); + } + } + + test_iter(r, 0..9); + test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()); + #[cfg(feature = "alloc")] + test_iter(r, (0..9).collect::>().into_iter()); + test_iter(r, UnhintedIterator { iter: 0..9 }); + test_iter(r, ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }); + test_iter(r, ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }); + test_iter(r, WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }); + test_iter(r, WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }); + + assert_eq!((0..0).choose(r), None); + assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None); + } + + #[test] + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_iterator_choose_stable_stability() { + fn test_iter(iter: impl Iterator + Clone) -> [i32; 9] { + let r = &mut crate::test::rng(109); + let mut chosen = [0i32; 9]; + for _ in 0..1000 { + let picked = iter.clone().choose_stable(r).unwrap(); + chosen[picked] += 1; + } + chosen + } + + let reference = test_iter(0..9); + assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference); + + #[cfg(feature = "alloc")] + assert_eq!(test_iter((0..9).collect::>().into_iter()), reference); + assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference); + assert_eq!(test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: false, + }), reference); + assert_eq!(test_iter(ChunkHintedIterator { + iter: 0..9, + chunk_size: 4, + chunk_remaining: 4, + hint_total_size: true, + }), reference); + assert_eq!(test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: false, + }), reference); + assert_eq!(test_iter(WindowHintedIterator { + iter: 0..9, + window_size: 2, + hint_total_size: true, + }), reference); + } + #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_shuffle() { @@ -999,6 +1153,52 @@ mod test { ); } + #[test] + fn value_stability_choose_stable() { + fn choose>(iter: I) -> Option { + let mut rng = crate::test::rng(411); + iter.choose_stable(&mut rng) + } + + assert_eq!(choose([].iter().cloned()), None); + assert_eq!(choose(0..100), Some(40)); + assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(40)); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: false, + }), + Some(40) + ); + assert_eq!( + choose(ChunkHintedIterator { + iter: 0..100, + chunk_size: 32, + chunk_remaining: 32, + hint_total_size: true, + }), + Some(40) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: false, + }), + Some(40) + ); + assert_eq!( + choose(WindowHintedIterator { + iter: 0..100, + window_size: 32, + hint_total_size: true, + }), + Some(40) + ); + } + #[test] fn value_stability_choose_multiple() { fn do_test>(iter: I, v: &[u32]) {