Skip to content

Commit

Permalink
Add Iterator::choose_stable()
Browse files Browse the repository at this point in the history
This function is similar to Iterator::choose() except that given a PRNG and any iterator of the same length it will always select the same element and make the same calls to the PRNG.

Closes #1051
  • Loading branch information
kevincox committed Oct 10, 2020
1 parent 14ba1bb commit 7b35598
Showing 1 changed file with 153 additions and 0 deletions.
153 changes: 153 additions & 0 deletions src/seq/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,62 @@ pub trait IteratorRandom: Iterator + Sized {
}
}

/// Choose one element at random from the iterator.
///
/// Returns `None` if and only if the iterator is empty.
///
/// This method is very similar to [`choose`] except that that the result
/// only depends on the length of the iterator and the values produced by
/// `rng`. Notably for any iterator of a given length this will make the
/// same requests to `rng` and if the same sequence of values are produced
/// the same index will be selected from `self`. This may be useful if you
/// need consistent results no matter what type of iterator you are working
/// with. If you do not need this stability prefer [`choose`].
///
/// Note that this method still uses [`Iterator::size_hint`] to skip
/// constructing elements where possible, however the selection and `rng`
/// calls are the same in the face of this optimization. If you want to
/// force every element to be created regardless call `.inspect(|e| ())`.
fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
where R: Rng + ?Sized {
let mut consumed = 0;
let mut result = None;

loop {
// Currently the only way to skip elements is `nth()`. So we need to
// store what index to access next here.
// This should be replaced by `advance_by()` once it is stable:
// https://github.com/rust-lang/rust/issues/77404
let mut next = 0;

let (lower, _) = self.size_hint();
if lower >= 2 {
let highest_selected = (0..lower)
.filter(|ix| rng.gen_range(0..=consumed+ix) == 0)
.last();

consumed += lower;
next = lower;

if let Some(ix) = highest_selected {
result = self.nth(ix);
next -= ix + 1;
debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
}
}

let elem = self.nth(next);
if elem.is_none() {
return result
}

if rng.gen_range(0..=consumed) == 0 {
result = elem;
}
consumed += 1;
}
}

/// Collects values at random from the iterator into a supplied buffer
/// until that buffer is filled.
///
Expand Down Expand Up @@ -795,6 +851,103 @@ mod test {
assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_iterator_choose_stable() {
let r = &mut crate::test::rng(109);
fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = iter.clone().choose_stable(r).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
// Samples should follow Binomial(1000, 1/9)
// Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
// Note: have seen 153, which is unlikely but not impossible.
assert!(
72 < *count && *count < 154,
"count not close to 1000/9: {}",
count
);
}
}

test_iter(r, 0..9);
test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
#[cfg(feature = "alloc")]
test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
test_iter(r, UnhintedIterator { iter: 0..9 });
test_iter(r, ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: false,
});
test_iter(r, ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: true,
});
test_iter(r, WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: false,
});
test_iter(r, WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: true,
});

assert_eq!((0..0).choose(r), None);
assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_iterator_choose_stable_stability() {
fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
let r = &mut crate::test::rng(109);
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = iter.clone().choose_stable(r).unwrap();
chosen[picked] += 1;
}
chosen
}

let reference = test_iter(0..9);
assert_eq!(test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()), reference);

#[cfg(feature = "alloc")]
assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
assert_eq!(test_iter(ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: false,
}), reference);
assert_eq!(test_iter(ChunkHintedIterator {
iter: 0..9,
chunk_size: 4,
chunk_remaining: 4,
hint_total_size: true,
}), reference);
assert_eq!(test_iter(WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: false,
}), reference);
assert_eq!(test_iter(WindowHintedIterator {
iter: 0..9,
window_size: 2,
hint_total_size: true,
}), reference);
}

#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_shuffle() {
Expand Down

0 comments on commit 7b35598

Please sign in to comment.