diff --git a/rand_distr/Cargo.toml b/rand_distr/Cargo.toml index ff1400d6e6..2c1d6753f2 100644 --- a/rand_distr/Cargo.toml +++ b/rand_distr/Cargo.toml @@ -21,12 +21,13 @@ default = ["std"] std = ["alloc", "rand/std"] alloc = ["rand/alloc"] std_math = ["num-traits/std"] -serde1 = ["serde", "rand/serde1"] +serde1 = ["serde", "serde_with", "rand/serde1"] [dependencies] rand = { path = "..", version = "0.9.0", default-features = false } num-traits = { version = "0.2", default-features = false, features = ["libm"] } serde = { version = "1.0.103", features = ["derive"], optional = true } +serde_with = { version = "1.14.0", optional = true } [dev-dependencies] rand_pcg = { version = "0.4.0", path = "../rand_pcg" } diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 462e9f901f..72df35a203 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -13,6 +13,8 @@ use num_traits::Float; use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; use rand::Rng; use core::fmt; +#[cfg(feature = "serde1")] +use serde_with::serde_as; /// The Dirichlet distribution `Dirichlet(alpha)`. /// @@ -31,8 +33,8 @@ use core::fmt; /// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples); /// ``` #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] +#[cfg_attr(feature = "serde1", serde_as)] #[derive(Clone, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Dirichlet where F: Float, @@ -41,6 +43,7 @@ where Open01: Distribution, { /// Concentration parameters (alpha) + #[cfg_attr(feature = "serde1", serde_as(as = "[_; N]"))] alpha: [F; N], }