Skip to content

Commit 1441977

Browse files
committed
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS
1 parent 60b89fe commit 1441977

File tree

7 files changed

+100
-22
lines changed

7 files changed

+100
-22
lines changed

core/pom.xml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@
6767
<groupId>org.apache.commons</groupId>
6868
<artifactId>commons-lang3</artifactId>
6969
</dependency>
70+
<dependency>
71+
<groupId>org.apache.commons</groupId>
72+
<artifactId>commons-math3</artifactId>
73+
</dependency>
7074
<dependency>
7175
<groupId>com.google.code.findbugs</groupId>
7276
<artifactId>jsr305</artifactId>

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag](
379379
}.toArray
380380
}
381381

382-
def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
383-
{
382+
/**
383+
* Return a fixed-size sampled subset of this RDD in an array
384+
*
385+
* @param withReplacement whether sampling is done with replacement
386+
* @param num size of the returned sample
387+
* @param seed seed for the random number generator
388+
* @return sample of specified size in an array
389+
*/
390+
def takeSample(withReplacement: Boolean,
391+
num: Int,
392+
seed: Long = Utils.random.nextLong): Array[T] = {
384393
var fraction = 0.0
385394
var total = 0
386395
val multiplier = 3.0
@@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag](
402411
}
403412

404413
if (num > initialCount && !withReplacement) {
414+
// special case not covered in computeFraction
405415
total = maxSelected
406416
fraction = multiplier * (maxSelected + 1) / initialCount
407417
} else {
408-
fraction = multiplier * (num + 1) / initialCount
418+
fraction = computeFraction(num, initialCount, withReplacement)
409419
total = num
410420
}
411421

@@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag](
421431
Utils.randomizeInPlace(samples, rand).take(total)
422432
}
423433

434+
private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = {
435+
val fraction = num.toDouble / total
436+
if (withReplacement) {
437+
var numStDev = 5
438+
if (num < 12) {
439+
// special case to guarantee sample size for small s
440+
numStDev = 9
441+
}
442+
fraction + numStDev * math.sqrt(fraction / total)
443+
} else {
444+
val delta = 0.00005
445+
val gamma = - math.log(delta)/total
446+
math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
447+
}
448+
}
449+
424450
/**
425451
* Return the union of this RDD and another one. Any identical elements will appear multiple
426452
* times (use `.distinct()` to eliminate them).

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
7070
}
7171

7272
/**
73-
* Return a sampler with is the complement of the range specified of the current sampler.
73+
* Return a sampler which is the complement of the range specified of the current sampler.
7474
*/
7575
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
7676

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
2222

2323
import org.scalatest.FunSuite
2424

25+
import org.apache.commons.math3.distribution.PoissonDistribution
2526
import org.apache.spark._
2627
import org.apache.spark.SparkContext._
2728
import org.apache.spark.rdd._
@@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext {
494495
assert(sortedTopK === nums.sorted(ord).take(5))
495496
}
496497

498+
test("computeFraction") {
499+
// test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
500+
val data = new EmptyRDD[Int](sc)
501+
val n = 100000
502+
503+
for (s <- 1 to 15) {
504+
val frac = data.computeFraction(s, n, true)
505+
val qpois = new PoissonDistribution(frac * n)
506+
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
507+
}
508+
for (s <- 1 to 15) {
509+
val frac = data.computeFraction(s, n, false)
510+
val qpois = new PoissonDistribution(frac * n)
511+
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
512+
}
513+
for (s <- List(1, 10, 100, 1000)) {
514+
val frac = data.computeFraction(s, n, true)
515+
val qpois = new PoissonDistribution(frac * n)
516+
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
517+
}
518+
for (s <- List(1, 10, 100, 1000)) {
519+
val frac = data.computeFraction(s, n, false)
520+
val qpois = new PoissonDistribution(frac * n)
521+
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
522+
}
523+
}
524+
497525
test("takeSample") {
498-
val data = sc.parallelize(1 to 100, 2)
526+
val n = 1000000
527+
val data = sc.parallelize(1 to n, 2)
499528

500529
for (num <- List(5, 20, 100)) {
501530
val sample = data.takeSample(withReplacement=false, num=num)
502531
assert(sample.size === num) // Got exactly num elements
503532
assert(sample.toSet.size === num) // Elements are distinct
504-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
533+
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
505534
}
506535
for (seed <- 1 to 5) {
507536
val sample = data.takeSample(withReplacement=false, 20, seed)
508537
assert(sample.size === 20) // Got exactly 20 elements
509538
assert(sample.toSet.size === 20) // Elements are distinct
510-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
539+
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
511540
}
512541
for (seed <- 1 to 5) {
513-
val sample = data.takeSample(withReplacement=false, 200, seed)
542+
val sample = data.takeSample(withReplacement=false, 100, seed)
514543
assert(sample.size === 100) // Got only 100 elements
515544
assert(sample.toSet.size === 100) // Elements are distinct
516-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
545+
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
517546
}
518547
for (seed <- 1 to 5) {
519548
val sample = data.takeSample(withReplacement=true, 20, seed)
520549
assert(sample.size === 20) // Got exactly 20 elements
521-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
550+
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
522551
}
523552
{
524553
val sample = data.takeSample(withReplacement=true, num=20)
525554
assert(sample.size === 20) // Got exactly 100 elements
526555
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
527-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
556+
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
528557
}
529558
{
530-
val sample = data.takeSample(withReplacement=true, num=100)
531-
assert(sample.size === 100) // Got exactly 100 elements
559+
val sample = data.takeSample(withReplacement=true, num=n)
560+
assert(sample.size === n) // Got exactly 100 elements
532561
// Chance of getting all distinct elements is astronomically low, so test we got < 100
533-
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
534-
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
562+
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
563+
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
535564
}
536565
for (seed <- 1 to 5) {
537-
val sample = data.takeSample(withReplacement=true, 100, seed)
538-
assert(sample.size === 100) // Got exactly 100 elements
566+
val sample = data.takeSample(withReplacement=true, n, seed)
567+
assert(sample.size === n) // Got exactly 100 elements
539568
// Chance of getting all distinct elements is astronomically low, so test we got < 100
540-
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
569+
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
541570
}
542571
for (seed <- 1 to 5) {
543-
val sample = data.takeSample(withReplacement=true, 200, seed)
544-
assert(sample.size === 200) // Got exactly 200 elements
572+
val sample = data.takeSample(withReplacement=true, 2*n, seed)
573+
assert(sample.size === 2*n) // Got exactly 200 elements
545574
// Chance of getting all distinct elements is still quite low, so test we got < 100
546-
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
575+
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
547576
}
548577
}
549578

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,11 @@
245245
<artifactId>commons-codec</artifactId>
246246
<version>1.5</version>
247247
</dependency>
248+
<dependency>
249+
<groupId>org.apache.commons</groupId>
250+
<artifactId>commons-math3</artifactId>
251+
<version>3.2</version>
252+
</dependency>
248253
<dependency>
249254
<groupId>com.google.code.findbugs</groupId>
250255
<artifactId>jsr305</artifactId>

project/SparkBuild.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ object SparkBuild extends Build {
331331
libraryDependencies ++= Seq(
332332
"com.google.guava" % "guava" % "14.0.1",
333333
"org.apache.commons" % "commons-lang3" % "3.3.2",
334+
"org.apache.commons" % "commons-math3" % "3.2",
334335
"com.google.code.findbugs" % "jsr305" % "1.3.9",
335336
"log4j" % "log4j" % "1.2.17",
336337
"org.slf4j" % "slf4j-api" % slf4jVersion,

python/pyspark/rdd.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import warnings
3232
import heapq
3333
from random import Random
34+
from math import sqrt, log, min
3435

3536
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
3637
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None):
374375
total = maxSelected
375376
fraction = multiplier * (maxSelected + 1) / initialCount
376377
else:
377-
fraction = multiplier * (num + 1) / initialCount
378+
fraction = self._computeFraction(num, initialCount, withReplacement)
378379
total = num
379380

380381
samples = self.sample(withReplacement, fraction, seed).collect()
@@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None):
390391
sampler.shuffle(samples)
391392
return samples[0:total]
392393

394+
def _computeFraction(self, num, total, withReplacement):
395+
fraction = float(num)/total
396+
if withReplacement:
397+
numStDev = 5
398+
if (num < 12):
399+
numStDev = 9
400+
return fraction + numStDev * sqrt(fraction/total)
401+
else:
402+
delta = 0.00005
403+
gamma = - log(delta)/total
404+
return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction))
405+
393406
def union(self, other):
394407
"""
395408
Return the union of this RDD and another one.

0 commit comments

Comments
 (0)