Skip to content

Commit 7471609

Browse files
authored
Merge pull request #37 from ncsa/issue-23
Issue 23: Expand test for simple_rng
2 parents f0c0c54 + 5e654cc commit 7471609

File tree

1 file changed

+108
-84
lines changed

1 file changed

+108
-84
lines changed

simple_rng/src/lib.rs

+108-84
Original file line numberDiff line numberDiff line change
@@ -107,33 +107,34 @@ impl Rng {
107107
self.s2
108108
}
109109

110-
/// Returns a bool with a probability `frac` of being true.
110+
/// Returns a bool with a probability `p` of being true.
111111
///
112112
/// # Panics
113113
///
114-
/// If `frac` does *not* fall in [0, 1] range.
115-
pub fn gen_bool(&mut self, frac: f64) -> bool {
116-
// Uses a Bernoulli distribution to generate a fractional probability
117-
// Then an input from the RNG to sample
118-
//
119-
// This method is lifted from rand 0.85, but with restrictions edited so we can use our
120-
// custom RNG
121-
if !(0.0..1.0).contains(&frac) {
122-
if frac == 1.0 {
114+
/// If `p` does *not* fall in range [0, 1].
115+
116+
// This method is lifted from rand 0.85, but with restrictions edited so we can use our
117+
// custom RNG
118+
pub fn gen_bool(&mut self, p: f64) -> bool {
119+
if !(0.0..1.0).contains(&p) {
120+
if p == 1.0 {
123121
return true;
124122
}
125-
panic!("Invalid frac for gen_bool {} (must be in [0.0, 1.0)", frac)
123+
panic!("p={:?} is outside range [0.0, 1.0]", p);
126124
}
125+
127126
// This is just `2.0.powi(64)`, but written this way because it is not available
128-
// in `no_std` mode. (from rand 0.8.5 docs). We used u64 max + 1, the equivalent
129-
let p_int = (frac * (u64::MAX as f64 + 1.0)) as u64;
127+
// in `no_std` mode. (from rand 0.8.5 docs). We used u64 max + 1, the equivalent.
128+
let scale = u64::MAX as f64 + 1.0;
129+
130+
let p_int = (p * scale) as u64;
130131
let x = self.rand_u64();
131132
x < p_int
132133
}
133134

134135
/// Shuffles elements in a sequence in place.
135136
pub fn shuffle<T: Clone>(&mut self, a: &mut Vec<T>) {
136-
for i in (0..a.len()).rev() {
137+
for i in (1..a.len()).rev() {
137138
let j = (self.random() * i as f64).floor() as usize;
138139
let temp = a[i].clone();
139140
a[i] = a[j].clone();
@@ -160,17 +161,10 @@ impl Rng {
160161
(ret_num % (u32::MAX as u64)) as u32
161162
}
162163

163-
/// Picks a element from a sequence at random.
164+
/// Picks an element from a sequence at random.
164165
pub fn choose<T: Clone>(&mut self, a: &[T]) -> T {
165-
// Randomly select based on which calculation comes up as 0 first
166-
// since i = 0 will force j = 0, this will always return an element
167-
for i in (0..=(a.len() - 1)).rev() {
168-
let j = (self.random() * i as f64).floor() as usize;
169-
if j == 0 {
170-
return a[i].clone();
171-
}
172-
}
173-
a[0].clone()
166+
let i = (self.random() * (a.len() - 1) as f64).floor() as usize;
167+
a[i].clone()
174168
}
175169
}
176170

@@ -223,26 +217,28 @@ pub struct DiscreteDistribution {
223217
}
224218

225219
impl DiscreteDistribution {
226-
/// Constructs a new discrete distribution.
227-
pub fn new<T>(w_vec: &Vec<T>, degenerate: bool) -> Self
220+
/// Constructs a new discrete distribution with the probability masses defined by `prob_mass`.
221+
pub fn new<T>(prob_mass: &[T], degenerate: bool) -> Self
228222
where
229223
f64: From<T>,
230224
T: Copy,
231-
T: Into<f64> + Copy,
232225
{
233-
// let's first convert weights to f64
234-
let mut w_vec_64: Vec<f64> = Vec::with_capacity(w_vec.len());
235-
for number in w_vec {
236-
w_vec_64.push(f64::from(*number).into());
237-
}
238226
let cumulative_probability = if !degenerate {
239-
let sum_weights: f64 = w_vec_64.iter().sum();
240-
let mut normalized_weights = Vec::with_capacity(w_vec.len());
241-
// we no longer need the w_vec_64 after this, so we consume it
242-
for weight in w_vec_64 {
243-
normalized_weights.push(weight / sum_weights);
244-
}
245-
cumulative_sum(&mut normalized_weights)
227+
// Convert probability masses to floats.
228+
let mut prob_mass: Vec<f64> = prob_mass.iter().map(|&w| f64::from(w)).collect();
229+
230+
// Normalize them.
231+
let sum: f64 = prob_mass.iter().sum();
232+
prob_mass = prob_mass.iter().map(|&x| x / sum).collect();
233+
234+
// Calculate the their cumulative sum.
235+
prob_mass
236+
.iter()
237+
.scan(0.0, |acc, &x| {
238+
*acc += x;
239+
Some(*acc)
240+
})
241+
.collect::<Vec<_>>()
246242
} else {
247243
vec![1.0]
248244
};
@@ -275,50 +271,47 @@ impl DiscreteDistribution {
275271
}
276272
}
277273

278-
// Calculates cumulative sum.
279-
fn cumulative_sum(a: &[f64]) -> Vec<f64> {
280-
let mut acc = 0.0;
281-
let mut cumvec = Vec::with_capacity(a.len());
282-
for x in a {
283-
acc += *x;
284-
cumvec.push(acc);
285-
}
286-
cumvec
287-
}
288-
289274
#[cfg(test)]
290275
mod tests {
291276
use super::*;
277+
292278
#[test]
293-
fn test_discrete_distribution() {
294-
let weights: Vec<f64> = vec![1.1, 2.0, 1.0, 8.0, 0.2, 2.0];
295-
let d = DiscreteDistribution::new(&weights, false);
279+
fn test_discrete_distribution_sample() {
296280
let mut rng = Rng::from_seed(vec![
297281
"Hello".to_string(),
298282
"Cruel".to_string(),
299283
"World".to_string(),
300284
]);
301-
let x = d.sample(&mut rng);
302-
assert_eq!(x, 5);
303-
let x = d.sample(&mut rng);
304-
assert_eq!(x, 3);
305-
let x = d.sample(&mut rng);
306-
assert_eq!(x, 3);
307-
let x = d.sample(&mut rng);
308-
assert_eq!(x, 1);
285+
let weights: Vec<f64> = vec![1.1, 2.0, 1.0, 8.0, 0.2, 2.0];
286+
let d = DiscreteDistribution::new(&weights, false);
287+
assert_eq!(d.sample(&mut rng), 5);
288+
assert_eq!(d.sample(&mut rng), 3);
289+
assert_eq!(d.sample(&mut rng), 3);
309290
}
310291

311292
#[test]
312-
fn test_gen_bool() {
293+
#[should_panic]
294+
fn test_normal_distribution_panic() {
295+
let _n = NormalDistribution::new(0.0, 0.0);
296+
}
297+
298+
#[test]
299+
fn test_normal_distribution_inverse_cdf() {
300+
let n = NormalDistribution::new(0.0, 1.0);
301+
assert_eq!(n.inverse_cdf(0.5), 0.0);
302+
}
303+
304+
#[test]
305+
fn test_normal_distribution_sample() {
313306
let mut rng = Rng::from_seed(vec![
314307
"Hello".to_string(),
315308
"Cruel".to_string(),
316309
"World".to_string(),
317310
]);
318-
assert_eq!(rng.gen_bool(0.5), false);
319-
assert_eq!(rng.gen_bool(0.5), true);
320-
assert_eq!(rng.gen_bool(0.5), false);
321-
assert_eq!(rng.gen_bool(0.5), true);
311+
let n = NormalDistribution::new(0.0, 1.0);
312+
assert_eq!(n.sample(&mut rng), 1.181326790364034);
313+
assert_eq!(n.sample(&mut rng), -0.005968986117806898);
314+
assert_eq!(n.sample(&mut rng), 0.11272207687563508);
322315
}
323316

324317
#[test]
@@ -328,41 +321,47 @@ mod tests {
328321
"cruel".to_string(),
329322
"world".to_string(),
330323
]);
331-
let test = rng.random();
332-
assert_eq!(test, 0.8797469853889197);
333-
let test2 = rng.random();
334-
assert_eq!(test2, 0.5001547893043607);
335-
let test3 = rng.random();
336-
assert_eq!(test3, 0.6195652585010976);
324+
assert_eq!(rng.random(), 0.8797469853889197);
325+
assert_eq!(rng.random(), 0.5001547893043607);
326+
assert_eq!(rng.random(), 0.6195652585010976);
327+
}
328+
329+
#[test]
330+
fn test_gen_bool() {
331+
let mut rng = Rng::from_seed(vec![
332+
"Hello".to_string(),
333+
"Cruel".to_string(),
334+
"World".to_string(),
335+
]);
336+
assert_eq!(rng.gen_bool(0.5), false);
337+
assert_eq!(rng.gen_bool(0.5), true);
338+
assert_eq!(rng.gen_bool(0.5), false);
339+
assert_eq!(rng.gen_bool(0.5), true);
337340
}
338341

339342
#[test]
340343
fn test_shuffle() {
341-
let mut my_vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
344+
let mut vec = vec![1, 2, 3, 4, 5, 6, 7, 8, 9];
342345
let mut rng = Rng::from_seed(vec![
343346
"hello".to_string(),
344347
"cruel".to_string(),
345348
"world".to_string(),
346349
]);
347-
rng.shuffle(&mut my_vec);
348-
assert_eq!(my_vec, vec![5, 7, 6, 3, 2, 1, 9, 4, 8]);
349-
rng.shuffle(&mut my_vec);
350-
assert_eq!(my_vec, vec![4, 9, 1, 8, 3, 7, 2, 6, 5]);
350+
rng.shuffle(&mut vec);
351+
assert_eq!(vec, vec![5, 7, 6, 3, 2, 1, 9, 4, 8]);
352+
rng.shuffle(&mut vec);
353+
assert_eq!(vec, vec![8, 1, 4, 9, 7, 3, 6, 5, 2]);
351354
}
352355

353356
#[test]
354-
fn test_range() {
355-
let min = 0;
356-
let max = 10;
357+
fn test_range_i64() {
357358
let mut rng = Rng::from_seed(vec![
358359
"hello".to_string(),
359360
"cruel".to_string(),
360361
"world".to_string(),
361362
]);
362-
let num = rng.range_i64(min, max);
363-
assert_eq!(num, 8);
364-
let num2 = rng.range_i64(min, max);
365-
assert_eq!(num2, 5);
363+
assert_eq!(rng.range_i64(0, 10), 8);
364+
assert_eq!(rng.range_i64(0, 10), 5);
366365
assert_eq!(rng.range_i64(-3, 5), 1);
367366
}
368367

@@ -377,4 +376,29 @@ mod tests {
377376
assert_eq!(rng.rand_u64(), 9226227395537666048);
378377
assert_eq!(rng.rand_u64(), 11428961760531447808);
379378
}
379+
380+
#[test]
381+
fn test_rand_u32() {
382+
let mut rng = Rng::from_seed(vec![
383+
"hello".to_string(),
384+
"cruel".to_string(),
385+
"world".to_string(),
386+
]);
387+
assert_eq!(rng.rand_u32(), 3778484531);
388+
assert_eq!(rng.rand_u32(), 2148148463);
389+
assert_eq!(rng.rand_u32(), 2661012523);
390+
}
391+
392+
#[test]
393+
fn test_choose() {
394+
let mut rng = Rng::from_seed(vec![
395+
"hello".to_string(),
396+
"cruel".to_string(),
397+
"world".to_string(),
398+
]);
399+
let vec: Vec<u32> = (0..10).collect();
400+
assert_eq!(rng.choose(&vec), 7);
401+
assert_eq!(rng.choose(&vec), 4);
402+
assert_eq!(rng.choose(&vec), 5);
403+
}
380404
}

0 commit comments

Comments
 (0)