@@ -107,33 +107,34 @@ impl Rng {
107
107
self . s2
108
108
}
109
109
110
- /// Returns a bool with a probability `frac ` of being true.
110
+ /// Returns a bool with a probability `p ` of being true.
111
111
///
112
112
/// # Panics
113
113
///
114
- /// If `frac` does *not* fall in [0, 1] range.
115
- pub fn gen_bool ( & mut self , frac : f64 ) -> bool {
116
- // Uses a Bernoulli distribution to generate a fractional probability
117
- // Then an input from the RNG to sample
118
- //
119
- // This method is lifted from rand 0.85, but with restrictions edited so we can use our
120
- // custom RNG
121
- if !( 0.0 ..1.0 ) . contains ( & frac) {
122
- if frac == 1.0 {
114
+ /// If `p` does *not* fall in range [0, 1].
115
+
116
+ // This method is lifted from rand 0.85, but with restrictions edited so we can use our
117
+ // custom RNG
118
+ pub fn gen_bool ( & mut self , p : f64 ) -> bool {
119
+ if !( 0.0 ..1.0 ) . contains ( & p) {
120
+ if p == 1.0 {
123
121
return true ;
124
122
}
125
- panic ! ( "Invalid frac for gen_bool {} (must be in [0.0, 1.0) " , frac )
123
+ panic ! ( "p={:?} is outside range [0.0, 1.0] " , p ) ;
126
124
}
125
+
127
126
// This is just `2.0.powi(64)`, but written this way because it is not available
128
- // in `no_std` mode. (from rand 0.8.5 docs). We used u64 max + 1, the equivalent
129
- let p_int = ( frac * ( u64:: MAX as f64 + 1.0 ) ) as u64 ;
127
+ // in `no_std` mode. (from rand 0.8.5 docs). We used u64 max + 1, the equivalent.
128
+ let scale = u64:: MAX as f64 + 1.0 ;
129
+
130
+ let p_int = ( p * scale) as u64 ;
130
131
let x = self . rand_u64 ( ) ;
131
132
x < p_int
132
133
}
133
134
134
135
/// Shuffles elements in a sequence in place.
135
136
pub fn shuffle < T : Clone > ( & mut self , a : & mut Vec < T > ) {
136
- for i in ( 0 ..a. len ( ) ) . rev ( ) {
137
+ for i in ( 1 ..a. len ( ) ) . rev ( ) {
137
138
let j = ( self . random ( ) * i as f64 ) . floor ( ) as usize ;
138
139
let temp = a[ i] . clone ( ) ;
139
140
a[ i] = a[ j] . clone ( ) ;
@@ -160,17 +161,10 @@ impl Rng {
160
161
( ret_num % ( u32:: MAX as u64 ) ) as u32
161
162
}
162
163
163
- /// Picks a element from a sequence at random.
164
+ /// Picks an element from a sequence at random.
164
165
pub fn choose < T : Clone > ( & mut self , a : & [ T ] ) -> T {
165
- // Randomly select based on which calculation comes up as 0 first
166
- // since i = 0 will force j = 0, this will always return an element
167
- for i in ( 0 ..=( a. len ( ) - 1 ) ) . rev ( ) {
168
- let j = ( self . random ( ) * i as f64 ) . floor ( ) as usize ;
169
- if j == 0 {
170
- return a[ i] . clone ( ) ;
171
- }
172
- }
173
- a[ 0 ] . clone ( )
166
+ let i = ( self . random ( ) * ( a. len ( ) - 1 ) as f64 ) . floor ( ) as usize ;
167
+ a[ i] . clone ( )
174
168
}
175
169
}
176
170
@@ -223,26 +217,28 @@ pub struct DiscreteDistribution {
223
217
}
224
218
225
219
impl DiscreteDistribution {
226
- /// Constructs a new discrete distribution.
227
- pub fn new < T > ( w_vec : & Vec < T > , degenerate : bool ) -> Self
220
+ /// Constructs a new discrete distribution with the probability masses defined by `prob_mass` .
221
+ pub fn new < T > ( prob_mass : & [ T ] , degenerate : bool ) -> Self
228
222
where
229
223
f64 : From < T > ,
230
224
T : Copy ,
231
- T : Into < f64 > + Copy ,
232
225
{
233
- // let's first convert weights to f64
234
- let mut w_vec_64: Vec < f64 > = Vec :: with_capacity ( w_vec. len ( ) ) ;
235
- for number in w_vec {
236
- w_vec_64. push ( f64:: from ( * number) . into ( ) ) ;
237
- }
238
226
let cumulative_probability = if !degenerate {
239
- let sum_weights: f64 = w_vec_64. iter ( ) . sum ( ) ;
240
- let mut normalized_weights = Vec :: with_capacity ( w_vec. len ( ) ) ;
241
- // we no longer need the w_vec_64 after this, so we consume it
242
- for weight in w_vec_64 {
243
- normalized_weights. push ( weight / sum_weights) ;
244
- }
245
- cumulative_sum ( & mut normalized_weights)
227
+ // Convert probability masses to floats.
228
+ let mut prob_mass: Vec < f64 > = prob_mass. iter ( ) . map ( |& w| f64:: from ( w) ) . collect ( ) ;
229
+
230
+ // Normalize them.
231
+ let sum: f64 = prob_mass. iter ( ) . sum ( ) ;
232
+ prob_mass = prob_mass. iter ( ) . map ( |& x| x / sum) . collect ( ) ;
233
+
234
+ // Calculate the their cumulative sum.
235
+ prob_mass
236
+ . iter ( )
237
+ . scan ( 0.0 , |acc, & x| {
238
+ * acc += x;
239
+ Some ( * acc)
240
+ } )
241
+ . collect :: < Vec < _ > > ( )
246
242
} else {
247
243
vec ! [ 1.0 ]
248
244
} ;
@@ -275,50 +271,47 @@ impl DiscreteDistribution {
275
271
}
276
272
}
277
273
278
- // Calculates cumulative sum.
279
- fn cumulative_sum ( a : & [ f64 ] ) -> Vec < f64 > {
280
- let mut acc = 0.0 ;
281
- let mut cumvec = Vec :: with_capacity ( a. len ( ) ) ;
282
- for x in a {
283
- acc += * x;
284
- cumvec. push ( acc) ;
285
- }
286
- cumvec
287
- }
288
-
289
274
#[ cfg( test) ]
290
275
mod tests {
291
276
use super :: * ;
277
+
292
278
#[ test]
293
- fn test_discrete_distribution ( ) {
294
- let weights: Vec < f64 > = vec ! [ 1.1 , 2.0 , 1.0 , 8.0 , 0.2 , 2.0 ] ;
295
- let d = DiscreteDistribution :: new ( & weights, false ) ;
279
+ fn test_discrete_distribution_sample ( ) {
296
280
let mut rng = Rng :: from_seed ( vec ! [
297
281
"Hello" . to_string( ) ,
298
282
"Cruel" . to_string( ) ,
299
283
"World" . to_string( ) ,
300
284
] ) ;
301
- let x = d. sample ( & mut rng) ;
302
- assert_eq ! ( x, 5 ) ;
303
- let x = d. sample ( & mut rng) ;
304
- assert_eq ! ( x, 3 ) ;
305
- let x = d. sample ( & mut rng) ;
306
- assert_eq ! ( x, 3 ) ;
307
- let x = d. sample ( & mut rng) ;
308
- assert_eq ! ( x, 1 ) ;
285
+ let weights: Vec < f64 > = vec ! [ 1.1 , 2.0 , 1.0 , 8.0 , 0.2 , 2.0 ] ;
286
+ let d = DiscreteDistribution :: new ( & weights, false ) ;
287
+ assert_eq ! ( d. sample( & mut rng) , 5 ) ;
288
+ assert_eq ! ( d. sample( & mut rng) , 3 ) ;
289
+ assert_eq ! ( d. sample( & mut rng) , 3 ) ;
309
290
}
310
291
311
292
#[ test]
312
- fn test_gen_bool ( ) {
293
+ #[ should_panic]
294
+ fn test_normal_distribution_panic ( ) {
295
+ let _n = NormalDistribution :: new ( 0.0 , 0.0 ) ;
296
+ }
297
+
298
+ #[ test]
299
+ fn test_normal_distribution_inverse_cdf ( ) {
300
+ let n = NormalDistribution :: new ( 0.0 , 1.0 ) ;
301
+ assert_eq ! ( n. inverse_cdf( 0.5 ) , 0.0 ) ;
302
+ }
303
+
304
+ #[ test]
305
+ fn test_normal_distribution_sample ( ) {
313
306
let mut rng = Rng :: from_seed ( vec ! [
314
307
"Hello" . to_string( ) ,
315
308
"Cruel" . to_string( ) ,
316
309
"World" . to_string( ) ,
317
310
] ) ;
318
- assert_eq ! ( rng . gen_bool ( 0.5 ) , false ) ;
319
- assert_eq ! ( rng . gen_bool ( 0.5 ) , true ) ;
320
- assert_eq ! ( rng . gen_bool ( 0.5 ) , false ) ;
321
- assert_eq ! ( rng . gen_bool ( 0.5 ) , true ) ;
311
+ let n = NormalDistribution :: new ( 0.0 , 1.0 ) ;
312
+ assert_eq ! ( n . sample ( & mut rng ) , 1.181326790364034 ) ;
313
+ assert_eq ! ( n . sample ( & mut rng ) , - 0.005968986117806898 ) ;
314
+ assert_eq ! ( n . sample ( & mut rng ) , 0.11272207687563508 ) ;
322
315
}
323
316
324
317
#[ test]
@@ -328,41 +321,47 @@ mod tests {
328
321
"cruel" . to_string( ) ,
329
322
"world" . to_string( ) ,
330
323
] ) ;
331
- let test = rng. random ( ) ;
332
- assert_eq ! ( test, 0.8797469853889197 ) ;
333
- let test2 = rng. random ( ) ;
334
- assert_eq ! ( test2, 0.5001547893043607 ) ;
335
- let test3 = rng. random ( ) ;
336
- assert_eq ! ( test3, 0.6195652585010976 ) ;
324
+ assert_eq ! ( rng. random( ) , 0.8797469853889197 ) ;
325
+ assert_eq ! ( rng. random( ) , 0.5001547893043607 ) ;
326
+ assert_eq ! ( rng. random( ) , 0.6195652585010976 ) ;
327
+ }
328
+
329
+ #[ test]
330
+ fn test_gen_bool ( ) {
331
+ let mut rng = Rng :: from_seed ( vec ! [
332
+ "Hello" . to_string( ) ,
333
+ "Cruel" . to_string( ) ,
334
+ "World" . to_string( ) ,
335
+ ] ) ;
336
+ assert_eq ! ( rng. gen_bool( 0.5 ) , false ) ;
337
+ assert_eq ! ( rng. gen_bool( 0.5 ) , true ) ;
338
+ assert_eq ! ( rng. gen_bool( 0.5 ) , false ) ;
339
+ assert_eq ! ( rng. gen_bool( 0.5 ) , true ) ;
337
340
}
338
341
339
342
#[ test]
340
343
fn test_shuffle ( ) {
341
- let mut my_vec = vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ;
344
+ let mut vec = vec ! [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] ;
342
345
let mut rng = Rng :: from_seed ( vec ! [
343
346
"hello" . to_string( ) ,
344
347
"cruel" . to_string( ) ,
345
348
"world" . to_string( ) ,
346
349
] ) ;
347
- rng. shuffle ( & mut my_vec ) ;
348
- assert_eq ! ( my_vec , vec![ 5 , 7 , 6 , 3 , 2 , 1 , 9 , 4 , 8 ] ) ;
349
- rng. shuffle ( & mut my_vec ) ;
350
- assert_eq ! ( my_vec , vec![ 4 , 9 , 1 , 8 , 3 , 7 , 2 , 6 , 5 ] ) ;
350
+ rng. shuffle ( & mut vec ) ;
351
+ assert_eq ! ( vec , vec![ 5 , 7 , 6 , 3 , 2 , 1 , 9 , 4 , 8 ] ) ;
352
+ rng. shuffle ( & mut vec ) ;
353
+ assert_eq ! ( vec , vec![ 8 , 1 , 4 , 9 , 7 , 3 , 6 , 5 , 2 ] ) ;
351
354
}
352
355
353
356
#[ test]
354
- fn test_range ( ) {
355
- let min = 0 ;
356
- let max = 10 ;
357
+ fn test_range_i64 ( ) {
357
358
let mut rng = Rng :: from_seed ( vec ! [
358
359
"hello" . to_string( ) ,
359
360
"cruel" . to_string( ) ,
360
361
"world" . to_string( ) ,
361
362
] ) ;
362
- let num = rng. range_i64 ( min, max) ;
363
- assert_eq ! ( num, 8 ) ;
364
- let num2 = rng. range_i64 ( min, max) ;
365
- assert_eq ! ( num2, 5 ) ;
363
+ assert_eq ! ( rng. range_i64( 0 , 10 ) , 8 ) ;
364
+ assert_eq ! ( rng. range_i64( 0 , 10 ) , 5 ) ;
366
365
assert_eq ! ( rng. range_i64( -3 , 5 ) , 1 ) ;
367
366
}
368
367
@@ -377,4 +376,29 @@ mod tests {
377
376
assert_eq ! ( rng. rand_u64( ) , 9226227395537666048 ) ;
378
377
assert_eq ! ( rng. rand_u64( ) , 11428961760531447808 ) ;
379
378
}
379
+
380
+ #[ test]
381
+ fn test_rand_u32 ( ) {
382
+ let mut rng = Rng :: from_seed ( vec ! [
383
+ "hello" . to_string( ) ,
384
+ "cruel" . to_string( ) ,
385
+ "world" . to_string( ) ,
386
+ ] ) ;
387
+ assert_eq ! ( rng. rand_u32( ) , 3778484531 ) ;
388
+ assert_eq ! ( rng. rand_u32( ) , 2148148463 ) ;
389
+ assert_eq ! ( rng. rand_u32( ) , 2661012523 ) ;
390
+ }
391
+
392
+ #[ test]
393
+ fn test_choose ( ) {
394
+ let mut rng = Rng :: from_seed ( vec ! [
395
+ "hello" . to_string( ) ,
396
+ "cruel" . to_string( ) ,
397
+ "world" . to_string( ) ,
398
+ ] ) ;
399
+ let vec: Vec < u32 > = ( 0 ..10 ) . collect ( ) ;
400
+ assert_eq ! ( rng. choose( & vec) , 7 ) ;
401
+ assert_eq ! ( rng. choose( & vec) , 4 ) ;
402
+ assert_eq ! ( rng. choose( & vec) , 5 ) ;
403
+ }
380
404
}
0 commit comments