Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 32 additions & 37 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -469,50 +469,45 @@ abstract class RDD[T: ClassTag](
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
// TODO: rewrite this without return statements so we can wrap it in a scope
def takeSample(
withReplacement: Boolean,
num: Int,
seed: Long = Utils.random.nextLong): Array[T] = {
seed: Long = Utils.random.nextLong): Array[T] = withScope {
val numStDev = 10.0

if (num < 0) {
throw new IllegalArgumentException("Negative number of elements requested")
} else if (num == 0) {
return new Array[T](0)
}

val initialCount = this.count()
if (initialCount == 0) {
return new Array[T](0)
}

val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt
if (num > maxSampleSize) {
throw new IllegalArgumentException("Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")
}

val rand = new Random(seed)
if (!withReplacement && num >= initialCount) {
return Utils.randomizeInPlace(this.collect(), rand)
}

val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)

var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
require(num >= 0, "Negative number of elements requested")
require(num <= (Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt),
"Cannot support a sample size > Int.MaxValue - " +
s"$numStDev * math.sqrt(Int.MaxValue)")

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
var numIters = 0
while (samples.length < num) {
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
numIters += 1
if (num == 0) {
new Array[T](0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, the only difference here is that if num == 0 we still do a count(), whereas before we just return quickly. I think we should preserve the old behavior even though it adds another layer of nesting and unfortunately makes the code harder to read.

} else {
val initialCount = this.count()
if(initialCount ==0)
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two lines have a lot of style errors... they should be

if (initialCount == 0) {
  new Array[T](0)
} else {
  ...
}

new Array[T](0)
} else {
val rand = new Random(seed)
if (!withReplacement && num >= initialCount) {
Utils.randomizeInPlace(this.collect(), rand)
} else {
val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount,
withReplacement)
var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()

// If the first sample didn't turn out large enough, keep trying to take samples;
// this shouldn't happen often because we use a big multiplier for the initial size
var numIters = 0
while (samples.length < num) {
logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters")
samples = this.sample(withReplacement, fraction, rand.nextInt()).collect()
numIters += 1
}
Utils.randomizeInPlace(samples, rand).take(num)
}
}
}

Utils.randomizeInPlace(samples, rand).take(num)
}

/**
Expand Down