Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: JDouble, seed: Int): JavaDoubleRDD =
def sample(withReplacement: Boolean, fraction: JDouble): JavaDoubleRDD =
sample(withReplacement, fraction, System.nanoTime)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: JDouble, seed: Long): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
sample(withReplacement, fraction, System.nanoTime)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's import to use Int instead of Long. Since current code is wrote against Int. If we change to Long, the old code using the sample api cannot be compiled because of type mismatch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long is going to be more standard; certainly Java uses long seeds in its APIs. It gives more bits of seed, which is good too. There may be some API changes but anyone calling with an Int seed should be able to call an API with a Long seed right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. It is more standard to use Long. But we just need to rewrite some code.
We should ask about @mateiz or @pwendell for advice. Maybe they chosen Int for some reason we don't know.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have deprecated overloaded methods for backward compatibility if needed.

new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))

/**
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,13 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
def sample(withReplacement: Boolean, fraction: Double): JavaRDD[T] =
sample(withReplacement, fraction, System.nanoTime)

/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
def takeSample(withReplacement: Boolean, num: Int): JList[T] =
takeSample(withReplacement, num, System.nanoTime)

def takeSample(withReplacement: Boolean, num: Int, seed: Long): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
new java.util.ArrayList(arr)
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ abstract class RDD[T: ClassTag](
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] = {
def sample(withReplacement: Boolean, fraction: Double, seed: Long = System.nanoTime): RDD[T] = {
require(fraction >= 0.0, "Invalid fraction value: " + fraction)
if (withReplacement) {
new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
Expand All @@ -346,7 +346,7 @@ abstract class RDD[T: ClassTag](
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
def takeSample(withReplacement: Boolean, num: Int, seed: Long = System.nanoTime): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
Expand Down
19 changes: 19 additions & 0 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
test("takeSample") {
val data = sc.parallelize(1 to 100, 2)

for (num <- List(5,20,100)) {
val sample = data.takeSample(withReplacement=false, num=num)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there might be some tab character weirdness going on, because these statements don't line up correctly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@advancedxy java.util.Random the default seed is a function of System.nanoTime ( at least in the openjdk code ). In python its based on time.time(). python time.time() is at the millisecond precision. Not sure if there is a python method to get nanoTime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@smartnut007 based on https://docs.python.org/2/library/time.html#time.time , time.time() is at the second precision. But since it returns float, I think we can use long(time.time() * 10**9) to get the nanoTime precision.

assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
Expand All @@ -483,6 +489,19 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=100)
assert(sample.size === 100) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 100, seed)
assert(sample.size === 100) // Got exactly 100 elements
Expand Down
13 changes: 6 additions & 7 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from threading import Thread
import warnings
import heapq
import random

from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand Down Expand Up @@ -332,7 +333,7 @@ def distinct(self):
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)

def sample(self, withReplacement, fraction, seed):
def sample(self, withReplacement, fraction, seed=None):
"""
Return a sampled subset of this RDD (relies on numpy and falls back
on default random generator if numpy is unavailable).
Expand All @@ -344,7 +345,7 @@ def sample(self, withReplacement, fraction, seed):
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)

# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed):
def takeSample(self, withReplacement, num, seed=None):
"""
Return a fixed-size sampled subset of this RDD (currently requires numpy).

Expand Down Expand Up @@ -381,13 +382,11 @@ def takeSample(self, withReplacement, num, seed):
# If the first sample didn't turn out large enough, keep trying to take samples;
# this shouldn't happen often because we use a big multiplier for their initial size.
# See: scala/spark/RDD.scala
random.seed(seed)
while len(samples) < total:
if seed > sys.maxint - 2:
seed = -1
seed += 1
samples = self.sample(withReplacement, fraction, seed).collect()
samples = self.sample(withReplacement, fraction, random.randint(0,sys.maxint)).collect()

sampler = RDDSampler(withReplacement, fraction, seed+1)
sampler = RDDSampler(withReplacement, fraction, random.randint(0,sys.maxint))
sampler.shuffle(samples)
return samples[0:total]

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/rddsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import random

class RDDSampler(object):
def __init__(self, withReplacement, fraction, seed):
def __init__(self, withReplacement, fraction, seed=None):
try:
import numpy
self._use_numpy = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode {
def references = Set.empty
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: LogicalPlan)
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: LogicalPlan)
extends UnaryNode {

def output = child.output
Expand Down
5 changes: 3 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,11 @@ class SchemaRDD(
* @group Query
*/
@Experimental
override
def sample(
fraction: Double,
withReplacement: Boolean = true,
seed: Int = (math.random * 1000).toInt) =
fraction: Double,
seed: Long) =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you intend to remove the default behavior here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scala does not allow multiple overloaded methods to have default params. So, if we makde seed default in RDD.sample, then this had to be modified. So, modified it in a standard way. Also, I believe the author intended to overrride RDD.sample and not overload. More details on the PR comment.

new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan))

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
}
}

case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan)
case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: SparkPlan)
extends UnaryNode {

override def output = child.output
Expand Down