Skip to content

Commit 99ecfa5

Browse files
vinodkcAndrew Or
authored andcommitted
[SPARK-10575] [SPARK CORE] Wrapped RDD.takeSample with Scope
Remove return statements in RDD.takeSample and wrap it withScope Author: vinodkc <[email protected]> Author: vinodkc <[email protected]> Author: Vinod K C <[email protected]> Closes #8730 from vinodkc/fix_takesample_return.
1 parent a63cdc7 commit 99ecfa5

File tree

1 file changed

+31
-37
lines changed
  • core/src/main/scala/org/apache/spark/rdd

1 file changed

+31
-37
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 31 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -469,50 +469,44 @@ abstract class RDD[T: ClassTag](
469469
* @param seed seed for the random number generator
470470
* @return sample of specified size in an array
471471
*/
472-
// TODO: rewrite this without return statements so we can wrap it in a scope
473472
def takeSample(
474473
withReplacement: Boolean,
475474
num: Int,
476-
seed: Long = Utils.random.nextLong): Array[T] = {
475+
seed: Long = Utils.random.nextLong): Array[T] = withScope {
477476
val numStDev = 10.0
478477

479-
if (num < 0) {
480-
throw new IllegalArgumentException("Negative number of elements requested")
481-
} else if (num == 0) {
482-
return new Array[T](0)
483-
}
484-
485-
val initialCount = this.count()
486-
if (initialCount == 0) {
487-
return new Array[T](0)
488-
}
489-
490-
val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
491-
if (num > maxSampleSize) {
492-
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
493-
s"$numStDev * math.sqrt(Int.MaxValue)")
494-
}
495-
496-
val rand = new Random(seed)
497-
if (!withReplacement && num >= initialCount) {
498-
return Utils.randomizeInPlace(this.collect(), rand)
499-
}
500-
501-
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
502-
withReplacement)
503-
504-
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
478+
require(num >= 0, "Negative number of elements requested")
479+
require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt),
480+
"Cannot support a sample size > Int.MaxValue - " +
481+
s"$numStDev * math.sqrt(Int.MaxValue)")
505482

506-
// If the first sample didn't turn out large enough, keep trying to take samples;
507-
// this shouldn't happen often because we use a big multiplier for the initial size
508-
var numIters = 0
509-
while (samples.length < num) {
510-
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
511-
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
512-
numIters += 1
483+
if (num == 0) {
484+
new Array[T](0)
485+
} else {
486+
val initialCount = this.count()
487+
if (initialCount == 0) {
488+
new Array[T](0)
489+
} else {
490+
val rand = new Random(seed)
491+
if (!withReplacement && num >= initialCount) {
492+
Utils.randomizeInPlace(this.collect(), rand)
493+
} else {
494+
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
495+
withReplacement)
496+
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
497+
498+
// If the first sample didn't turn out large enough, keep trying to take samples;
499+
// this shouldn't happen often because we use a big multiplier for the initial size
500+
var numIters = 0
501+
while (samples.length < num) {
502+
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
503+
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
504+
numIters += 1
505+
}
506+
Utils.randomizeInPlace(samples, rand).take(num)
507+
}
508+
}
513509
}
514-
515-
Utils.randomizeInPlace(samples, rand).take(num)
516510
}
517511

518512
/**

0 commit comments

Comments
 (0)