From a7aa74628609c4b1514e5100d6d5f60ee0bc97b3 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:11:03 +0800 Subject: [PATCH 01/20] Add WeightedIndexTree to rand_distr --- rand_distr/src/lib.rs | 10 +- rand_distr/src/weighted_tree.rs | 332 ++++++++++++++++++++++++++++ src/distributions/weighted_index.rs | 54 +++-- 3 files changed, 373 insertions(+), 23 deletions(-) create mode 100644 rand_distr/src/weighted_tree.rs diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index c8fd298171d..ed6d1bc75f5 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -133,6 +133,7 @@ pub use rand::distributions::{WeightedError, WeightedIndex}; #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub use weighted_alias::WeightedAliasIndex; +pub use weighted_tree::WeightedTreeIndex; pub use num_traits; @@ -174,10 +175,14 @@ mod test { macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); - assert!(diff <= $prec, + assert!( + diff <= $prec, "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ (left: `{}`, right: `{}`)", - diff, $prec, $a, $b + diff, + $prec, + $a, + $b ); }; } @@ -186,6 +191,7 @@ mod test { #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod weighted_alias; +pub mod weighted_tree; mod binomial; mod cauchy; diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs new file mode 100644 index 00000000000..08d2b59f4c8 --- /dev/null +++ b/rand_distr/src/weighted_tree.rs @@ -0,0 +1,332 @@ +// Copyright 2024 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! This module contains an implementation of a tree sttructure for sampling random +//! indices with probabilities proportional to a collection of weights. + +use core::ops::{Add, AddAssign, Sub, SubAssign}; + +use super::WeightedError; +use crate::Distribution; +use alloc::{vec, vec::Vec}; +use num_traits::Zero; +use rand::{distributions::uniform::SampleUniform, Rng}; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// A distribution using weighted sampling to pick a discretely selected item. +/// +/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedTreeIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which a implementation of +/// [`Weight`] exists. +/// +/// # Key differences +/// +/// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] +/// lies in the internal representation of weights. In [`WeightedTreeIndex`], +/// weights are structured as a tree, which is optimized for frequent updates of the weights. +/// +/// # Performance +/// +/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. +/// +/// Time complexity for the operations of a [`WeightedTreeIndex`] are: +/// * Constructing: Building the initial tree from a slice of weights takes `O(n)` time. +/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. +/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. +/// +/// # Example +/// +/// ``` +/// use rand_distr::WeightedTreeIndex; +/// use rand::prelude::*; +/// +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = WeightedTreeIndex::new(&weights).unwrap(); +/// let mut rng = thread_rng(); +/// for _ in 0..100 { +/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' +/// let i = dist.sample(&mut rng).unwrap(); +/// println!("{}", choices[i]); +/// } +/// ``` +/// +/// [`WeightedTreeIndex`]: WeightedTreeIndex +/// [`Uniform::sample`]: Distribution::sample +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde1", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde1 ", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] +#[derive(Clone, Default, Debug, PartialEq)] +pub struct WeightedTreeIndex { + subtotals: Vec, +} + +impl WeightedTreeIndex { + /// Creates a new [`WeightedTreeIndex`] from a slice of weights. + pub fn new(weights: &[W]) -> Result { + for &weight in weights { + if weight < W::zero() { + return Err(WeightedError::InvalidWeight); + } + } + let n = weights.len(); + let mut subtotals = vec![W::zero(); n]; + for i in (0..n).rev() { + let left_index = 2 * i + 1; + let left_subtotal = if left_index < n { + subtotals[left_index] + } else { + W::zero() + }; + let right_index = 2 * i + 2; + let right_subtotal = if right_index < n { + subtotals[right_index] + } else { + W::zero() + }; + subtotals[i] = weights[i] + left_subtotal + right_subtotal; + } + Ok(Self { subtotals }) + } + + /// Returns `true` if the tree contains no weights. + pub fn is_empty(&self) -> bool { + self.subtotals.is_empty() + } + + /// Returns the number of weights. + pub fn len(&self) -> usize { + self.subtotals.len() + } + + /// Returns `true` if we can sample. + /// + /// This is the case if the total weight of the tree is greater than zero. + pub fn can_sample(&self) -> bool { + self.subtotals.first().is_some_and(|x| *x > W::zero()) + } + + /// Gets the weight at an index. + pub fn get(&self, index: usize) -> W { + let left_index = 2 * index + 1; + let right_index = 2 * index + 2; + self.subtotals[index] - self.subtotal(left_index) - self.subtotal(right_index) + } + + /// Removes the last weight and returns it, or [`None`] if it is empty. + pub fn pop(&mut self) -> Option { + self.subtotals.pop().map(|weight| { + let mut index = self.len(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= weight; + } + weight + }) + } + + /// Appends a new weight at the end. + pub fn push(&mut self, weight: W) -> Result<(), WeightedError> { + if weight < W::zero() { + return Err(WeightedError::InvalidWeight); + } + let mut index = self.len(); + self.subtotals.push(weight); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] += weight; + } + Ok(()) + } + + /// Updates the weight at an index. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { + if weight < W::zero() { + return Err(WeightedError::InvalidWeight); + } + let difference = weight - self.get(index); + if difference == W::zero() { + return Ok(()); + } + self.subtotals[index] += difference; + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] += difference; + } + Ok(()) + } + + fn subtotal(&self, index: usize) -> W { + if index < self.subtotals.len() { + self.subtotals[index] + } else { + W::zero() + } + } +} + +impl Distribution> for WeightedTreeIndex { + fn sample(&self, rng: &mut R) -> Result { + if self.subtotals.is_empty() { + return Err(WeightedError::NoItem); + } + let total_weight = self.subtotals[0]; + if total_weight == W::zero() { + return Err(WeightedError::AllWeightsZero); + } + let mut target_weight = rng.gen_range(W::zero()..total_weight); + let mut index = 0; + loop { + // Maybe descend into the left sub tree. + let left_index = 2 * index + 1; + let left_subtotal = self.subtotal(left_index); + if target_weight < left_subtotal { + index = left_index; + continue; + } + target_weight -= left_subtotal; + + // Maybe descend into the right sub tree. + let right_index = 2 * index + 2; + let right_subtotal = self.subtotal(right_index); + if target_weight < right_subtotal { + index = right_index; + continue; + } + target_weight -= right_subtotal; + + // Otherwise we found the index with the target weight. + break; + } + Ok(index) + } +} + +/// Trait that must be implemented for weights, that are used with +/// [`WeightedTreeIndex`]. Currently no guarantees on the correctness of +/// [`WeightedTreeIndex`] are given for custom implementations of this trait. +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub trait Weight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Zero +{ +} + +impl Weight for T where + T: Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Zero +{ +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_no_item_error() { + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); + } + + #[test] + fn test_all_weights_zero_error() { + let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + assert_eq!( + tree.sample(&mut rng).unwrap_err(), + WeightedError::AllWeightsZero + ); + } + + #[test] + fn test_invalid_weight_error() { + assert_eq!( + WeightedTreeIndex::::new(&[1, -1]).unwrap_err(), + WeightedError::InvalidWeight + ); + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight); + tree.push(1).unwrap(); + assert_eq!( + tree.update(0, -1).unwrap_err(), + WeightedError::InvalidWeight + ); + } + + #[test] + fn test_tree_modifications() { + let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap(); + tree.push(3).unwrap(); + tree.push(5).unwrap(); + tree.update(0, 0).unwrap(); + assert_eq!(tree.pop(), Some(5)); + let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap(); + assert_eq!(tree, expected); + } + + #[test] + fn test_sample_counts_match_probabilities() { + let start = 1; + let end = 3; + let samples = 20; + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); + let mut tree = WeightedTreeIndex::new(&weights).unwrap(); + let mut total_weight = 0.0; + let mut weights = vec![0.0; end]; + for i in 0..end { + tree.update(i, i as f64).unwrap(); + weights[i] = i as f64; + total_weight += i as f64; + } + for i in 0..start { + tree.update(i, 0.0).unwrap(); + weights[i] = 0.0; + total_weight -= i as f64; + } + let mut counts = vec![0_usize; end]; + for _ in 0..samples { + let i = tree.sample(&mut rng).unwrap(); + counts[i] += 1; + } + for i in 0..start { + assert_eq!(counts[i], 0); + } + for i in start..end { + let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; + assert!(diff.abs() < 0.05); + } + } +} diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index de3628b5ead..b759351fb86 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -18,7 +18,7 @@ use core::fmt; use alloc::vec::Vec; #[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling of discrete items /// @@ -144,15 +144,17 @@ impl WeightedIndex { /// allocation internally. /// /// In case of error, `self` is not modified. - /// + /// /// Note: Updating floating-point weights may cause slight inaccuracies in the total weight. /// This method may not return `WeightedError::AllWeightsZero` when all weights - /// are zero if using floating-point weights. + /// are zero if using floating-point weights. pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> + where + X: for<'a> ::core::ops::AddAssign<&'a X> + for<'a> ::core::ops::SubAssign<&'a X> + Clone - + Default { + + Default, + { if new_weights.is_empty() { return Ok(()); } @@ -230,12 +232,14 @@ impl WeightedIndex { } impl Distribution for WeightedIndex -where X: SampleUniform + PartialOrd +where + X: SampleUniform + PartialOrd, { fn sample(&self, rng: &mut R) -> usize { let chosen_weight = self.weight_distribution.sample(rng); // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights.partition_point(|w| w <= &chosen_weight) + self.cumulative_weights + .partition_point(|w| w <= &chosen_weight) } } @@ -288,7 +292,7 @@ macro_rules! impl_weight_float { Ok(()) } } - } + }; } impl_weight_float!(f32); impl_weight_float!(f64); @@ -314,7 +318,7 @@ mod test { } #[test] - fn test_accepting_nan(){ + fn test_accepting_nan() { assert_eq!( WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), WeightedError::InvalidWeight, @@ -337,7 +341,6 @@ mod test { ) } - #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weightedindex() { @@ -461,15 +464,21 @@ mod test { } let mut buf = [0; 10]; - test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, - ]); - test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, - ]); - test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, - ]); + test_samples( + &[1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], + ); + test_samples( + &[0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], + ); + test_samples( + &[1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], + ); } #[test] @@ -479,11 +488,14 @@ mod test { #[test] fn overflow() { - assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow)); + assert_eq!( + WeightedIndex::new([2, usize::MAX]), + Err(WeightedError::Overflow) + ); } } -/// Error type returned from `WeightedIndex::new`. +/// Error type returned from [`WeightedIndex`] operations. #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum WeightedError { From 191a1fedd42df1f24166059e956f6dee333c0b99 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:16:32 +0800 Subject: [PATCH 02/20] revert --- src/distributions/weighted_index.rs | 684 ++++++++++------------------ 1 file changed, 242 insertions(+), 442 deletions(-) diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index b759351fb86..08d2b59f4c8 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -1,4 +1,4 @@ -// Copyright 2018 Developers of the Rand project. +// Copyright 2024 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -6,527 +6,327 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! Weighted index sampling +//! This module contains an implementation of a tree sttructure for sampling random +//! indices with probabilities proportional to a collection of weights. -use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; -use crate::distributions::Distribution; -use crate::Rng; -use core::cmp::PartialOrd; -use core::fmt; - -// Note that this whole module is only imported if feature="alloc" is enabled. -use alloc::vec::Vec; +use core::ops::{Add, AddAssign, Sub, SubAssign}; +use super::WeightedError; +use crate::Distribution; +use alloc::{vec, vec::Vec}; +use num_traits::Zero; +use rand::{distributions::uniform::SampleUniform, Rng}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; -/// A distribution using weighted sampling of discrete items +/// A distribution using weighted sampling to pick a discretely selected item. /// -/// Sampling a `WeightedIndex` distribution returns the index of a randomly -/// selected element from the iterator used when the `WeightedIndex` was -/// created. The chance of a given element being picked is proportional to the -/// weight of the element. The weights can use any type `X` for which an -/// implementation of [`Uniform`] exists. The implementation guarantees that -/// elements with zero weight are never picked, even when the weights are -/// floating point numbers. +/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly +/// selected element from the vector used to create the [`WeightedTreeIndex`]. +/// The chance of a given element being picked is proportional to the value of +/// the element. The weights can have any type `W` for which a implementation of +/// [`Weight`] exists. /// -/// # Performance +/// # Key differences /// -/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. As an alternative, -/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost. +/// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] +/// lies in the internal representation of weights. In [`WeightedTreeIndex`], +/// weights are structured as a tree, which is optimized for frequent updates of the weights. /// -/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its -/// size is the sum of the size of those objects, possibly plus some alignment. +/// # Performance /// -/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` -/// weights of type `X`, where `N` is the number of weights. However, since -/// `Vec` doesn't guarantee a particular growth strategy, additional memory -/// might be allocated but not used. Since the `WeightedIndex` object also -/// contains an instance of `X::Sampler`, this might cause additional allocations, -/// though for primitive types, [`Uniform`] doesn't allocate any memory. +/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. /// -/// Sampling from `WeightedIndex` will result in a single call to -/// `Uniform::sample` (method of the [`Distribution`] trait), which typically -/// will request a single value from the underlying [`RngCore`], though the -/// exact number depends on the implementation of `Uniform::sample`. +/// Time complexity for the operations of a [`WeightedTreeIndex`] are: +/// * Constructing: Building the initial tree from a slice of weights takes `O(n)` time. +/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. +/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. +/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. /// /// # Example /// /// ``` +/// use rand_distr::WeightedTreeIndex; /// use rand::prelude::*; -/// use rand::distributions::WeightedIndex; /// -/// let choices = ['a', 'b', 'c']; -/// let weights = [2, 1, 1]; -/// let dist = WeightedIndex::new(&weights).unwrap(); +/// let choices = vec!['a', 'b', 'c']; +/// let weights = vec![2, 1, 1]; +/// let dist = WeightedTreeIndex::new(&weights).unwrap(); /// let mut rng = thread_rng(); /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// println!("{}", choices[dist.sample(&mut rng)]); -/// } -/// -/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)]; -/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); -/// for _ in 0..100 { -/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' -/// println!("{}", items[dist2.sample(&mut rng)].0); +/// let i = dist.sample(&mut rng).unwrap(); +/// println!("{}", choices[i]); /// } /// ``` /// -/// [`Uniform`]: crate::distributions::Uniform -/// [`RngCore`]: crate::RngCore -#[derive(Debug, Clone, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +/// [`WeightedTreeIndex`]: WeightedTreeIndex +/// [`Uniform::sample`]: Distribution::sample #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub struct WeightedIndex { - cumulative_weights: Vec, - total_weight: X, - weight_distribution: X::Sampler, +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +#[cfg_attr( + feature = "serde1", + serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) +)] +#[cfg_attr( + feature = "serde1 ", + serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) +)] +#[derive(Clone, Default, Debug, PartialEq)] +pub struct WeightedTreeIndex { + subtotals: Vec, } -impl WeightedIndex { - /// Creates a new a `WeightedIndex` [`Distribution`] using the values - /// in `weights`. The weights can use any type `X` for which an - /// implementation of [`Uniform`] exists. - /// - /// Returns an error if the iterator is empty, if any weight is `< 0`, or - /// if its total value is 0. - /// - /// [`Uniform`]: crate::distributions::uniform::Uniform - pub fn new(weights: I) -> Result, WeightedError> - where - I: IntoIterator, - I::Item: SampleBorrow, - X: Weight, - { - let mut iter = weights.into_iter(); - let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); - - let zero = X::ZERO; - if !(total_weight >= zero) { - return Err(WeightedError::InvalidWeight); - } - - let mut weights = Vec::::with_capacity(iter.size_hint().0); - for w in iter { - // Note that `!(w >= x)` is not equivalent to `w < x` for partially - // ordered types due to NaNs which are equal to nothing. - if !(w.borrow() >= &zero) { +impl WeightedTreeIndex { + /// Creates a new [`WeightedTreeIndex`] from a slice of weights. + pub fn new(weights: &[W]) -> Result { + for &weight in weights { + if weight < W::zero() { return Err(WeightedError::InvalidWeight); } - weights.push(total_weight.clone()); - - if let Err(()) = total_weight.checked_add_assign(w.borrow()) { - return Err(WeightedError::Overflow); - } } - - if total_weight == zero { - return Err(WeightedError::AllWeightsZero); + let n = weights.len(); + let mut subtotals = vec![W::zero(); n]; + for i in (0..n).rev() { + let left_index = 2 * i + 1; + let left_subtotal = if left_index < n { + subtotals[left_index] + } else { + W::zero() + }; + let right_index = 2 * i + 2; + let right_subtotal = if right_index < n { + subtotals[right_index] + } else { + W::zero() + }; + subtotals[i] = weights[i] + left_subtotal + right_subtotal; } - let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); - - Ok(WeightedIndex { - cumulative_weights: weights, - total_weight, - weight_distribution: distr, - }) + Ok(Self { subtotals }) } - /// Update a subset of weights, without changing the number of weights. - /// - /// `new_weights` must be sorted by the index. - /// - /// Using this method instead of `new` might be more efficient if only a small number of - /// weights is modified. No allocations are performed, unless the weight type `X` uses - /// allocation internally. - /// - /// In case of error, `self` is not modified. - /// - /// Note: Updating floating-point weights may cause slight inaccuracies in the total weight. - /// This method may not return `WeightedError::AllWeightsZero` when all weights - /// are zero if using floating-point weights. - pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where - X: for<'a> ::core::ops::AddAssign<&'a X> - + for<'a> ::core::ops::SubAssign<&'a X> - + Clone - + Default, - { - if new_weights.is_empty() { - return Ok(()); - } + /// Returns `true` if the tree contains no weights. + pub fn is_empty(&self) -> bool { + self.subtotals.is_empty() + } - let zero = ::default(); + /// Returns the number of weights. + pub fn len(&self) -> usize { + self.subtotals.len() + } - let mut total_weight = self.total_weight.clone(); + /// Returns `true` if we can sample. + /// + /// This is the case if the total weight of the tree is greater than zero. + pub fn can_sample(&self) -> bool { + self.subtotals.first().is_some_and(|x| *x > W::zero()) + } - // Check for errors first, so we don't modify `self` in case something - // goes wrong. - let mut prev_i = None; - for &(i, w) in new_weights { - if let Some(old_i) = prev_i { - if old_i >= i { - return Err(WeightedError::InvalidWeight); - } - } - if !(*w >= zero) { - return Err(WeightedError::InvalidWeight); - } - if i > self.cumulative_weights.len() { - return Err(WeightedError::TooMany); - } + /// Gets the weight at an index. + pub fn get(&self, index: usize) -> W { + let left_index = 2 * index + 1; + let right_index = 2 * index + 2; + self.subtotals[index] - self.subtotal(left_index) - self.subtotal(right_index) + } - let mut old_w = if i < self.cumulative_weights.len() { - self.cumulative_weights[i].clone() - } else { - self.total_weight.clone() - }; - if i > 0 { - old_w -= &self.cumulative_weights[i - 1]; + /// Removes the last weight and returns it, or [`None`] if it is empty. + pub fn pop(&mut self) -> Option { + self.subtotals.pop().map(|weight| { + let mut index = self.len(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= weight; } + weight + }) + } - total_weight -= &old_w; - total_weight += w; - prev_i = Some(i); + /// Appends a new weight at the end. + pub fn push(&mut self, weight: W) -> Result<(), WeightedError> { + if weight < W::zero() { + return Err(WeightedError::InvalidWeight); } - if total_weight <= zero { - return Err(WeightedError::AllWeightsZero); + let mut index = self.len(); + self.subtotals.push(weight); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] += weight; } + Ok(()) + } - // Update the weights. Because we checked all the preconditions in the - // previous loop, this should never panic. - let mut iter = new_weights.iter(); - - let mut prev_weight = zero.clone(); - let mut next_new_weight = iter.next(); - let &(first_new_index, _) = next_new_weight.unwrap(); - let mut cumulative_weight = if first_new_index > 0 { - self.cumulative_weights[first_new_index - 1].clone() - } else { - zero.clone() - }; - for i in first_new_index..self.cumulative_weights.len() { - match next_new_weight { - Some(&(j, w)) if i == j => { - cumulative_weight += w; - next_new_weight = iter.next(); - } - _ => { - let mut tmp = self.cumulative_weights[i].clone(); - tmp -= &prev_weight; // We know this is positive. - cumulative_weight += &tmp; - } - } - prev_weight = cumulative_weight.clone(); - core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); + /// Updates the weight at an index. + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { + if weight < W::zero() { + return Err(WeightedError::InvalidWeight); + } + let difference = weight - self.get(index); + if difference == W::zero() { + return Ok(()); + } + self.subtotals[index] += difference; + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] += difference; } - - self.total_weight = total_weight; - self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap(); - Ok(()) } -} -impl Distribution for WeightedIndex -where - X: SampleUniform + PartialOrd, -{ - fn sample(&self, rng: &mut R) -> usize { - let chosen_weight = self.weight_distribution.sample(rng); - // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights - .partition_point(|w| w <= &chosen_weight) + fn subtotal(&self, index: usize) -> W { + if index < self.subtotals.len() { + self.subtotals[index] + } else { + W::zero() + } } } -/// Bounds on a weight -/// -/// See usage in [`WeightedIndex`]. -pub trait Weight: Clone { - /// Representation of 0 - const ZERO: Self; - - /// Checked addition - /// - /// - `Result::Ok`: On success, `v` is added to `self` - /// - `Result::Err`: Returns an error when `Self` cannot represent the - /// result of `self + v` (i.e. overflow). The value of `self` should be - /// discarded. - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; -} - -macro_rules! impl_weight_int { - ($t:ty) => { - impl Weight for $t { - const ZERO: Self = 0; - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { - match self.checked_add(*v) { - Some(sum) => { - *self = sum; - Ok(()) - } - None => Err(()), - } +impl Distribution> for WeightedTreeIndex { + fn sample(&self, rng: &mut R) -> Result { + if self.subtotals.is_empty() { + return Err(WeightedError::NoItem); + } + let total_weight = self.subtotals[0]; + if total_weight == W::zero() { + return Err(WeightedError::AllWeightsZero); + } + let mut target_weight = rng.gen_range(W::zero()..total_weight); + let mut index = 0; + loop { + // Maybe descend into the left sub tree. + let left_index = 2 * index + 1; + let left_subtotal = self.subtotal(left_index); + if target_weight < left_subtotal { + index = left_index; + continue; } + target_weight -= left_subtotal; + + // Maybe descend into the right sub tree. + let right_index = 2 * index + 2; + let right_subtotal = self.subtotal(right_index); + if target_weight < right_subtotal { + index = right_index; + continue; + } + target_weight -= right_subtotal; + + // Otherwise we found the index with the target weight. + break; } - }; - ($t:ty, $($tt:ty),*) => { - impl_weight_int!($t); - impl_weight_int!($($tt),*); + Ok(index) } } -impl_weight_int!(i8, i16, i32, i64, i128, isize); -impl_weight_int!(u8, u16, u32, u64, u128, usize); -macro_rules! impl_weight_float { - ($t:ty) => { - impl Weight for $t { - const ZERO: Self = 0.0; - fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { - // Floats have an explicit representation for overflow - *self += *v; - Ok(()) - } - } - }; +/// Trait that must be implemented for weights, that are used with +/// [`WeightedTreeIndex`]. Currently no guarantees on the correctness of +/// [`WeightedTreeIndex`] are given for custom implementations of this trait. +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub trait Weight: + Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Zero +{ +} + +impl Weight for T where + T: Sized + + Copy + + SampleUniform + + PartialOrd + + Add + + AddAssign + + Sub + + SubAssign + + Zero +{ } -impl_weight_float!(f32); -impl_weight_float!(f64); #[cfg(test)] mod test { use super::*; - #[cfg(feature = "serde1")] #[test] - fn test_weightedindex_serde1() { - let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); - - let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); - let de_weighted_index: WeightedIndex = - bincode::deserialize(&ser_weighted_index).unwrap(); - - assert_eq!( - de_weighted_index.cumulative_weights, - weighted_index.cumulative_weights - ); - assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); + fn test_no_item_error() { + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); } #[test] - fn test_accepting_nan() { + fn test_all_weights_zero_error() { + let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); assert_eq!( - WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), - WeightedError::InvalidWeight, - ); - assert_eq!( - WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, - ); - assert_eq!( - WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), - WeightedError::InvalidWeight, + tree.sample(&mut rng).unwrap_err(), + WeightedError::AllWeightsZero ); - - assert_eq!( - WeightedIndex::new(&[0.5, 7.0]) - .unwrap() - .update_weights(&[(0, &core::f32::NAN)]) - .unwrap_err(), - WeightedError::InvalidWeight, - ) } #[test] - #[cfg_attr(miri, ignore)] // Miri is too slow - fn test_weightedindex() { - let mut r = crate::test::rng(700); - const N_REPS: u32 = 5000; - let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; - let total_weight = weights.iter().sum::() as f32; - - let verify = |result: [i32; 14]| { - for (i, count) in result.iter().enumerate() { - let exp = (weights[i] * N_REPS) as f32 / total_weight; - let mut err = (*count as f32 - exp).abs(); - if err != 0.0 { - err /= exp; - } - assert!(err <= 0.25); - } - }; - - // WeightedIndex from vec - let mut chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.to_vec()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from slice - chosen = [0i32; 14]; - let distr = WeightedIndex::new(&weights[..]).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - // WeightedIndex from iterator - chosen = [0i32; 14]; - let distr = WeightedIndex::new(weights.iter()).unwrap(); - for _ in 0..N_REPS { - chosen[distr.sample(&mut r)] += 1; - } - verify(chosen); - - for _ in 0..5 { - assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); - assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); - assert_eq!( - WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) - .unwrap() - .sample(&mut r), - 4 - ); - } - - assert_eq!( - WeightedIndex::new(&[10][0..0]).unwrap_err(), - WeightedError::NoItem - ); + fn test_invalid_weight_error() { assert_eq!( - WeightedIndex::new(&[0]).unwrap_err(), - WeightedError::AllWeightsZero - ); - assert_eq!( - WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), + WeightedTreeIndex::::new(&[1, -1]).unwrap_err(), WeightedError::InvalidWeight ); + let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight); + tree.push(1).unwrap(); assert_eq!( - WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), - WeightedError::InvalidWeight - ); - assert_eq!( - WeightedIndex::new(&[-10]).unwrap_err(), + tree.update(0, -1).unwrap_err(), WeightedError::InvalidWeight ); } #[test] - fn test_update_weights() { - let data = [ - ( - &[10u32, 2, 3, 4][..], - &[(1, &100), (2, &4)][..], // positive change - &[10, 100, 4, 4][..], - ), - ( - &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], - &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element - &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], - ), - ]; - - for (weights, update, expected_weights) in data.iter() { - let total_weight = weights.iter().sum::(); - let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, total_weight); - - distr.update_weights(update).unwrap(); - let expected_total_weight = expected_weights.iter().sum::(); - let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); - assert_eq!(distr.total_weight, expected_total_weight); - assert_eq!(distr.total_weight, expected_distr.total_weight); - assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); - } + fn test_tree_modifications() { + let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap(); + tree.push(3).unwrap(); + tree.push(5).unwrap(); + tree.update(0, 0).unwrap(); + assert_eq!(tree.pop(), Some(5)); + let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap(); + assert_eq!(tree, expected); } #[test] - fn value_stability() { - fn test_samples( - weights: I, buf: &mut [usize], expected: &[usize], - ) where - I: IntoIterator, - I::Item: SampleBorrow, - { - assert_eq!(buf.len(), expected.len()); - let distr = WeightedIndex::new(weights).unwrap(); - let mut rng = crate::test::rng(701); - for r in buf.iter_mut() { - *r = rng.sample(&distr); - } - assert_eq!(buf, expected); + fn test_sample_counts_match_probabilities() { + let start = 1; + let end = 3; + let samples = 20; + let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); + let mut tree = WeightedTreeIndex::new(&weights).unwrap(); + let mut total_weight = 0.0; + let mut weights = vec![0.0; end]; + for i in 0..end { + tree.update(i, i as f64).unwrap(); + weights[i] = i as f64; + total_weight += i as f64; + } + for i in 0..start { + tree.update(i, 0.0).unwrap(); + weights[i] = 0.0; + total_weight -= i as f64; + } + let mut counts = vec![0_usize; end]; + for _ in 0..samples { + let i = tree.sample(&mut rng).unwrap(); + counts[i] += 1; + } + for i in 0..start { + assert_eq!(counts[i], 0); + } + for i in start..end { + let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; + assert!(diff.abs() < 0.05); } - - let mut buf = [0; 10]; - test_samples( - &[1i32, 1, 1, 1, 1, 1, 1, 1, 1], - &mut buf, - &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], - ); - test_samples( - &[0.7f32, 0.1, 0.1, 0.1], - &mut buf, - &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], - ); - test_samples( - &[1.0f64, 0.999, 0.998, 0.997], - &mut buf, - &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], - ); - } - - #[test] - fn weighted_index_distributions_can_be_compared() { - assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); - } - - #[test] - fn overflow() { - assert_eq!( - WeightedIndex::new([2, usize::MAX]), - Err(WeightedError::Overflow) - ); - } -} - -/// Error type returned from [`WeightedIndex`] operations. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum WeightedError { - /// The provided weight collection contains no items. - NoItem, - - /// A weight is either less than zero, greater than the supported maximum, - /// NaN, or otherwise invalid. - InvalidWeight, - - /// All items in the provided weight collection are zero. - AllWeightsZero, - - /// Too many weights are provided (length greater than `u32::MAX`) - TooMany, - - /// The sum of weights overflows - Overflow, -} - -#[cfg(feature = "std")] -impl std::error::Error for WeightedError {} - -impl fmt::Display for WeightedError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.write_str(match *self { - WeightedError::NoItem => "No weights provided in distribution", - WeightedError::InvalidWeight => "A weight is invalid in distribution", - WeightedError::AllWeightsZero => "All weights are zero in distribution", - WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", - WeightedError::Overflow => "The sum of weights overflowed", - }) } } From 3672f2243092ea426fe019158aa452e34779c86e Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:17:42 +0800 Subject: [PATCH 03/20] revert2 --- src/distributions/weighted_index.rs | 674 ++++++++++++++++++---------- 1 file changed, 431 insertions(+), 243 deletions(-) diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 08d2b59f4c8..e8c30b5213f 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -1,4 +1,4 @@ -// Copyright 2024 Developers of the Rand project. +// Copyright 2018 Developers of the Rand project. // // Licensed under the Apache License, Version 2.0 or the MIT license @@ -6,327 +6,515 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! This module contains an implementation of a tree sttructure for sampling random -//! indices with probabilities proportional to a collection of weights. +//! Weighted index sampling -use core::ops::{Add, AddAssign, Sub, SubAssign}; +use crate::distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler}; +use crate::distributions::Distribution; +use crate::Rng; +use core::cmp::PartialOrd; +use core::fmt; + +// Note that this whole module is only imported if feature="alloc" is enabled. +use alloc::vec::Vec; -use super::WeightedError; -use crate::Distribution; -use alloc::{vec, vec::Vec}; -use num_traits::Zero; -use rand::{distributions::uniform::SampleUniform, Rng}; #[cfg(feature = "serde1")] -use serde::{Deserialize, Serialize}; +use serde::{Serialize, Deserialize}; -/// A distribution using weighted sampling to pick a discretely selected item. +/// A distribution using weighted sampling of discrete items /// -/// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly -/// selected element from the vector used to create the [`WeightedTreeIndex`]. -/// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which a implementation of -/// [`Weight`] exists. +/// Sampling a `WeightedIndex` distribution returns the index of a randomly +/// selected element from the iterator used when the `WeightedIndex` was +/// created. The chance of a given element being picked is proportional to the +/// weight of the element. The weights can use any type `X` for which an +/// implementation of [`Uniform`] exists. The implementation guarantees that +/// elements with zero weight are never picked, even when the weights are +/// floating point numbers. /// -/// # Key differences +/// # Performance /// -/// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] -/// lies in the internal representation of weights. In [`WeightedTreeIndex`], -/// weights are structured as a tree, which is optimized for frequent updates of the weights. +/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where +/// `N` is the number of weights. As an alternative, +/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) +/// supports `O(1)` sampling, but with much higher initialisation cost. /// -/// # Performance +/// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its +/// size is the sum of the size of those objects, possibly plus some alignment. /// -/// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. +/// Creating a `WeightedIndex` will allocate enough space to hold `N - 1` +/// weights of type `X`, where `N` is the number of weights. However, since +/// `Vec` doesn't guarantee a particular growth strategy, additional memory +/// might be allocated but not used. Since the `WeightedIndex` object also +/// contains an instance of `X::Sampler`, this might cause additional allocations, +/// though for primitive types, [`Uniform`] doesn't allocate any memory. /// -/// Time complexity for the operations of a [`WeightedTreeIndex`] are: -/// * Constructing: Building the initial tree from a slice of weights takes `O(n)` time. -/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. -/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. -/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. -/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time. +/// Sampling from `WeightedIndex` will result in a single call to +/// `Uniform::sample` (method of the [`Distribution`] trait), which typically +/// will request a single value from the underlying [`RngCore`], though the +/// exact number depends on the implementation of `Uniform::sample`. /// /// # Example /// /// ``` -/// use rand_distr::WeightedTreeIndex; /// use rand::prelude::*; +/// use rand::distributions::WeightedIndex; /// -/// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 1, 1]; -/// let dist = WeightedTreeIndex::new(&weights).unwrap(); +/// let choices = ['a', 'b', 'c']; +/// let weights = [2, 1, 1]; +/// let dist = WeightedIndex::new(&weights).unwrap(); /// let mut rng = thread_rng(); /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// let i = dist.sample(&mut rng).unwrap(); -/// println!("{}", choices[i]); +/// println!("{}", choices[dist.sample(&mut rng)]); +/// } +/// +/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)]; +/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap(); +/// for _ in 0..100 { +/// // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c' +/// println!("{}", items[dist2.sample(&mut rng)].0); /// } /// ``` /// -/// [`WeightedTreeIndex`]: WeightedTreeIndex -/// [`Uniform::sample`]: Distribution::sample -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +/// [`Uniform`]: crate::distributions::Uniform +/// [`RngCore`]: crate::RngCore +#[derive(Debug, Clone, PartialEq)] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -#[cfg_attr( - feature = "serde1", - serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")) -)] -#[cfg_attr( - feature = "serde1 ", - serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) -)] -#[derive(Clone, Default, Debug, PartialEq)] -pub struct WeightedTreeIndex { - subtotals: Vec, +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +pub struct WeightedIndex { + cumulative_weights: Vec, + total_weight: X, + weight_distribution: X::Sampler, } -impl WeightedTreeIndex { - /// Creates a new [`WeightedTreeIndex`] from a slice of weights. - pub fn new(weights: &[W]) -> Result { - for &weight in weights { - if weight < W::zero() { +impl WeightedIndex { + /// Creates a new a `WeightedIndex` [`Distribution`] using the values + /// in `weights`. The weights can use any type `X` for which an + /// implementation of [`Uniform`] exists. + /// + /// Returns an error if the iterator is empty, if any weight is `< 0`, or + /// if its total value is 0. + /// + /// [`Uniform`]: crate::distributions::uniform::Uniform + pub fn new(weights: I) -> Result, WeightedError> + where + I: IntoIterator, + I::Item: SampleBorrow, + X: Weight, + { + let mut iter = weights.into_iter(); + let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); + + let zero = X::ZERO; + if !(total_weight >= zero) { + return Err(WeightedError::InvalidWeight); + } + + let mut weights = Vec::::with_capacity(iter.size_hint().0); + for w in iter { + // Note that `!(w >= x)` is not equivalent to `w < x` for partially + // ordered types due to NaNs which are equal to nothing. + if !(w.borrow() >= &zero) { return Err(WeightedError::InvalidWeight); } + weights.push(total_weight.clone()); + + if let Err(()) = total_weight.checked_add_assign(w.borrow()) { + return Err(WeightedError::Overflow); + } } - let n = weights.len(); - let mut subtotals = vec![W::zero(); n]; - for i in (0..n).rev() { - let left_index = 2 * i + 1; - let left_subtotal = if left_index < n { - subtotals[left_index] - } else { - W::zero() - }; - let right_index = 2 * i + 2; - let right_subtotal = if right_index < n { - subtotals[right_index] - } else { - W::zero() - }; - subtotals[i] = weights[i] + left_subtotal + right_subtotal; - } - Ok(Self { subtotals }) - } - /// Returns `true` if the tree contains no weights. - pub fn is_empty(&self) -> bool { - self.subtotals.is_empty() - } + if total_weight == zero { + return Err(WeightedError::AllWeightsZero); + } + let distr = X::Sampler::new(zero, total_weight.clone()).unwrap(); - /// Returns the number of weights. - pub fn len(&self) -> usize { - self.subtotals.len() + Ok(WeightedIndex { + cumulative_weights: weights, + total_weight, + weight_distribution: distr, + }) } - /// Returns `true` if we can sample. + /// Update a subset of weights, without changing the number of weights. /// - /// This is the case if the total weight of the tree is greater than zero. - pub fn can_sample(&self) -> bool { - self.subtotals.first().is_some_and(|x| *x > W::zero()) - } + /// `new_weights` must be sorted by the index. + /// + /// Using this method instead of `new` might be more efficient if only a small number of + /// weights is modified. No allocations are performed, unless the weight type `X` uses + /// allocation internally. + /// + /// In case of error, `self` is not modified. + /// + /// Note: Updating floating-point weights may cause slight inaccuracies in the total weight. + /// This method may not return `WeightedError::AllWeightsZero` when all weights + /// are zero if using floating-point weights. + pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> + where X: for<'a> ::core::ops::AddAssign<&'a X> + + for<'a> ::core::ops::SubAssign<&'a X> + + Clone + + Default { + if new_weights.is_empty() { + return Ok(()); + } - /// Gets the weight at an index. - pub fn get(&self, index: usize) -> W { - let left_index = 2 * index + 1; - let right_index = 2 * index + 2; - self.subtotals[index] - self.subtotal(left_index) - self.subtotal(right_index) - } + let zero = ::default(); - /// Removes the last weight and returns it, or [`None`] if it is empty. - pub fn pop(&mut self) -> Option { - self.subtotals.pop().map(|weight| { - let mut index = self.len(); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] -= weight; + let mut total_weight = self.total_weight.clone(); + + // Check for errors first, so we don't modify `self` in case something + // goes wrong. + let mut prev_i = None; + for &(i, w) in new_weights { + if let Some(old_i) = prev_i { + if old_i >= i { + return Err(WeightedError::InvalidWeight); + } + } + if !(*w >= zero) { + return Err(WeightedError::InvalidWeight); + } + if i > self.cumulative_weights.len() { + return Err(WeightedError::TooMany); } - weight - }) - } - /// Appends a new weight at the end. - pub fn push(&mut self, weight: W) -> Result<(), WeightedError> { - if weight < W::zero() { - return Err(WeightedError::InvalidWeight); - } - let mut index = self.len(); - self.subtotals.push(weight); - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] += weight; - } - Ok(()) - } + let mut old_w = if i < self.cumulative_weights.len() { + self.cumulative_weights[i].clone() + } else { + self.total_weight.clone() + }; + if i > 0 { + old_w -= &self.cumulative_weights[i - 1]; + } - /// Updates the weight at an index. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { - if weight < W::zero() { - return Err(WeightedError::InvalidWeight); + total_weight -= &old_w; + total_weight += w; + prev_i = Some(i); } - let difference = weight - self.get(index); - if difference == W::zero() { - return Ok(()); + if total_weight <= zero { + return Err(WeightedError::AllWeightsZero); } - self.subtotals[index] += difference; - while index != 0 { - index = (index - 1) / 2; - self.subtotals[index] += difference; + + // Update the weights. Because we checked all the preconditions in the + // previous loop, this should never panic. + let mut iter = new_weights.iter(); + + let mut prev_weight = zero.clone(); + let mut next_new_weight = iter.next(); + let &(first_new_index, _) = next_new_weight.unwrap(); + let mut cumulative_weight = if first_new_index > 0 { + self.cumulative_weights[first_new_index - 1].clone() + } else { + zero.clone() + }; + for i in first_new_index..self.cumulative_weights.len() { + match next_new_weight { + Some(&(j, w)) if i == j => { + cumulative_weight += w; + next_new_weight = iter.next(); + } + _ => { + let mut tmp = self.cumulative_weights[i].clone(); + tmp -= &prev_weight; // We know this is positive. + cumulative_weight += &tmp; + } + } + prev_weight = cumulative_weight.clone(); + core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]); } + + self.total_weight = total_weight; + self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap(); + Ok(()) } +} - fn subtotal(&self, index: usize) -> W { - if index < self.subtotals.len() { - self.subtotals[index] - } else { - W::zero() - } +impl Distribution for WeightedIndex +where X: SampleUniform + PartialOrd +{ + fn sample(&self, rng: &mut R) -> usize { + let chosen_weight = self.weight_distribution.sample(rng); + // Find the first item which has a weight *higher* than the chosen weight. + self.cumulative_weights.partition_point(|w| w <= &chosen_weight) } } -impl Distribution> for WeightedTreeIndex { - fn sample(&self, rng: &mut R) -> Result { - if self.subtotals.is_empty() { - return Err(WeightedError::NoItem); - } - let total_weight = self.subtotals[0]; - if total_weight == W::zero() { - return Err(WeightedError::AllWeightsZero); - } - let mut target_weight = rng.gen_range(W::zero()..total_weight); - let mut index = 0; - loop { - // Maybe descend into the left sub tree. - let left_index = 2 * index + 1; - let left_subtotal = self.subtotal(left_index); - if target_weight < left_subtotal { - index = left_index; - continue; - } - target_weight -= left_subtotal; - - // Maybe descend into the right sub tree. - let right_index = 2 * index + 2; - let right_subtotal = self.subtotal(right_index); - if target_weight < right_subtotal { - index = right_index; - continue; - } - target_weight -= right_subtotal; +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; - // Otherwise we found the index with the target weight. - break; - } - Ok(index) - } + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; } -/// Trait that must be implemented for weights, that are used with -/// [`WeightedTreeIndex`]. Currently no guarantees on the correctness of -/// [`WeightedTreeIndex`] are given for custom implementations of this trait. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub trait Weight: - Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Zero -{ +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } } +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); -impl Weight for T where - T: Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Zero -{ +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + } } +impl_weight_float!(f32); +impl_weight_float!(f64); #[cfg(test)] mod test { use super::*; + #[cfg(feature = "serde1")] #[test] - fn test_no_item_error() { - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); + fn test_weightedindex_serde1() { + let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); + + let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); + let de_weighted_index: WeightedIndex = + bincode::deserialize(&ser_weighted_index).unwrap(); + + assert_eq!( + de_weighted_index.cumulative_weights, + weighted_index.cumulative_weights + ); + assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight); } #[test] - fn test_all_weights_zero_error() { - let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); + fn test_accepting_nan(){ assert_eq!( - tree.sample(&mut rng).unwrap_err(), - WeightedError::AllWeightsZero + WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), + WeightedError::InvalidWeight, + ); + assert_eq!( + WeightedIndex::new(&[core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight, ); + assert_eq!( + WeightedIndex::new(&[0.5, core::f32::NAN]).unwrap_err(), + WeightedError::InvalidWeight, + ); + + assert_eq!( + WeightedIndex::new(&[0.5, 7.0]) + .unwrap() + .update_weights(&[(0, &core::f32::NAN)]) + .unwrap_err(), + WeightedError::InvalidWeight, + ) } + #[test] - fn test_invalid_weight_error() { + #[cfg_attr(miri, ignore)] // Miri is too slow + fn test_weightedindex() { + let mut r = crate::test::rng(700); + const N_REPS: u32 = 5000; + let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7]; + let total_weight = weights.iter().sum::() as f32; + + let verify = |result: [i32; 14]| { + for (i, count) in result.iter().enumerate() { + let exp = (weights[i] * N_REPS) as f32 / total_weight; + let mut err = (*count as f32 - exp).abs(); + if err != 0.0 { + err /= exp; + } + assert!(err <= 0.25); + } + }; + + // WeightedIndex from vec + let mut chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.to_vec()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from slice + chosen = [0i32; 14]; + let distr = WeightedIndex::new(&weights[..]).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + // WeightedIndex from iterator + chosen = [0i32; 14]; + let distr = WeightedIndex::new(weights.iter()).unwrap(); + for _ in 0..N_REPS { + chosen[distr.sample(&mut r)] += 1; + } + verify(chosen); + + for _ in 0..5 { + assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); + assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); + assert_eq!( + WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) + .unwrap() + .sample(&mut r), + 4 + ); + } + assert_eq!( - WeightedTreeIndex::::new(&[1, -1]).unwrap_err(), + WeightedIndex::new(&[10][0..0]).unwrap_err(), + WeightedError::NoItem + ); + assert_eq!( + WeightedIndex::new(&[0]).unwrap_err(), + WeightedError::AllWeightsZero + ); + assert_eq!( + WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::InvalidWeight ); - let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.push(-1).unwrap_err(), WeightedError::InvalidWeight); - tree.push(1).unwrap(); assert_eq!( - tree.update(0, -1).unwrap_err(), + WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), + WeightedError::InvalidWeight + ); + assert_eq!( + WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::InvalidWeight ); } #[test] - fn test_tree_modifications() { - let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap(); - tree.push(3).unwrap(); - tree.push(5).unwrap(); - tree.update(0, 0).unwrap(); - assert_eq!(tree.pop(), Some(5)); - let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap(); - assert_eq!(tree, expected); + fn test_update_weights() { + let data = [ + ( + &[10u32, 2, 3, 4][..], + &[(1, &100), (2, &4)][..], // positive change + &[10, 100, 4, 4][..], + ), + ( + &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..], + &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element + &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..], + ), + ]; + + for (weights, update, expected_weights) in data.iter() { + let total_weight = weights.iter().sum::(); + let mut distr = WeightedIndex::new(weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, total_weight); + + distr.update_weights(update).unwrap(); + let expected_total_weight = expected_weights.iter().sum::(); + let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap(); + assert_eq!(distr.total_weight, expected_total_weight); + assert_eq!(distr.total_weight, expected_distr.total_weight); + assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights); + } } #[test] - fn test_sample_counts_match_probabilities() { - let start = 1; - let end = 3; - let samples = 20; - let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); - let mut tree = WeightedTreeIndex::new(&weights).unwrap(); - let mut total_weight = 0.0; - let mut weights = vec![0.0; end]; - for i in 0..end { - tree.update(i, i as f64).unwrap(); - weights[i] = i as f64; - total_weight += i as f64; - } - for i in 0..start { - tree.update(i, 0.0).unwrap(); - weights[i] = 0.0; - total_weight -= i as f64; - } - let mut counts = vec![0_usize; end]; - for _ in 0..samples { - let i = tree.sample(&mut rng).unwrap(); - counts[i] += 1; - } - for i in 0..start { - assert_eq!(counts[i], 0); - } - for i in start..end { - let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight; - assert!(diff.abs() < 0.05); + fn value_stability() { + fn test_samples( + weights: I, buf: &mut [usize], expected: &[usize], + ) where + I: IntoIterator, + I::Item: SampleBorrow, + { + assert_eq!(buf.len(), expected.len()); + let distr = WeightedIndex::new(weights).unwrap(); + let mut rng = crate::test::rng(701); + for r in buf.iter_mut() { + *r = rng.sample(&distr); + } + assert_eq!(buf, expected); } + + let mut buf = [0; 10]; + test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ + 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, + ]); + test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ + 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, + ]); + test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ + 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, + ]); } + + #[test] + fn weighted_index_distributions_can_be_compared() { + assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); + } + + #[test] + fn overflow() { + assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow)); + } +} + +/// Error type returned from `WeightedIndex::new`. +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum WeightedError { + /// The provided weight collection contains no items. + NoItem, + + /// A weight is either less than zero, greater than the supported maximum, + /// NaN, or otherwise invalid. + InvalidWeight, + + /// All items in the provided weight collection are zero. + AllWeightsZero, + + /// Too many weights are provided (length greater than `u32::MAX`) + TooMany, + + /// The sum of weights overflows + Overflow, } + +#[cfg(feature = "std")] +impl std::error::Error for WeightedError {} + +impl fmt::Display for WeightedError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.write_str(match *self { + WeightedError::NoItem => "No weights provided in distribution", + WeightedError::InvalidWeight => "A weight is invalid in distribution", + WeightedError::AllWeightsZero => "All weights are zero in distribution", + WeightedError::TooMany => "Too many weights (hit u32::MAX) in distribution", + WeightedError::Overflow => "The sum of weights overflowed", + }) + } +} \ No newline at end of file From b20ac94b03c92f8cd336a6d4cfbdecf3e343d783 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:18:08 +0800 Subject: [PATCH 04/20] revert3 --- src/distributions/weighted_index.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index e8c30b5213f..de3628b5ead 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -517,4 +517,4 @@ impl fmt::Display for WeightedError { WeightedError::Overflow => "The sum of weights overflowed", }) } -} \ No newline at end of file +} From f90661e517e4290feb2ac7fdc64c25275a09c82a Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:18:44 +0800 Subject: [PATCH 05/20] d --- rand_distr/src/lib.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index ed6d1bc75f5..54b4fb93bf0 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -175,14 +175,10 @@ mod test { macro_rules! assert_almost_eq { ($a:expr, $b:expr, $prec:expr) => { let diff = ($a - $b).abs(); - assert!( - diff <= $prec, + assert!(diff <= $prec, "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ (left: `{}`, right: `{}`)", - diff, - $prec, - $a, - $b + diff, $prec, $a, $b ); }; } From 8188aa9895ee4ef17384f822f9a1b86ea20d9f44 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:30:31 +0800 Subject: [PATCH 06/20] a --- rand_distr/src/weighted_tree.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 08d2b59f4c8..bfa24681af7 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -32,7 +32,17 @@ use serde::{Deserialize, Serialize}; /// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] /// lies in the internal representation of weights. In [`WeightedTreeIndex`], /// weights are structured as a tree, which is optimized for frequent updates of the weights. -/// +/// +/// # Caution: Floating point types +/// +/// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), +/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types +/// are susceptible to numerical rounding errors. Since operations on floating point weights are +/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable +/// deviations from the expected behavior. +/// +/// Ideally, use fixed point or integer types whenever possible. +/// /// # Performance /// /// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. From 42e21231f659f3824563197ce8c6b7b51a7148fe Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:32:11 +0800 Subject: [PATCH 07/20] a --- rand_distr/src/weighted_tree.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index bfa24681af7..fd556b89363 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -61,8 +61,10 @@ use serde::{Deserialize, Serialize}; /// use rand::prelude::*; /// /// let choices = vec!['a', 'b', 'c']; -/// let weights = vec![2, 1, 1]; -/// let dist = WeightedTreeIndex::new(&weights).unwrap(); +/// let weights = vec![2, 0]; +/// let mut dist = WeightedTreeIndex::new(&weights).unwrap(); +/// dist.push(1).unwrap(); +/// dist.update(1, 1).unwrap(); /// let mut rng = thread_rng(); /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' @@ -266,7 +268,7 @@ mod test { #[test] fn test_no_item_error() { let mut rng = crate::test::rng(0x9c9fa0b0580a7031); - let mut tree = WeightedTreeIndex::::new(&[]).unwrap(); + let tree = WeightedTreeIndex::::new(&[]).unwrap(); assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); } From 0eb69d42917587008c6fe12639ba8bf28e1da18a Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:35:13 +0800 Subject: [PATCH 08/20] a --- rand_distr/src/weighted_tree.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index fd556b89363..6340660e257 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -227,6 +227,8 @@ impl Distribution> for WeightedTreeIndex // Otherwise we found the index with the target weight. break; } + assert!(target_weight >= W::zero()); + assert!(target_weight < self.subtotal(index)); Ok(index) } } From b00c7852ae4990ac87a3c1384ce49de9da8ced14 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:41:08 +0800 Subject: [PATCH 09/20] a --- rand_distr/src/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 54b4fb93bf0..7172e8d4322 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -133,6 +133,8 @@ pub use rand::distributions::{WeightedError, WeightedIndex}; #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub use weighted_alias::WeightedAliasIndex; +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub use weighted_tree::WeightedTreeIndex; pub use num_traits; @@ -187,6 +189,8 @@ mod test { #[cfg(feature = "alloc")] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod weighted_alias; +#[cfg(feature = "alloc")] +#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] pub mod weighted_tree; mod binomial; From 75c150f13c9a2bb677aae107094de9d5cbd7e1f4 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 10 Jan 2024 20:46:06 +0800 Subject: [PATCH 10/20] a --- rand_distr/src/weighted_tree.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 6340660e257..045c9e1f2b8 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -132,7 +132,11 @@ impl WeightedTreeIndex { /// /// This is the case if the total weight of the tree is greater than zero. pub fn can_sample(&self) -> bool { - self.subtotals.first().is_some_and(|x| *x > W::zero()) + if let Some(&w) = self.subtotals.first() { + w > W::zero() + } else { + false + } } /// Gets the weight at an index. From 30866d62d19effbfbbb5db7dcb7ede8aeed6b7e2 Mon Sep 17 00:00:00 2001 From: xmakro Date: Fri, 12 Jan 2024 12:00:45 +0800 Subject: [PATCH 11/20] Make it safe against overflows --- rand_distr/src/lib.rs | 3 +- rand_distr/src/weighted_tree.rs | 80 ++++++++++++++++++++++------- src/distributions/weighted_index.rs | 4 +- 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 7172e8d4322..1e28aaaa79a 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -76,8 +76,9 @@ //! - [`UnitBall`] distribution //! - [`UnitCircle`] distribution //! - [`UnitDisc`] distribution -//! - Alternative implementation for weighted index sampling +//! - Alternative implementations for weighted index sampling //! - [`WeightedAliasIndex`] distribution +//! - [`WeightedTreeIndex`] distribution //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 045c9e1f2b8..2f1349126ee 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -9,12 +9,12 @@ //! This module contains an implementation of a tree sttructure for sampling random //! indices with probabilities proportional to a collection of weights. -use core::ops::{Add, AddAssign, Sub, SubAssign}; +use core::ops::{Sub, SubAssign}; use super::WeightedError; use crate::Distribution; use alloc::{vec, vec::Vec}; -use num_traits::Zero; +use num_traits::{Zero, CheckedAdd}; use rand::{distributions::uniform::SampleUniform, Rng}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -113,7 +113,8 @@ impl WeightedTreeIndex { } else { W::zero() }; - subtotals[i] = weights[i] + left_subtotal + right_subtotal; + let children_subtotal = left_subtotal.checked_add(&right_subtotal).ok_or(WeightedError::Overflow)?; + subtotals[i] = weights[i].checked_add(&children_subtotal).ok_or(WeightedError::Overflow)?; } Ok(Self { subtotals }) } @@ -163,11 +164,16 @@ impl WeightedTreeIndex { if weight < W::zero() { return Err(WeightedError::InvalidWeight); } + if let Some(total) = self.subtotals.first() { + if total.checked_add(&weight).is_none() { + return Err(WeightedError::Overflow); + } + } let mut index = self.len(); self.subtotals.push(weight); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] += weight; + self.subtotals[index] = self.subtotals[index].checked_add(&weight).unwrap(); } Ok(()) } @@ -181,10 +187,15 @@ impl WeightedTreeIndex { if difference == W::zero() { return Ok(()); } - self.subtotals[index] += difference; + if let Some(total) = self.subtotals.first() { + if total.checked_add(&difference).is_none() { + return Err(WeightedError::Overflow); + } + } + self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap(); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] += difference; + self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap(); } Ok(()) } @@ -246,27 +257,49 @@ pub trait Weight: + Copy + SampleUniform + PartialOrd - + Add - + AddAssign + Sub + SubAssign + Zero { + /// Adds two numbers, checking for overflow. If overflow happens, None is returned. + fn checked_add(&self, b: &Self) -> Option; } -impl Weight for T where - T: Sized - + Copy - + SampleUniform - + PartialOrd - + Add - + AddAssign - + Sub - + SubAssign - + Zero -{ +macro_rules! impl_weight_for_float { + ($T: ident) => { + impl Weight for $T { + fn checked_add(&self, b: &Self) -> Option { + Some(self + b) + } + } + }; } +macro_rules! impl_weight_for_int { + ($T: ident) => { + impl Weight for $T { + fn checked_add(&self, b: &Self) -> Option { + CheckedAdd::checked_add(self, b) + } + } + }; +} + +impl_weight_for_float!(f64); +impl_weight_for_float!(f32); +impl_weight_for_int!(usize); +impl_weight_for_int!(u128); +impl_weight_for_int!(u64); +impl_weight_for_int!(u32); +impl_weight_for_int!(u16); +impl_weight_for_int!(u8); +impl_weight_for_int!(isize); +impl_weight_for_int!(i128); +impl_weight_for_int!(i64); +impl_weight_for_int!(i32); +impl_weight_for_int!(i16); +impl_weight_for_int!(i8); + #[cfg(test)] mod test { use super::*; @@ -278,6 +311,15 @@ mod test { assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); } + #[test] + fn test_overflow_error() { + assert_eq!(WeightedTreeIndex::new(&[i32::MAX, 2]), Err(WeightedError::Overflow)); + let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); + assert_eq!(tree.push(3), Err(WeightedError::Overflow)); + assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow)); + tree.update(1, 2).unwrap(); + } + #[test] fn test_all_weights_zero_error() { let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index de3628b5ead..5223af594d6 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -35,7 +35,9 @@ use serde::{Serialize, Deserialize}; /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where /// `N` is the number of weights. As an alternative, /// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost. +/// supports `O(1)` sampling, but with much higher initialisation cost, +/// and [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) +/// supports `O(log n)` updates with O /// /// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its /// size is the sum of the size of those objects, possibly plus some alignment. From 3b6229ee4662d4903c58cb2c5221a6afed74b4b3 Mon Sep 17 00:00:00 2001 From: xmakro Date: Fri, 12 Jan 2024 13:08:21 +0800 Subject: [PATCH 12/20] Checked adds and docs --- rand_distr/src/weighted_tree.rs | 190 ++++++++++++---------------- src/distributions/weighted_index.rs | 67 ++++++---- 2 files changed, 123 insertions(+), 134 deletions(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 2f1349126ee..f241477e0f7 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -9,13 +9,18 @@ //! This module contains an implementation of a tree sttructure for sampling random //! indices with probabilities proportional to a collection of weights. -use core::ops::{Sub, SubAssign}; +use core::ops::SubAssign; use super::WeightedError; use crate::Distribution; -use alloc::{vec, vec::Vec}; -use num_traits::{Zero, CheckedAdd}; -use rand::{distributions::uniform::SampleUniform, Rng}; +use alloc::vec::Vec; +use rand::{ + distributions::{ + uniform::{SampleBorrow, SampleUniform}, + Weight, + }, + Rng, +}; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -32,17 +37,17 @@ use serde::{Deserialize, Serialize}; /// The main distinction between [`WeightedTreeIndex`] and [`rand::distributions::WeightedIndex`] /// lies in the internal representation of weights. In [`WeightedTreeIndex`], /// weights are structured as a tree, which is optimized for frequent updates of the weights. -/// +/// /// # Caution: Floating point types -/// +/// /// When utilizing [`WeightedTreeIndex`] with floating point types (such as f32 or f64), /// exercise caution due to the inherent nature of floating point arithmetic. Floating point types /// are susceptible to numerical rounding errors. Since operations on floating point weights are /// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable /// deviations from the expected behavior. -/// +/// /// Ideally, use fixed point or integer types whenever possible. -/// +/// /// # Performance /// /// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. @@ -86,35 +91,30 @@ use serde::{Deserialize, Serialize}; serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) )] #[derive(Clone, Default, Debug, PartialEq)] -pub struct WeightedTreeIndex { +pub struct WeightedTreeIndex { subtotals: Vec, } -impl WeightedTreeIndex { +impl WeightedTreeIndex { /// Creates a new [`WeightedTreeIndex`] from a slice of weights. - pub fn new(weights: &[W]) -> Result { - for &weight in weights { - if weight < W::zero() { + pub fn new(weights: I) -> Result + where + I: IntoIterator, + I::Item: SampleBorrow, + { + let mut subtotals: Vec = weights.into_iter().map(|x| x.borrow().clone()).collect(); + for weight in subtotals.iter() { + if *weight < W::ZERO { return Err(WeightedError::InvalidWeight); } } - let n = weights.len(); - let mut subtotals = vec![W::zero(); n]; - for i in (0..n).rev() { - let left_index = 2 * i + 1; - let left_subtotal = if left_index < n { - subtotals[left_index] - } else { - W::zero() - }; - let right_index = 2 * i + 2; - let right_subtotal = if right_index < n { - subtotals[right_index] - } else { - W::zero() - }; - let children_subtotal = left_subtotal.checked_add(&right_subtotal).ok_or(WeightedError::Overflow)?; - subtotals[i] = weights[i].checked_add(&children_subtotal).ok_or(WeightedError::Overflow)?; + let n = subtotals.len(); + for i in (1..n).rev() { + let w = subtotals[i].clone(); + let parent = (i - 1) / 2; + subtotals[parent] + .checked_add_assign(&w) + .map_err(|()| WeightedError::Overflow)?; } Ok(Self { subtotals }) } @@ -133,27 +133,36 @@ impl WeightedTreeIndex { /// /// This is the case if the total weight of the tree is greater than zero. pub fn can_sample(&self) -> bool { - if let Some(&w) = self.subtotals.first() { - w > W::zero() + if let Some(weight) = self.subtotals.first() { + *weight > W::ZERO } else { false } } /// Gets the weight at an index. - pub fn get(&self, index: usize) -> W { + pub fn get(&self, index: usize) -> W + where + W: for<'a> SubAssign<&'a W>, + { let left_index = 2 * index + 1; let right_index = 2 * index + 2; - self.subtotals[index] - self.subtotal(left_index) - self.subtotal(right_index) + let mut w = self.subtotals[index].clone(); + w -= &self.subtotal(left_index); + w -= &self.subtotal(right_index); + w } /// Removes the last weight and returns it, or [`None`] if it is empty. - pub fn pop(&mut self) -> Option { + pub fn pop(&mut self) -> Option + where + W: for<'a> SubAssign<&'a W>, + { self.subtotals.pop().map(|weight| { let mut index = self.len(); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] -= weight; + self.subtotals[index] -= &weight; } weight }) @@ -161,64 +170,76 @@ impl WeightedTreeIndex { /// Appends a new weight at the end. pub fn push(&mut self, weight: W) -> Result<(), WeightedError> { - if weight < W::zero() { + if weight < W::ZERO { return Err(WeightedError::InvalidWeight); } if let Some(total) = self.subtotals.first() { - if total.checked_add(&weight).is_none() { + let mut total = total.clone(); + if total.checked_add_assign(&weight).is_err() { return Err(WeightedError::Overflow); } } let mut index = self.len(); - self.subtotals.push(weight); + self.subtotals.push(weight.clone()); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] = self.subtotals[index].checked_add(&weight).unwrap(); + self.subtotals[index].checked_add_assign(&weight).unwrap(); } Ok(()) } /// Updates the weight at an index. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { - if weight < W::zero() { + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> + where + W: for<'a> SubAssign<&'a W>, + { + if weight < W::ZERO { return Err(WeightedError::InvalidWeight); } - let difference = weight - self.get(index); - if difference == W::zero() { + let mut difference = weight; + difference -= &self.get(index); + if difference == W::ZERO { return Ok(()); } if let Some(total) = self.subtotals.first() { - if total.checked_add(&difference).is_none() { + let mut total = total.clone(); + if total.checked_add_assign(&difference).is_err() { return Err(WeightedError::Overflow); } } - self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap(); + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] = self.subtotals[index].checked_add(&difference).unwrap(); + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); } Ok(()) } fn subtotal(&self, index: usize) -> W { if index < self.subtotals.len() { - self.subtotals[index] + self.subtotals[index].clone() } else { - W::zero() + W::ZERO } } } -impl Distribution> for WeightedTreeIndex { +impl + Weight> + Distribution> for WeightedTreeIndex +{ fn sample(&self, rng: &mut R) -> Result { if self.subtotals.is_empty() { return Err(WeightedError::NoItem); } - let total_weight = self.subtotals[0]; - if total_weight == W::zero() { + let total_weight = self.subtotals[0].clone(); + if total_weight == W::ZERO { return Err(WeightedError::AllWeightsZero); } - let mut target_weight = rng.gen_range(W::zero()..total_weight); + let mut target_weight = rng.gen_range(W::ZERO..total_weight); let mut index = 0; loop { // Maybe descend into the left sub tree. @@ -242,64 +263,12 @@ impl Distribution> for WeightedTreeIndex // Otherwise we found the index with the target weight. break; } - assert!(target_weight >= W::zero()); + assert!(target_weight >= W::ZERO); assert!(target_weight < self.subtotal(index)); Ok(index) } } -/// Trait that must be implemented for weights, that are used with -/// [`WeightedTreeIndex`]. Currently no guarantees on the correctness of -/// [`WeightedTreeIndex`] are given for custom implementations of this trait. -#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] -pub trait Weight: - Sized - + Copy - + SampleUniform - + PartialOrd - + Sub - + SubAssign - + Zero -{ - /// Adds two numbers, checking for overflow. If overflow happens, None is returned. - fn checked_add(&self, b: &Self) -> Option; -} - -macro_rules! impl_weight_for_float { - ($T: ident) => { - impl Weight for $T { - fn checked_add(&self, b: &Self) -> Option { - Some(self + b) - } - } - }; -} - -macro_rules! impl_weight_for_int { - ($T: ident) => { - impl Weight for $T { - fn checked_add(&self, b: &Self) -> Option { - CheckedAdd::checked_add(self, b) - } - } - }; -} - -impl_weight_for_float!(f64); -impl_weight_for_float!(f32); -impl_weight_for_int!(usize); -impl_weight_for_int!(u128); -impl_weight_for_int!(u64); -impl_weight_for_int!(u32); -impl_weight_for_int!(u16); -impl_weight_for_int!(u8); -impl_weight_for_int!(isize); -impl_weight_for_int!(i128); -impl_weight_for_int!(i64); -impl_weight_for_int!(i32); -impl_weight_for_int!(i16); -impl_weight_for_int!(i8); - #[cfg(test)] mod test { use super::*; @@ -313,7 +282,10 @@ mod test { #[test] fn test_overflow_error() { - assert_eq!(WeightedTreeIndex::new(&[i32::MAX, 2]), Err(WeightedError::Overflow)); + assert_eq!( + WeightedTreeIndex::new(&[i32::MAX, 2]), + Err(WeightedError::Overflow) + ); let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); assert_eq!(tree.push(3), Err(WeightedError::Overflow)); assert_eq!(tree.update(1, 4), Err(WeightedError::Overflow)); @@ -365,7 +337,7 @@ mod test { let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); let mut tree = WeightedTreeIndex::new(&weights).unwrap(); let mut total_weight = 0.0; - let mut weights = vec![0.0; end]; + let mut weights = alloc::vec![0.0; end]; for i in 0..end { tree.update(i, i as f64).unwrap(); weights[i] = i as f64; @@ -376,7 +348,7 @@ mod test { weights[i] = 0.0; total_weight -= i as f64; } - let mut counts = vec![0_usize; end]; + let mut counts = alloc::vec![0_usize; end]; for _ in 0..samples { let i = tree.sample(&mut rng).unwrap(); counts[i] += 1; diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 5223af594d6..0b1b4da947c 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -18,7 +18,7 @@ use core::fmt; use alloc::vec::Vec; #[cfg(feature = "serde1")] -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; /// A distribution using weighted sampling of discrete items /// @@ -33,11 +33,12 @@ use serde::{Serialize, Deserialize}; /// # Performance /// /// Time complexity of sampling from `WeightedIndex` is `O(log N)` where -/// `N` is the number of weights. As an alternative, -/// [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) -/// supports `O(1)` sampling, but with much higher initialisation cost, -/// and [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) -/// supports `O(log n)` updates with O +/// `N` is the number of weights. There are two alternative implementations with +/// different runtimes characteristics: +/// * [`rand_distr::weighted_alias`](https://docs.rs/rand_distr/*/rand_distr/weighted_alias/index.html) +/// supports `O(1)` sampling, but with much higher initialisation cost. +/// * [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) +/// keeps the weights in a tree structure where sampling and updating is `O(log N)`. /// /// A `WeightedIndex` contains a `Vec` and a [`Uniform`] and so its /// size is the sum of the size of those objects, possibly plus some alignment. @@ -146,15 +147,21 @@ impl WeightedIndex { /// allocation internally. /// /// In case of error, `self` is not modified. - /// + /// + /// Updates take `O(N)` time. If you need to frequently update weights, consider + /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html) + /// as an alternative where an update is `O(log N)`. + /// /// Note: Updating floating-point weights may cause slight inaccuracies in the total weight. /// This method may not return `WeightedError::AllWeightsZero` when all weights - /// are zero if using floating-point weights. + /// are zero if using floating-point weights. pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), WeightedError> - where X: for<'a> ::core::ops::AddAssign<&'a X> + where + X: for<'a> ::core::ops::AddAssign<&'a X> + for<'a> ::core::ops::SubAssign<&'a X> + Clone - + Default { + + Default, + { if new_weights.is_empty() { return Ok(()); } @@ -232,12 +239,14 @@ impl WeightedIndex { } impl Distribution for WeightedIndex -where X: SampleUniform + PartialOrd +where + X: SampleUniform + PartialOrd, { fn sample(&self, rng: &mut R) -> usize { let chosen_weight = self.weight_distribution.sample(rng); // Find the first item which has a weight *higher* than the chosen weight. - self.cumulative_weights.partition_point(|w| w <= &chosen_weight) + self.cumulative_weights + .partition_point(|w| w <= &chosen_weight) } } @@ -290,7 +299,7 @@ macro_rules! impl_weight_float { Ok(()) } } - } + }; } impl_weight_float!(f32); impl_weight_float!(f64); @@ -316,7 +325,7 @@ mod test { } #[test] - fn test_accepting_nan(){ + fn test_accepting_nan() { assert_eq!( WeightedIndex::new(&[core::f32::NAN, 0.5]).unwrap_err(), WeightedError::InvalidWeight, @@ -339,7 +348,6 @@ mod test { ) } - #[test] #[cfg_attr(miri, ignore)] // Miri is too slow fn test_weightedindex() { @@ -463,15 +471,21 @@ mod test { } let mut buf = [0; 10]; - test_samples(&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ - 0, 6, 2, 6, 3, 4, 7, 8, 2, 5, - ]); - test_samples(&[0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ - 0, 0, 0, 1, 0, 0, 2, 3, 0, 0, - ]); - test_samples(&[1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ - 2, 2, 1, 3, 2, 1, 3, 3, 2, 1, - ]); + test_samples( + &[1i32, 1, 1, 1, 1, 1, 1, 1, 1], + &mut buf, + &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], + ); + test_samples( + &[0.7f32, 0.1, 0.1, 0.1], + &mut buf, + &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], + ); + test_samples( + &[1.0f64, 0.999, 0.998, 0.997], + &mut buf, + &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], + ); } #[test] @@ -481,7 +495,10 @@ mod test { #[test] fn overflow() { - assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(WeightedError::Overflow)); + assert_eq!( + WeightedIndex::new([2, usize::MAX]), + Err(WeightedError::Overflow) + ); } } From 45a03aca04eced8c41e90efefe2bb4f5d132afbd Mon Sep 17 00:00:00 2001 From: xmakro Date: Fri, 12 Jan 2024 13:09:52 +0800 Subject: [PATCH 13/20] a --- rand_distr/src/weighted_tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index f241477e0f7..c2a242bfd09 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -53,7 +53,7 @@ use serde::{Deserialize, Serialize}; /// A [`WeightedTreeIndex`] with `n` elements requires `O(n)` memory. /// /// Time complexity for the operations of a [`WeightedTreeIndex`] are: -/// * Constructing: Building the initial tree from a slice of weights takes `O(n)` time. +/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time. /// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time. /// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time. /// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time. From bca832d3b60dbd09dcde391cedd835c5c4c5a3c2 Mon Sep 17 00:00:00 2001 From: xmakro Date: Fri, 12 Jan 2024 13:14:51 +0800 Subject: [PATCH 14/20] x --- rand_distr/src/weighted_tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index c2a242bfd09..84fe80b1856 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -95,7 +95,7 @@ pub struct WeightedTreeIndex { subtotals: Vec, } -impl WeightedTreeIndex { +impl WeightedTreeIndex { /// Creates a new [`WeightedTreeIndex`] from a slice of weights. pub fn new(weights: I) -> Result where From 7a0e234547005f1df40ca1ed9be6f5e738c73be6 Mon Sep 17 00:00:00 2001 From: xmakro Date: Fri, 12 Jan 2024 13:16:39 +0800 Subject: [PATCH 15/20] a --- rand_distr/src/weighted_tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 84fe80b1856..ce5aad4dc25 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -91,7 +91,7 @@ use serde::{Deserialize, Serialize}; serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) )] #[derive(Clone, Default, Debug, PartialEq)] -pub struct WeightedTreeIndex { +pub struct WeightedTreeIndex { subtotals: Vec, } From a23e842d5a5e8debc1c8632e1fdbabc40a063762 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 7 Feb 2024 16:49:28 +0000 Subject: [PATCH 16/20] Address comments --- rand_distr/src/weighted_tree.rs | 53 ++++++++++++++++++++++----------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index ce5aad4dc25..370c38914db 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! This module contains an implementation of a tree sttructure for sampling random +//! This module contains an implementation of a tree structure for sampling random //! indices with probabilities proportional to a collection of weights. use core::ops::SubAssign; @@ -14,13 +14,9 @@ use core::ops::SubAssign; use super::WeightedError; use crate::Distribution; use alloc::vec::Vec; -use rand::{ - distributions::{ - uniform::{SampleBorrow, SampleUniform}, - Weight, - }, - Rng, -}; +use rand::distributions::uniform::{SampleBorrow, SampleUniform}; +use rand::distributions::Weight; +use rand::Rng; #[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; @@ -29,7 +25,7 @@ use serde::{Deserialize, Serialize}; /// Sampling a [`WeightedTreeIndex`] distribution returns the index of a randomly /// selected element from the vector used to create the [`WeightedTreeIndex`]. /// The chance of a given element being picked is proportional to the value of -/// the element. The weights can have any type `W` for which a implementation of +/// the element. The weights can have any type `W` for which an implementation of /// [`Weight`] exists. /// /// # Key differences @@ -71,15 +67,16 @@ use serde::{Deserialize, Serialize}; /// dist.push(1).unwrap(); /// dist.update(1, 1).unwrap(); /// let mut rng = thread_rng(); +/// let mut samples = [0; 3]; /// for _ in 0..100 { /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c' -/// let i = dist.sample(&mut rng).unwrap(); -/// println!("{}", choices[i]); +/// let i = dist.sample(&mut rng); +/// samples[i] += 1; /// } +/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::>()); /// ``` /// /// [`WeightedTreeIndex`]: WeightedTreeIndex -/// [`Uniform::sample`]: Distribution::sample #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr( @@ -132,7 +129,7 @@ impl WeightedTreeInd /// Returns `true` if we can sample. /// /// This is the case if the total weight of the tree is greater than zero. - pub fn can_sample(&self) -> bool { + pub fn is_valid(&self) -> bool { if let Some(weight) = self.subtotals.first() { *weight > W::ZERO } else { @@ -229,9 +226,13 @@ impl WeightedTreeInd } impl + Weight> - Distribution> for WeightedTreeIndex + WeightedTreeIndex { - fn sample(&self, rng: &mut R) -> Result { + /// Samples a randomly selected index from the weighted distribution. + /// + /// Returns an error if there are no elements or all weights are zero. This + /// is unlike [`Distribution::sample`], which panics in those cases. + fn safe_sample(&self, rng: &mut R) -> Result { if self.subtotals.is_empty() { return Err(WeightedError::NoItem); } @@ -269,6 +270,19 @@ impl + Weight> } } +impl + Weight> Distribution + for WeightedTreeIndex +{ + /// Samples a randomly selected index from the weighted distribution. + /// + /// Caution: This method panics if there are no elements or all weights are zero. However, + /// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] + /// returns `true`. + fn sample(&self, rng: &mut R) -> usize { + self.safe_sample(rng).unwrap() + } +} + #[cfg(test)] mod test { use super::*; @@ -277,7 +291,10 @@ mod test { fn test_no_item_error() { let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let tree = WeightedTreeIndex::::new(&[]).unwrap(); - assert_eq!(tree.sample(&mut rng).unwrap_err(), WeightedError::NoItem); + assert_eq!( + tree.safe_sample(&mut rng).unwrap_err(), + WeightedError::NoItem + ); } #[test] @@ -297,7 +314,7 @@ mod test { let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); let mut rng = crate::test::rng(0x9c9fa0b0580a7031); assert_eq!( - tree.sample(&mut rng).unwrap_err(), + tree.safe_sample(&mut rng).unwrap_err(), WeightedError::AllWeightsZero ); } @@ -350,7 +367,7 @@ mod test { } let mut counts = alloc::vec![0_usize; end]; for _ in 0..samples { - let i = tree.sample(&mut rng).unwrap(); + let i = tree.sample(&mut rng); counts[i] += 1; } for i in 0..start { From c8e5e35b22b8a139882fe7fe377be133992456d5 Mon Sep 17 00:00:00 2001 From: xmakro Date: Wed, 7 Feb 2024 17:32:22 +0000 Subject: [PATCH 17/20] more comments --- rand_distr/src/weighted_tree.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 370c38914db..68839679816 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -232,7 +232,7 @@ impl + Weight> /// /// Returns an error if there are no elements or all weights are zero. This /// is unlike [`Distribution::sample`], which panics in those cases. - fn safe_sample(&self, rng: &mut R) -> Result { + fn try_sample(&self, rng: &mut R) -> Result { if self.subtotals.is_empty() { return Err(WeightedError::NoItem); } @@ -270,16 +270,16 @@ impl + Weight> } } +/// Samples a randomly selected index from the weighted distribution. +/// +/// Caution: This method panics if there are no elements or all weights are zero. However, +/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] +/// returns `true`. impl + Weight> Distribution for WeightedTreeIndex { - /// Samples a randomly selected index from the weighted distribution. - /// - /// Caution: This method panics if there are no elements or all weights are zero. However, - /// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`] - /// returns `true`. fn sample(&self, rng: &mut R) -> usize { - self.safe_sample(rng).unwrap() + self.try_sample(rng).unwrap() } } @@ -292,7 +292,7 @@ mod test { let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let tree = WeightedTreeIndex::::new(&[]).unwrap(); assert_eq!( - tree.safe_sample(&mut rng).unwrap_err(), + tree.try_sample(&mut rng).unwrap_err(), WeightedError::NoItem ); } @@ -314,7 +314,7 @@ mod test { let tree = WeightedTreeIndex::::new(&[0.0, 0.0]).unwrap(); let mut rng = crate::test::rng(0x9c9fa0b0580a7031); assert_eq!( - tree.safe_sample(&mut rng).unwrap_err(), + tree.try_sample(&mut rng).unwrap_err(), WeightedError::AllWeightsZero ); } From 689ac4875f83d6eb36f0452d63fdbd5a9d429697 Mon Sep 17 00:00:00 2001 From: xmakro Date: Thu, 8 Feb 2024 13:00:42 +0000 Subject: [PATCH 18/20] a --- rand_distr/src/weighted_tree.rs | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 68839679816..5eb82d3f45f 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -88,11 +88,15 @@ use serde::{Deserialize, Serialize}; serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")) )] #[derive(Clone, Default, Debug, PartialEq)] -pub struct WeightedTreeIndex { +pub struct WeightedTreeIndex< + W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign + Weight, +> { subtotals: Vec, } -impl WeightedTreeIndex { +impl + Weight> + WeightedTreeIndex +{ /// Creates a new [`WeightedTreeIndex`] from a slice of weights. pub fn new(weights: I) -> Result where @@ -138,28 +142,22 @@ impl WeightedTreeInd } /// Gets the weight at an index. - pub fn get(&self, index: usize) -> W - where - W: for<'a> SubAssign<&'a W>, - { + pub fn get(&self, index: usize) -> W { let left_index = 2 * index + 1; let right_index = 2 * index + 2; let mut w = self.subtotals[index].clone(); - w -= &self.subtotal(left_index); - w -= &self.subtotal(right_index); + w -= self.subtotal(left_index); + w -= self.subtotal(right_index); w } /// Removes the last weight and returns it, or [`None`] if it is empty. - pub fn pop(&mut self) -> Option - where - W: for<'a> SubAssign<&'a W>, - { + pub fn pop(&mut self) -> Option { self.subtotals.pop().map(|weight| { let mut index = self.len(); while index != 0 { index = (index - 1) / 2; - self.subtotals[index] -= &weight; + self.subtotals[index] -= weight.clone(); } weight }) @@ -186,15 +184,12 @@ impl WeightedTreeInd } /// Updates the weight at an index. - pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> - where - W: for<'a> SubAssign<&'a W>, - { + pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), WeightedError> { if weight < W::ZERO { return Err(WeightedError::InvalidWeight); } let mut difference = weight; - difference -= &self.get(index); + difference -= self.get(index); if difference == W::ZERO { return Ok(()); } From 57158b7db538b031e829492cf26bc8182b823241 Mon Sep 17 00:00:00 2001 From: xmakro Date: Thu, 8 Feb 2024 13:08:45 +0000 Subject: [PATCH 19/20] c --- rand_distr/src/weighted_tree.rs | 38 ++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index 5eb82d3f45f..e3ab1fd981e 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -188,25 +188,33 @@ impl + Weight> if weight < W::ZERO { return Err(WeightedError::InvalidWeight); } - let mut difference = weight; - difference -= self.get(index); - if difference == W::ZERO { - return Ok(()); - } - if let Some(total) = self.subtotals.first() { - let mut total = total.clone(); - if total.checked_add_assign(&difference).is_err() { - return Err(WeightedError::Overflow); + let old_weight = self.get(index); + if weight > old_weight { + let mut difference = weight; + difference -= old_weight; + if let Some(total) = self.subtotals.first() { + let mut total = total.clone(); + if total.checked_add_assign(&difference).is_err() { + return Err(WeightedError::Overflow); + } } - } - self.subtotals[index] - .checked_add_assign(&difference) - .unwrap(); - while index != 0 { - index = (index - 1) / 2; self.subtotals[index] .checked_add_assign(&difference) .unwrap(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] + .checked_add_assign(&difference) + .unwrap(); + } + } else if weight < old_weight { + let mut difference = old_weight; + difference -= weight; + self.subtotals[index] -= difference.clone(); + while index != 0 { + index = (index - 1) / 2; + self.subtotals[index] -= difference.clone(); + } } Ok(()) } From e4596454791bee7b15887a60b30a7c11d97e2a06 Mon Sep 17 00:00:00 2001 From: xmakro Date: Thu, 8 Feb 2024 14:14:02 +0000 Subject: [PATCH 20/20] a --- rand_distr/src/weighted_tree.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rand_distr/src/weighted_tree.rs b/rand_distr/src/weighted_tree.rs index e3ab1fd981e..b308cdb2c04 100644 --- a/rand_distr/src/weighted_tree.rs +++ b/rand_distr/src/weighted_tree.rs @@ -268,7 +268,7 @@ impl + Weight> break; } assert!(target_weight >= W::ZERO); - assert!(target_weight < self.subtotal(index)); + assert!(target_weight < self.get(index)); Ok(index) } }