diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index 332d1fc06c..26e7712b2c 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -45,14 +45,10 @@ use rand::Rng; /// ``` #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] -pub struct Poisson +pub struct Poisson(Method) where F: Float + FloatConst, - Standard: Distribution, -{ - lambda: F, - method: Method, -} + Standard: Distribution; /// Error type returned from [`Poisson::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -82,6 +78,7 @@ impl std::error::Error for Error {} pub(crate) struct KnuthMethod { exp_lambda: F, } + impl KnuthMethod { pub(crate) fn new(lambda: F) -> Self { KnuthMethod { @@ -93,10 +90,26 @@ impl KnuthMethod { #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct RejectionMethod { + lambda: F, log_lambda: F, sqrt_2lambda: F, magic_val: F, } + +impl RejectionMethod { + pub(crate) fn new(lambda: F) -> Self { + let log_lambda = lambda.ln(); + let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt(); + let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda); + RejectionMethod { + lambda, + log_lambda, + sqrt_2lambda, + magic_val, + } + } +} + #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] enum Method { @@ -132,17 +145,10 @@ where let method = if lambda < F::from(12.0).unwrap() { Method::Knuth(KnuthMethod::new(lambda)) } else { - let log_lambda = lambda.ln(); - let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt(); - let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda); - Method::Rejection(RejectionMethod { - log_lambda, - sqrt_2lambda, - magic_val, - }) + Method::Rejection(RejectionMethod::new(lambda)) }; - Ok(Poisson { lambda, method }) + Ok(Poisson(method)) } } @@ -161,12 +167,13 @@ where result - F::one() } } -impl RejectionMethod + +impl Distribution for RejectionMethod where F: Float + FloatConst, Standard: Distribution, { - fn sample(&self, lambda: F, rng: &mut R) -> F { + fn sample(&self, rng: &mut R) -> F { // The algorithm from Numerical Recipes in C // we use the Cauchy distribution as the comparison distribution @@ -181,7 +188,7 @@ where // draw from the Cauchy distribution comp_dev = rng.sample(cauchy); // shift the peak of the comparison distribution - result = self.sqrt_2lambda * comp_dev + lambda; + result = self.sqrt_2lambda * comp_dev + self.lambda; // repeat the drawing until we are in the range of possible values if result >= F::zero() { break; @@ -210,6 +217,7 @@ where result } } + impl Distribution for Poisson where F: Float + FloatConst, @@ -217,9 +225,9 @@ where { #[inline] fn sample(&self, rng: &mut R) -> F { - match &self.method { + match &self.0 { Method::Knuth(method) => method.sample(rng), - Method::Rejection(method) => method.sample(self.lambda, rng), + Method::Rejection(method) => method.sample(rng), } } }