Skip to content

Commit

Permalink
Fix calculated CDF in choose_two_weighted_indexed
Browse files Browse the repository at this point in the history
  • Loading branch information
dhardy committed Nov 19, 2024
1 parent 16a16c6 commit 554d331
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions distr_test/tests/weighted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,25 @@ fn choose_two_weighted_indexed() {

let pmf1 = (0..num).map(|i| weight(i as i64)).collect::<Vec<f64>>();
let sum: f64 = pmf1.iter().sum();
let sum_sq: f64 = pmf1.iter().map(|x| x * x).sum();
let frac = 2.0 / (sum * sum - sum_sq);
let frac = 1.0 / sum;

let mut ac = 0.0;
let mut cdf = Vec::with_capacity(num * num);
for a in 0..num {
for b in 0..num {
if a < b {
ac += pmf1[a] * pmf1[b];
let pa = pmf1[a] * frac;
let pab = pa * pmf1[b] / (sum - pmf1[a]);

let pb = pmf1[b] * frac;
let pba = pb * pmf1[a] / (sum - pmf1[b]);

ac += pab + pba;
}
cdf.push(ac * frac);
cdf.push(ac);
}
}
assert!((ac * frac - 1.0).abs() < 1e-9);
assert!((cdf.last().unwrap() - 1.0).abs() < 1e-9);

let cdf = |i| {
if i < 0 {
Expand All @@ -173,9 +178,9 @@ fn choose_two_weighted_indexed() {

test_weights(100, |_| 1.0);
test_weights(100, |i| ((i + 1) as f64).ln());
// test_weights(100, |i| i as f64);
// test_weights(100, |i| (i as f64).powi(3));
// test_weights(100, |i| 1.0 / ((i + 1) as f64));
test_weights(100, |i| i as f64);
test_weights(100, |i| (i as f64).powi(3));
test_weights(100, |i| 1.0 / ((i + 1) as f64));
}

#[test]
Expand Down

0 comments on commit 554d331

Please sign in to comment.