From 554d3311a32ce91305d644d2c4250fdf05f6699d Mon Sep 17 00:00:00 2001 From: Diggory Hardy Date: Tue, 19 Nov 2024 10:23:49 +0000 Subject: [PATCH] Fix calculated CDF in choose_two_weighted_indexed --- distr_test/tests/weighted.rs | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/distr_test/tests/weighted.rs b/distr_test/tests/weighted.rs index cb89287556..0d24c88909 100644 --- a/distr_test/tests/weighted.rs +++ b/distr_test/tests/weighted.rs @@ -145,20 +145,25 @@ fn choose_two_weighted_indexed() { let pmf1 = (0..num).map(|i| weight(i as i64)).collect::>(); let sum: f64 = pmf1.iter().sum(); - let sum_sq: f64 = pmf1.iter().map(|x| x * x).sum(); - let frac = 2.0 / (sum * sum - sum_sq); + let frac = 1.0 / sum; let mut ac = 0.0; let mut cdf = Vec::with_capacity(num * num); for a in 0..num { for b in 0..num { if a < b { - ac += pmf1[a] * pmf1[b]; + let pa = pmf1[a] * frac; + let pab = pa * pmf1[b] / (sum - pmf1[a]); + + let pb = pmf1[b] * frac; + let pba = pb * pmf1[a] / (sum - pmf1[b]); + + ac += pab + pba; } - cdf.push(ac * frac); + cdf.push(ac); } } - assert!((ac * frac - 1.0).abs() < 1e-9); + assert!((cdf.last().unwrap() - 1.0).abs() < 1e-9); let cdf = |i| { if i < 0 { @@ -173,9 +178,9 @@ fn choose_two_weighted_indexed() { test_weights(100, |_| 1.0); test_weights(100, |i| ((i + 1) as f64).ln()); - // test_weights(100, |i| i as f64); - // test_weights(100, |i| (i as f64).powi(3)); - // test_weights(100, |i| 1.0 / ((i + 1) as f64)); + test_weights(100, |i| i as f64); + test_weights(100, |i| (i as f64).powi(3)); + test_weights(100, |i| 1.0 / ((i + 1) as f64)); } #[test]