diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 214f22bc5b603..73ba73203c50f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -392,29 +392,28 @@ abstract class RDD[T: ClassTag]( this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) }.toArray } - + /** - * Return a fixed-size sampled subset of this RDD in an array - * - * @param withReplacement whether sampling is done with replacement - * @param num size of the returned sample - * @param seed seed for the random number generator - * @return sample of specified size in an array + * Returns a fixed-size sampled subset of this RDD as an RDD + * @param withReplacement - Whether to sample with replacement (boolean) + * @param num - The number of elements to retrieve + * @param seed - A random seed for the randomization + * @return */ - def takeSample(withReplacement: Boolean, - num: Int, - seed: Long = Utils.random.nextLong): Array[T] = { + def sampleByCount(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): RDD[T] = { val numStDev = 10.0 if (num < 0) { throw new IllegalArgumentException("Negative number of elements requested") } else if (num == 0) { - return new Array[T](0) + return new EmptyRDD[T](this.sc) } val initialCount = this.count() if (initialCount == 0) { - return new Array[T](0) + return new EmptyRDD[T](this.sc) } val maxSampleSize = Int.MaxValue - (numStDev * math.sqrt(Int.MaxValue)).toInt @@ -425,26 +424,57 @@ abstract class RDD[T: ClassTag]( val rand = new Random(seed) if (!withReplacement && num >= initialCount) { - return Utils.randomizeInPlace(this.collect(), rand) + return this } + // Because sampling is stochastic, compute the sample size needed to ensure a sufficient + // number of samples with 99.99% succss rate val fraction = SamplingUtils.computeFractionForSampleSize(num, initialCount, withReplacement) - - var samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + + var samples = this.sample(withReplacement, fraction, rand.nextInt()) // If the first sample didn't turn out large enough, keep trying to take samples; // this shouldn't happen often because we use a big multiplier for the initial size var numIters = 0 - while (samples.length < num) { + var count = samples.count() + + // At this point we are guaranteed to have at least "num" samples but we may have more than + // num samples since computeFractionForSample actually yields an upper bound. + // If we have too many samples, drop un-needed ones + while (count < num) { logWarning(s"Needed to re-sample due to insufficient sample size. Repeat #$numIters") - samples = this.sample(withReplacement, fraction, rand.nextInt()).collect() + samples = this.sample(withReplacement, fraction, rand.nextInt()) numIters += 1 + count = samples.count() + } + + // After sampling is complete, we may actually have too many samples. Therefore, as the final + // step, pare down the generated list + if(count > num) { + samples = samples.zipWithIndex().filter(_._2 < num).map(_._1) } - Utils.randomizeInPlace(samples, rand).take(num) + samples } + /** + * Return a fixed-size sampled subset of this RDD in an array + * + * @param withReplacement whether sampling is done with replacement + * @param num size of the returned sample + * @param seed seed for the random number generator + * @return sample of specified size in an array + */ + def takeSample(withReplacement: Boolean, + num: Int, + seed: Long = Utils.random.nextLong): Array[T] = { + + // To maintain functionality of the previous implementation, randomize the returned + // RDD in place before returning + Utils.randomizeInPlace(sampleByCount(withReplacement, num, seed).collect(), new Random(seed)) + } + /** * Return the union of this RDD and another one. Any identical elements will appear multiple * times (use `.distinct()` to eliminate them). diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6836e9ab0fd6b..e70fef03e4dc4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.util.Random + import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag @@ -85,7 +87,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(error(simpleRdd.countApproxDistinct(8, 0), size) < 0.2) assert(error(simpleRdd.countApproxDistinct(12, 0), size) < 0.1) } - + test("SparkContext.union") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(sc.union(nums).collect().toList === List(1, 2, 3, 4)) @@ -549,6 +551,60 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sampled.partitioner === rdd.partitioner) } } + + test("sampleByCount") { + val count = 10000 + val largeSize = 1000000 + val smallSize = 50 + val data = sc.parallelize(1 to largeSize, 2) + val dataSmall = sc.parallelize(1 to smallSize, 2) + + val testCount = 10 + + val seed = System.currentTimeMillis() + val rand = new Random(seed) + + for (i <- 1 to testCount) { + // When sampling without replacement, ensure all elements are distinct and we get the right + // number. + + val sampleSize = rand.nextInt(count) + val samples = data.sampleByCount(withReplacement=false, sampleSize, seed) + assert(samples.count() == sampleSize) + assert(samples.distinct().count() == sampleSize) + + // ********************************************************************* + // When sampling with replacement, ensure we get the right + // number. + val sampleSize2 = rand.nextInt(smallSize) + smallSize + val samples2 = dataSmall.sampleByCount(withReplacement=true, sampleSize2, seed) + assert(samples2.count() == sampleSize2) + + // ********************************************************************* + // When sampling without replacement and sample more elements than there are in the array + // ensure that the appropriate number of elements are returned + // Ensure that we're requesting more elements than there are in the RDD + val sampleSize3 = rand.nextInt(smallSize) + smallSize + val samples3 = dataSmall.sampleByCount(withReplacement=false, sampleSize3, seed) + + assert(samples3.count() == smallSize) + + // Values should still be distinct because the original array is still 1:smallCount + assert(samples3.distinct().count() == smallSize) + + // ********************************************************************* + // When sampling with replacement and sample the entire array for a large count + // ensure that all elements are not distinct + val sampleSize4 = count + rand.nextInt(count) + val samples4 = data.sampleByCount(withReplacement=true, sampleSize4, seed) + + assert(samples4.count() == sampleSize4) + + // Chance of getting all distinct elements is astronomically low, confirm that this doesnt + // happen + assert(samples4.distinct().count() < sampleSize4) + } + } test("takeSample") { val n = 1000000 @@ -579,13 +635,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { } { val sample = data.takeSample(withReplacement=true, num=20) - assert(sample.size === 20) // Got exactly 100 elements + assert(sample.size === 20) // Got exactly 20 elements assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]") } { val sample = data.takeSample(withReplacement=true, num=n) - assert(sample.size === n) // Got exactly 100 elements + assert(sample.size === n) // Got exactly 1000000 elements // Chance of getting all distinct elements is astronomically low, so test we got < 100 assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") assert(sample.forall(x => 1 <= x && x <= n), s"elements not in [1, $n]")