From 8441893b93d6d07955d83653459c77cf4dbd2acb Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 11:32:52 +0100 Subject: [PATCH 1/8] Use simpler definition of max_rand --- src/distributions/uniform.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 34a6b252f4..2dfb8e7ba5 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -931,9 +931,7 @@ macro_rules! uniform_float_impl { if !(low.all_lt(high)) { return Err(Error::EmptyRange); } - let max_rand = <$ty>::splat( - ($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); let mut scale = high - low; if !(scale.all_finite()) { @@ -967,9 +965,7 @@ macro_rules! uniform_float_impl { if !low.all_le(high) { return Err(Error::EmptyRange); } - let max_rand = <$ty>::splat( - ($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, - ); + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); let mut scale = (high - low) / max_rand; if !scale.all_finite() { From 9abc8aa4f5258d9d7f3e3a7abfc8b8d967c4467c Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:05:52 +0100 Subject: [PATCH 2/8] Permit samples of UniformFloat::new to equal high --- src/distributions/uniform.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 2dfb8e7ba5..6f6d75e951 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -939,7 +939,7 @@ macro_rules! uniform_float_impl { } loop { - let mask = (scale * max_rand + low).ge_mask(high); + let mask = (scale * max_rand + low).gt_mask(high); if !mask.any() { break; } @@ -1461,14 +1461,14 @@ mod tests { let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap(); for _ in 0..100 { let v = rng.sample(my_uniform).extract(lane); - assert!(low_scalar <= v && v < high_scalar); + assert!(low_scalar <= v && v <= high_scalar); let v = rng.sample(my_incl_uniform).extract(lane); assert!(low_scalar <= v && v <= high_scalar); let v = <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng) .unwrap() .extract(lane); - assert!(low_scalar <= v && v < high_scalar); + assert!(low_scalar <= v && v <= high_scalar); let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive( low, high, &mut rng, ) @@ -1506,12 +1506,12 @@ mod tests { low_scalar ); - assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); + assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar); assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); // sample_single cannot cope with max_rng: // assert!(<$ty as SampleUniform>::Sampler // ::sample_single(low, high, &mut max_rng).unwrap() - // .extract(lane) < high_scalar); + // .extract(lane) <= high_scalar); assert!( <$ty as SampleUniform>::Sampler::sample_single_inclusive( low, @@ -1539,7 +1539,7 @@ mod tests { ) .unwrap() .extract(lane) - < high_scalar + <= high_scalar ); } } From ba593dc76c8beb3f696e218d613f12d28ea58d69 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:07:27 +0100 Subject: [PATCH 3/8] Add (private) fn UniformFloat::new_bounded --- src/distributions/uniform.rs | 58 ++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 6f6d75e951..06ad46d42d 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -908,6 +908,33 @@ pub struct UniformFloat { macro_rules! uniform_float_impl { ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => { + $(#[cfg($meta)])? + impl UniformFloat<$ty> { + /// Construct, reducing `scale` as required to ensure that rounding + /// can never yield values greater than `high`. + /// + /// Note: though it may be tempting to use a variant of this method + /// to ensure that samples from `[low, high)` are always strictly + /// less than `high`, this approach may be very slow where + /// `scale.abs()` is much smaller than `high.abs()` + /// (example: `low=0.99999999997819644, high=1.`). + fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self { + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + + loop { + let mask = (scale * max_rand + low).gt_mask(high); + if !mask.any() { + break; + } + scale = scale.decrease_masked(mask); + } + + debug_assert!(<$ty>::splat(0.0).all_le(scale)); + + UniformFloat { low, scale } + } + } + $(#[cfg($meta)])? impl SampleUniform for $ty { type Sampler = UniformFloat<$ty>; @@ -931,24 +958,13 @@ macro_rules! uniform_float_impl { if !(low.all_lt(high)) { return Err(Error::EmptyRange); } - let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); - let mut scale = high - low; + let scale = high - low; if !(scale.all_finite()) { return Err(Error::NonFinite); } - loop { - let mask = (scale * max_rand + low).gt_mask(high); - if !mask.any() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - Ok(UniformFloat { low, scale }) + Ok(Self::new_bounded(low, high, scale)) } fn new_inclusive(low_b: B1, high_b: B2) -> Result @@ -965,24 +981,14 @@ macro_rules! uniform_float_impl { if !low.all_le(high) { return Err(Error::EmptyRange); } - let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); - let mut scale = (high - low) / max_rand; + let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON); + let scale = (high - low) / max_rand; if !scale.all_finite() { return Err(Error::NonFinite); } - loop { - let mask = (scale * max_rand + low).gt_mask(high); - if !mask.any() { - break; - } - scale = scale.decrease_masked(mask); - } - - debug_assert!(<$ty>::splat(0.0).all_le(scale)); - - Ok(UniformFloat { low, scale }) + Ok(Self::new_bounded(low, high, scale)) } fn sample(&self, rng: &mut R) -> Self::X { From c0defa99400a1357bad239031517fbceecb2a28d Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:14:27 +0100 Subject: [PATCH 4/8] test_float_assertions: we no longer need to use catch_unwind --- src/distributions/uniform.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 06ad46d42d..7557c8c48e 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -1592,10 +1592,9 @@ mod tests { #[cfg(all(feature = "std", panic = "unwind"))] fn test_float_assertions() { use super::SampleUniform; - use std::panic::catch_unwind; - fn range(low: T, high: T) { + fn range(low: T, high: T) -> Result { let mut rng = crate::test::rng(253); - T::Sampler::sample_single(low, high, &mut rng).unwrap(); + T::Sampler::sample_single(low, high, &mut rng) } macro_rules! t { @@ -1618,10 +1617,10 @@ mod tests { for lane in 0..<$ty>::LEN { let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar); let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar); - assert!(catch_unwind(|| range(low, high)).is_err()); + assert!(range(low, high).is_err()); assert!(Uniform::new(low, high).is_err()); assert!(Uniform::new_inclusive(low, high).is_err()); - assert!(catch_unwind(|| range(low, low)).is_err()); + assert!(range(low, low).is_err()); assert!(Uniform::new(low, low).is_err()); } } From 789c6a0ad5f9b4d8682352592b4f87c73bbbf41d Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:16:41 +0100 Subject: [PATCH 5/8] UniformFloat::sample_single: no longer omit high from result range --- src/distributions/uniform.rs | 68 +----------------------------------- 1 file changed, 1 insertion(+), 67 deletions(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 7557c8c48e..09be2d87c0 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -1012,72 +1012,7 @@ macro_rules! uniform_float_impl { B1: SampleBorrow + Sized, B2: SampleBorrow + Sized, { - let low = *low_b.borrow(); - let high = *high_b.borrow(); - #[cfg(debug_assertions)] - if !low.all_finite() || !high.all_finite() { - return Err(Error::NonFinite); - } - if !low.all_lt(high) { - return Err(Error::EmptyRange); - } - let mut scale = high - low; - if !scale.all_finite() { - return Err(Error::NonFinite); - } - - loop { - // Generate a value in the range [1, 2) - let value1_2 = - (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0); - - // Get a value in the range [0, 1) to avoid overflow when multiplying by scale - let value0_1 = value1_2 - <$ty>::splat(1.0); - - // Doing multiply before addition allows some architectures - // to use a single instruction. - let res = value0_1 * scale + low; - - debug_assert!(low.all_le(res) || !scale.all_finite()); - if res.all_lt(high) { - return Ok(res); - } - - // This handles a number of edge cases. - // * `low` or `high` is NaN. In this case `scale` and - // `res` are going to end up as NaN. - // * `low` is negative infinity and `high` is finite. - // `scale` is going to be infinite and `res` will be - // NaN. - // * `high` is positive infinity and `low` is finite. - // `scale` is going to be infinite and `res` will - // be infinite or NaN (if value0_1 is 0). - // * `low` is negative infinity and `high` is positive - // infinity. `scale` will be infinite and `res` will - // be NaN. - // * `low` and `high` are finite, but `high - low` - // overflows to infinite. `scale` will be infinite - // and `res` will be infinite or NaN (if value0_1 is 0). - // So if `high` or `low` are non-finite, we are guaranteed - // to fail the `res < high` check above and end up here. - // - // While we technically should check for non-finite `low` - // and `high` before entering the loop, by doing the checks - // here instead, we allow the common case to avoid these - // checks. But we are still guaranteed that if `low` or - // `high` are non-finite we'll end up here and can do the - // appropriate checks. - // - // Likewise, `high - low` overflowing to infinity is also - // rare, so handle it here after the common case. - let mask = !scale.finite_mask(); - if mask.any() { - if !(low.all_finite() && high.all_finite()) { - return Err(Error::NonFinite); - } - scale = scale.decrease_masked(mask); - } - } + Self::sample_single_inclusive(low_b, high_b, rng) } #[inline] @@ -1620,7 +1555,6 @@ mod tests { assert!(range(low, high).is_err()); assert!(Uniform::new(low, high).is_err()); assert!(Uniform::new_inclusive(low, high).is_err()); - assert!(range(low, low).is_err()); assert!(Uniform::new(low, low).is_err()); } } From a7805b6ac83025c21f3198532d91d21d8d74a2f3 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:18:35 +0100 Subject: [PATCH 6/8] Remove unused fns --- src/distributions/utils.rs | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/distributions/utils.rs b/src/distributions/utils.rs index aee92b6790..7e84665ec4 100644 --- a/src/distributions/utils.rs +++ b/src/distributions/utils.rs @@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils { fn all_finite(self) -> bool; type Mask; - fn finite_mask(self) -> Self::Mask; fn gt_mask(self, other: Self) -> Self::Mask; - fn ge_mask(self, other: Self) -> Self::Mask; // Decrease all lanes where the mask is `true` to the next lower value // representable by the floating-point type. At least one of the lanes @@ -292,21 +290,11 @@ macro_rules! scalar_float_impl { self.is_finite() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self > other } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self >= other - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { debug_assert!(mask, "At least one lane must be set"); @@ -368,21 +356,11 @@ macro_rules! simd_impl { self.is_finite().all() } - #[inline(always)] - fn finite_mask(self) -> Self::Mask { - self.is_finite() - } - #[inline(always)] fn gt_mask(self, other: Self) -> Self::Mask { self.simd_gt(other) } - #[inline(always)] - fn ge_mask(self, other: Self) -> Self::Mask { - self.simd_ge(other) - } - #[inline(always)] fn decrease_masked(self, mask: Self::Mask) -> Self { // Casting a mask into ints will produce all bits set for From f576a93bb87d8b939a7ea36341551d7856e9cb2f Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:40:54 +0100 Subject: [PATCH 7/8] Document that UniformFloat samples may equal high --- src/distributions/uniform.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index 09be2d87c0..5540b74e46 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -51,7 +51,8 @@ //! Those methods should include an assertion to check the range is valid (i.e. //! `low < high`). The example below merely wraps another back-end. //! -//! The `new`, `new_inclusive` and `sample_single` functions use arguments of +//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive` +//! functions use arguments of //! type `SampleBorrow` to support passing in values by reference or //! by value. In the implementation of these functions, you can choose to //! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose @@ -207,6 +208,11 @@ impl Uniform { /// Create a new `Uniform` instance, which samples uniformly from the half /// open range `[low, high)` (excluding `high`). /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// /// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is /// non-finite. In release mode, only the range is checked. pub fn new(low: B1, high: B2) -> Result, Error> @@ -265,6 +271,11 @@ pub trait UniformSampler: Sized { /// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`. /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// /// Usually users should not call this directly but prefer to use /// [`Uniform::new`]. fn new(low: B1, high: B2) -> Result @@ -287,6 +298,11 @@ pub trait UniformSampler: Sized { /// Sample a single value uniformly from a range with inclusive lower bound /// and exclusive upper bound `[low, high)`. /// + /// For discrete types (e.g. integers), samples will always be strictly less + /// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`), + /// samples may equal `high` due to loss of precision but may not be + /// greater than `high`. + /// /// By default this is implemented using /// `UniformSampler::new(low, high).sample(rng)`. However, for some types /// more optimal implementations for single usage may be provided via this From e5e6f45e2df65f305307e7aa4b6cd0d4b293fca0 Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Mon, 24 Jun 2024 12:54:50 +0100 Subject: [PATCH 8/8] Add CHANGELOG entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 18ce72a533..39e3c5c727 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. - Move all benchmarks to new `benches` crate (#1439) - Annotate panicking methods with `#[track_caller]` (#1442, #1447) - Enable feature `small_rng` by default (#1455) +- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462) ## [0.9.0-alpha.1] - 2024-03-18 - Add the `Slice::num_choices` method to the Slice distribution (#1402)