@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
2222
2323import org .scalatest .FunSuite
2424
25+ import org .apache .commons .math3 .distribution .PoissonDistribution
2526import org .apache .spark ._
2627import org .apache .spark .SparkContext ._
2728import org .apache .spark .rdd ._
@@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext {
494495 assert(sortedTopK === nums.sorted(ord).take(5 ))
495496 }
496497
498+ test(" computeFraction" ) {
499+ // test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
500+ val data = new EmptyRDD [Int ](sc)
501+ val n = 100000
502+
503+ for (s <- 1 to 15 ) {
504+ val frac = data.computeFraction(s, n, true )
505+ val qpois = new PoissonDistribution (frac * n)
506+ assert(qpois.inverseCumulativeProbability(0.0001 ) >= s, " Computed fraction is too low" )
507+ }
508+ for (s <- 1 to 15 ) {
509+ val frac = data.computeFraction(s, n, false )
510+ val qpois = new PoissonDistribution (frac * n)
511+ assert(qpois.inverseCumulativeProbability(0.0001 ) >= s, " Computed fraction is too low" )
512+ }
513+ for (s <- List (1 , 10 , 100 , 1000 )) {
514+ val frac = data.computeFraction(s, n, true )
515+ val qpois = new PoissonDistribution (frac * n)
516+ assert(qpois.inverseCumulativeProbability(0.0001 ) >= s, " Computed fraction is too low" )
517+ }
518+ for (s <- List (1 , 10 , 100 , 1000 )) {
519+ val frac = data.computeFraction(s, n, false )
520+ val qpois = new PoissonDistribution (frac * n)
521+ assert(qpois.inverseCumulativeProbability(0.0001 ) >= s, " Computed fraction is too low" )
522+ }
523+ }
524+
497525 test(" takeSample" ) {
498- val data = sc.parallelize(1 to 100 , 2 )
526+ val n = 1000000
527+ val data = sc.parallelize(1 to n, 2 )
499528
500529 for (num <- List (5 , 20 , 100 )) {
501530 val sample = data.takeSample(withReplacement= false , num= num)
502531 assert(sample.size === num) // Got exactly num elements
503532 assert(sample.toSet.size === num) // Elements are distinct
504- assert(sample.forall(x => 1 <= x && x <= 100 ), " elements not in [1, 100]" )
533+ assert(sample.forall(x => 1 <= x && x <= n ), " elements not in [1, 100]" )
505534 }
506535 for (seed <- 1 to 5 ) {
507536 val sample = data.takeSample(withReplacement= false , 20 , seed)
508537 assert(sample.size === 20 ) // Got exactly 20 elements
509538 assert(sample.toSet.size === 20 ) // Elements are distinct
510- assert(sample.forall(x => 1 <= x && x <= 100 ), " elements not in [1, 100]" )
539+ assert(sample.forall(x => 1 <= x && x <= n ), " elements not in [1, 100]" )
511540 }
512541 for (seed <- 1 to 5 ) {
513- val sample = data.takeSample(withReplacement= false , 200 , seed)
542+ val sample = data.takeSample(withReplacement= false , 100 , seed)
514543 assert(sample.size === 100 ) // Got only 100 elements
515544 assert(sample.toSet.size === 100 ) // Elements are distinct
516- assert(sample.forall(x => 1 <= x && x <= 100 ), " elements not in [1, 100]" )
545+ assert(sample.forall(x => 1 <= x && x <= n ), " elements not in [1, 100]" )
517546 }
518547 for (seed <- 1 to 5 ) {
519548 val sample = data.takeSample(withReplacement= true , 20 , seed)
520549 assert(sample.size === 20 ) // Got exactly 20 elements
521- assert(sample.forall(x => 1 <= x && x <= 100 ), " elements not in [1, 100]" )
550+ assert(sample.forall(x => 1 <= x && x <= n ), " elements not in [1, 100]" )
522551 }
523552 {
524553 val sample = data.takeSample(withReplacement= true , num= 20 )
525554 assert(sample.size === 20 ) // Got exactly 100 elements
526555 assert(sample.toSet.size <= 20 , " sampling with replacement returned all distinct elements" )
527- assert(sample.forall(x => 1 <= x && x <= 100 ), " elements not in [1, 100]" )
556+ assert(sample.forall(x => 1 <= x && x <= n ), " elements not in [1, 100]" )
528557 }
529558 {
530- val sample = data.takeSample(withReplacement= true , num= 100 )
531- assert(sample.size === 100 ) // Got exactly 100 elements
559+ val sample = data.takeSample(withReplacement= true , num= n )
560+ assert(sample.size === n ) // Got exactly 100 elements
532561 // Chance of getting all distinct elements is astronomically low, so test we got < 100
533- assert(sample.toSet.size < 100 , " sampling with replacement returned all distinct elements" )
534- assert(sample.forall(x => 1 <= x && x <= 100 ), " elements not in [1, 100]" )
562+ assert(sample.toSet.size < n , " sampling with replacement returned all distinct elements" )
563+ assert(sample.forall(x => 1 <= x && x <= n ), " elements not in [1, 100]" )
535564 }
536565 for (seed <- 1 to 5 ) {
537- val sample = data.takeSample(withReplacement= true , 100 , seed)
538- assert(sample.size === 100 ) // Got exactly 100 elements
566+ val sample = data.takeSample(withReplacement= true , n , seed)
567+ assert(sample.size === n ) // Got exactly 100 elements
539568 // Chance of getting all distinct elements is astronomically low, so test we got < 100
540- assert(sample.toSet.size < 100 , " sampling with replacement returned all distinct elements" )
569+ assert(sample.toSet.size < n , " sampling with replacement returned all distinct elements" )
541570 }
542571 for (seed <- 1 to 5 ) {
543- val sample = data.takeSample(withReplacement= true , 200 , seed)
544- assert(sample.size === 200 ) // Got exactly 200 elements
572+ val sample = data.takeSample(withReplacement= true , 2 * n , seed)
573+ assert(sample.size === 2 * n ) // Got exactly 200 elements
545574 // Chance of getting all distinct elements is still quite low, so test we got < 100
546- assert(sample.toSet.size < 100 , " sampling with replacement returned all distinct elements" )
575+ assert(sample.toSet.size < n , " sampling with replacement returned all distinct elements" )
547576 }
548577 }
549578
0 commit comments