Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions src/combinations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,51 @@ impl<I> Iterator for Combinations<I>
// Create result vector based on the indices
Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
}

fn size_hint(&self) -> (usize, Option<usize>) {
let (mut low, mut upp) = self.pool.size_hint();
low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX);
upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices));
(low, upp)
}

fn count(self) -> usize {
let Self { indices, pool, first } = self;
let n = pool.len() + pool.it.count();
remaining_for(n, first, &indices).expect("Iterator count greater than usize::MAX")
}
}

impl<I> FusedIterator for Combinations<I>
where I: Iterator,
I::Item: Clone
{}

pub(crate) fn checked_binomial(mut n: usize, mut k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// `factorial(n) / factorial(n - k) / factorial(k)` but trying to avoid it overflows:
k = (n - k).min(k); // symmetry
let mut c = 1;
for i in 1..=k {
c = (c / i).checked_mul(n)?.checked_add((c % i).checked_mul(n)? / i)?;
n -= 1;
}
Some(c)
}

/// For a given size `n`, return the count of remaining combinations or None if it would overflow.
fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option<usize> {
let k = indices.len();
if first {
checked_binomial(n, k)
} else {
indices
.iter()
.enumerate()
.fold(Some(0), |sum, (k0, n0)| {
sum.and_then(|s| s.checked_add(checked_binomial(n - 1 - *n0, k - k0)?))
})
}
}
6 changes: 6 additions & 0 deletions src/lazy_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::iter::Fuse;
use std::ops::Index;
use alloc::vec::Vec;

use crate::size_hint::{self, SizeHint};

#[derive(Debug, Clone)]
pub struct LazyBuffer<I: Iterator> {
pub it: Fuse<I>,
Expand All @@ -23,6 +25,10 @@ where
self.buffer.len()
}

pub fn size_hint(&self) -> SizeHint {
size_hint::add_scalar(self.it.size_hint(), self.len())
}

pub fn get_next(&mut self) -> bool {
if let Some(x) = self.it.next() {
self.buffer.push(x);
Expand Down
15 changes: 15 additions & 0 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,21 @@ fn combinations_zero() {
it::assert_equal((0..0).combinations(0), vec![vec![]]);
}

#[test]
fn combinations_range_count() {
for n in 0..6 {
for k in 0..=n {
let len = (n - k + 1..=n).product::<usize>() / (1..=k).product::<usize>();
let mut it = (0..n).combinations(k);
for count in (0..=len).rev() {
assert_eq!(it.size_hint(), (count, Some(count)));
assert_eq!(it.clone().count(), count);
assert_eq!(it.next().is_none(), count == 0);
}
}
}
}

#[test]
fn permutations_zero() {
it::assert_equal((1..3).permutations(0), vec![vec![]]);
Expand Down