Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/combinations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,95 @@ impl<I> Iterator for Combinations<I>
// Create result vector based on the indices
Some(self.indices.iter().map(|i| self.pool[*i].clone()).collect())
}

fn size_hint(&self) -> (usize, Option<usize>) {
let (mut low, mut upp) = self.pool.size_hint();
low = remaining_for(low, self.first, &self.indices).unwrap_or(usize::MAX);
upp = upp.and_then(|upp| remaining_for(upp, self.first, &self.indices));
(low, upp)
}

fn count(self) -> usize {
let Self { indices, pool, first } = self;
// TODO: make `pool.it` private
let n = pool.len() + pool.it.count();
remaining_for(n, first, &indices).unwrap()
}
}

impl<I> FusedIterator for Combinations<I>
where I: Iterator,
I::Item: Clone
{}

// https://en.wikipedia.org/wiki/Binomial_coefficient#In_programming_languages
pub(crate) fn checked_binomial(mut n: usize, mut k: usize) -> Option<usize> {
if n < k {
return Some(0);
}
// `factorial(n) / factorial(n - k) / factorial(k)` but trying to avoid it overflows:
k = (n - k).min(k); // symmetry
let mut c = 1;
for i in 1..=k {
c = (c / i).checked_mul(n)?.checked_add((c % i).checked_mul(n)? / i)?;
n -= 1;
}
Some(c)
}

#[test]
fn test_checked_binomial() {
// With the first row: [1, 0, 0, ...] and the first column full of 1s, we check
// row by row the recurrence relation of binomials (which is an equivalent definition).
// For n >= 1 and k >= 1 we have:
// binomial(n, k) == binomial(n - 1, k - 1) + binomial(n - 1, k)
const LIMIT: usize = 500;
let mut row = vec![Some(0); LIMIT + 1];
row[0] = Some(1);
for n in 0..=LIMIT {
for k in 0..=LIMIT {
assert_eq!(row[k], checked_binomial(n, k));
}
row = std::iter::once(Some(1))
.chain((1..=LIMIT).map(|k| row[k - 1]?.checked_add(row[k]?)))
.collect();
}
}

/// For a given size `n`, return the count of remaining combinations or None if it would overflow.
fn remaining_for(n: usize, first: bool, indices: &[usize]) -> Option<usize> {
let k = indices.len();
if n < k {
Some(0)
} else if first {
checked_binomial(n, k)
} else {
// https://en.wikipedia.org/wiki/Combinatorial_number_system
// http://www.site.uottawa.ca/~lucia/courses/5165-09/GenCombObj.pdf

// The combinations generated after the current one can be counted by counting as follows:
// - The subsequent combinations that differ in indices[0]:
// If subsequent combinations differ in indices[0], then their value for indices[0]
// must be at least 1 greater than the current indices[0].
// As indices is strictly monotonically sorted, this means we can effectively choose k values
// from (n - 1 - indices[0]), leading to binomial(n - 1 - indices[0], k) possibilities.
// - The subsequent combinations with same indices[0], but differing indices[1]:
// Here we can choose k - 1 values from (n - 1 - indices[1]) values,
// leading to binomial(n - 1 - indices[1], k - 1) possibilities.
// - (...)
// - The subsequent combinations with same indices[0..=i], but differing indices[i]:
// Here we can choose k - i values from (n - 1 - indices[i]) values: binomial(n - 1 - indices[i], k - i).
// Since subsequent combinations can in any index, we must sum up the aforementioned binomial coefficients.

// Below, `n0` resembles indices[i].
indices
.iter()
.enumerate()
// TODO: Once the MSRV hits 1.37.0, we can sum options instead:
// .map(|(i, n0)| checked_binomial(n - 1 - *n0, k - i))
// .sum()
.fold(Some(0), |sum, (i, n0)| {
sum.and_then(|s| s.checked_add(checked_binomial(n - 1 - *n0, k - i)?))
})
}
}
6 changes: 6 additions & 0 deletions src/lazy_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ use std::iter::Fuse;
use std::ops::Index;
use alloc::vec::Vec;

use crate::size_hint::{self, SizeHint};

#[derive(Debug, Clone)]
pub struct LazyBuffer<I: Iterator> {
pub it: Fuse<I>,
Expand All @@ -23,6 +25,10 @@ where
self.buffer.len()
}

pub fn size_hint(&self) -> SizeHint {
size_hint::add_scalar(self.it.size_hint(), self.len())
}

pub fn get_next(&mut self) -> bool {
if let Some(x) = self.it.next() {
self.buffer.push(x);
Expand Down
22 changes: 22 additions & 0 deletions tests/test_std.rs
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,28 @@ fn combinations_zero() {
it::assert_equal((0..0).combinations(0), vec![vec![]]);
}

#[test]
fn combinations_range_count() {
for n in 0..=10 {
for k in 0..=n {
let len = (n - k + 1..=n).product::<usize>() / (1..=k).product::<usize>();
let mut it = (0..n).combinations(k);
assert_eq!(len, it.clone().count());
assert_eq!(len, it.size_hint().0);
assert_eq!(Some(len), it.size_hint().1);
for count in (0..len).rev() {
let elem = it.next();
assert!(elem.is_some());
assert_eq!(count, it.clone().count());
assert_eq!(count, it.size_hint().0);
assert_eq!(Some(count), it.size_hint().1);
}
let should_be_none = it.next();
assert!(should_be_none.is_none());
}
}
}

#[test]
fn permutations_zero() {
it::assert_equal((1..3).permutations(0), vec![vec![]]);
Expand Down