diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 4330cef3965e..868ed141effa 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -133,7 +133,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD = + def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD = + sample(withReplacement, fraction, System.nanoTime) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD = fromRDD(srdd.sample(withReplacement, fraction, seed)) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index b3ec270281ae..02a5e20073f8 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -119,7 +119,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] = + def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] = + sample(withReplacement, fraction, System.nanoTime) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed)) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 327c1552dc94..374ff7808c21 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -98,7 +98,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] = + def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] = + sample(withReplacement, fraction, System.nanoTime) + + /** + * Return a sampled subset of this RDD. + */ + def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] = wrapRDD(rdd.sample(withReplacement, fraction, seed)) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 725c423a53e3..bbf84ee7ffb0 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -394,7 +394,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } - def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = { + def takeSample(withReplacement: Boolean, num: Int): JList[T] = + takeSample(withReplacement, num, System.nanoTime) + + def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = { import scala.collection.JavaConversions._ val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq new java.util.ArrayList(arr) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 891efccf23b6..125f1da220fd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -321,7 +321,7 @@ abstract class RDD[T: ClassTag]( /** * Return a sampled subset of this RDD. */ - def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = { + def sample(withReplacement: Boolean, fraction: Double, seed: Long = System.nanoTime): RDD[T] = { require(fraction >= 0.0, "Invalid fraction value: " + fraction) if (withReplacement) { new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed) @@ -346,7 +346,7 @@ abstract class RDD[T: ClassTag]( }.toArray } - def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = { + def takeSample(withReplacement: Boolean, num: Int, seed: Long = System.nanoTime): Array[T] = { var fraction = 0.0 var total = 0 val multiplier = 3.0 diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 1901330d8b18..6bbaaf87cd4b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -466,6 +466,12 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("takeSample") { val data = sc.parallelize(1 to 100, 2) + for (num <- List(5,20,100)) { + val sample = data.takeSample(withReplacement=false, num=num) + assert(sample.size === num) // Got exactly num elements + assert(sample.toSet.size === num) // Elements are distinct + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=false, 20, seed) assert(sample.size === 20) // Got exactly 20 elements @@ -483,6 +489,19 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sample.size === 20) // Got exactly 20 elements assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") } + { + val sample = data.takeSample(withReplacement=true, num=20) + assert(sample.size === 20) // Got exactly 100 elements + assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } + { + val sample = data.takeSample(withReplacement=true, num=100) + assert(sample.size === 100) // Got exactly 100 elements + // Chance of getting all distinct elements is astronomically low, so test we got < 100 + assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") + assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") + } for (seed <- 1 to 5) { val sample = data.takeSample(withReplacement=true, 100, seed) assert(sample.size === 100) // Got exactly 100 elements diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 91fc7e637e2c..70a78eda8c01 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -30,6 +30,7 @@ from threading import Thread import warnings import heapq +import random from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long @@ -332,7 +333,7 @@ def distinct(self): .reduceByKey(lambda x, _: x) \ .map(lambda (x, _): x) - def sample(self, withReplacement, fraction, seed): + def sample(self, withReplacement, fraction, seed=None): """ Return a sampled subset of this RDD (relies on numpy and falls back on default random generator if numpy is unavailable). @@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed): return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) # this is ported from scala/spark/RDD.scala - def takeSample(self, withReplacement, num, seed): + def takeSample(self, withReplacement, num, seed=None): """ Return a fixed-size sampled subset of this RDD (currently requires numpy). @@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed): # If the first sample didn't turn out large enough, keep trying to take samples; # this shouldn't happen often because we use a big multiplier for their initial size. # See: scala/spark/RDD.scala + random.seed(seed) while len(samples) < total: - if seed > sys.maxint - 2: - seed = -1 - seed += 1 - samples = self.sample(withReplacement, fraction, seed).collect() + samples = self.sample(withReplacement, fraction, random.randint(0,sys.maxint)).collect() - sampler = RDDSampler(withReplacement, fraction, seed+1) + sampler = RDDSampler(withReplacement, fraction, random.randint(0,sys.maxint)) sampler.shuffle(samples) return samples[0:total] diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index aca2ef3b51e9..5ce330ee239b 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -19,7 +19,7 @@ import random class RDDSampler(object): - def __init__(self, withReplacement, fraction, seed): + def __init__(self, withReplacement, fraction, seed=None): try: import numpy self._use_numpy = True diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 397473e17886..732708e146b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -168,7 +168,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { def references = Set.empty } -case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan) +case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan) extends UnaryNode { def output = child.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index f2ae5b0fe612..6658d306c987 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -256,10 +256,11 @@ class SchemaRDD( * @group Query */ @Experimental + override def sample( - fraction: Double, withReplacement: Boolean = true, - seed: Int = (math.random * 1000).toInt) = + fraction: Double, + seed: Long) = new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index eedcc7dda02d..288094f97f9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -44,7 +44,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } } -case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan) +case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan) extends UnaryNode { override def output = child.output