diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index a923f879d2..5adb82f811 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -126,7 +126,7 @@ pub use self::slice::Slice; #[doc(inline)] pub use self::uniform::Uniform; #[cfg(feature = "alloc")] -pub use self::weighted_index::{WeightedError, WeightedIndex}; +pub use self::weighted_index::{Weight, WeightedError, WeightedIndex}; #[allow(unused)] use crate::Rng; diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index c71842d66b..de3628b5ea 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -99,12 +99,12 @@ impl WeightedIndex { where I: IntoIterator, I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, + X: Weight, { let mut iter = weights.into_iter(); let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone(); - let zero = ::default(); + let zero = X::ZERO; if !(total_weight >= zero) { return Err(WeightedError::InvalidWeight); } @@ -118,8 +118,7 @@ impl WeightedIndex { } weights.push(total_weight.clone()); - total_weight += w.borrow(); - if total_weight < *w.borrow() { + if let Err(()) = total_weight.checked_add_assign(w.borrow()) { return Err(WeightedError::Overflow); } } @@ -240,6 +239,60 @@ where X: SampleUniform + PartialOrd } } +/// Bounds on a weight +/// +/// See usage in [`WeightedIndex`]. +pub trait Weight: Clone { + /// Representation of 0 + const ZERO: Self; + + /// Checked addition + /// + /// - `Result::Ok`: On success, `v` is added to `self` + /// - `Result::Err`: Returns an error when `Self` cannot represent the + /// result of `self + v` (i.e. overflow). The value of `self` should be + /// discarded. + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; +} + +macro_rules! impl_weight_int { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + match self.checked_add(*v) { + Some(sum) => { + *self = sum; + Ok(()) + } + None => Err(()), + } + } + } + }; + ($t:ty, $($tt:ty),*) => { + impl_weight_int!($t); + impl_weight_int!($($tt),*); + } +} +impl_weight_int!(i8, i16, i32, i64, i128, isize); +impl_weight_int!(u8, u16, u32, u64, u128, usize); + +macro_rules! impl_weight_float { + ($t:ty) => { + impl Weight for $t { + const ZERO: Self = 0.0; + fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { + // Floats have an explicit representation for overflow + *self += *v; + Ok(()) + } + } + } +} +impl_weight_float!(f32); +impl_weight_float!(f64); + #[cfg(test)] mod test { use super::*; @@ -392,12 +445,11 @@ mod test { #[test] fn value_stability() { - fn test_samples( + fn test_samples( weights: I, buf: &mut [usize], expected: &[usize], ) where I: IntoIterator, I::Item: SampleBorrow, - X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default, { assert_eq!(buf.len(), expected.len()); let distr = WeightedIndex::new(weights).unwrap(); diff --git a/src/seq/mod.rs b/src/seq/mod.rs index bbb46fc55f..9012b21b90 100644 --- a/src/seq/mod.rs +++ b/src/seq/mod.rs @@ -40,7 +40,7 @@ use alloc::vec::Vec; #[cfg(feature = "alloc")] use crate::distributions::uniform::{SampleBorrow, SampleUniform}; #[cfg(feature = "alloc")] -use crate::distributions::WeightedError; +use crate::distributions::{Weight, WeightedError}; use crate::Rng; use self::coin_flipper::CoinFlipper; @@ -170,11 +170,7 @@ pub trait SliceRandom { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; + X: SampleUniform + Weight + ::core::cmp::PartialOrd; /// Biased sampling for one element (mut) /// @@ -203,11 +199,7 @@ pub trait SliceRandom { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default; + X: SampleUniform + Weight + ::core::cmp::PartialOrd; /// Biased sampling of `amount` distinct elements /// @@ -585,11 +577,7 @@ impl SliceRandom for [T] { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, + X: SampleUniform + Weight + ::core::cmp::PartialOrd, { use crate::distributions::{Distribution, WeightedIndex}; let distr = WeightedIndex::new(self.iter().map(weight))?; @@ -604,11 +592,7 @@ impl SliceRandom for [T] { R: Rng + ?Sized, F: Fn(&Self::Item) -> B, B: SampleBorrow, - X: SampleUniform - + for<'a> ::core::ops::AddAssign<&'a X> - + ::core::cmp::PartialOrd - + Clone - + Default, + X: SampleUniform + Weight + ::core::cmp::PartialOrd, { use crate::distributions::{Distribution, WeightedIndex}; let distr = WeightedIndex::new(self.iter().map(weight))?;