-
Notifications
You must be signed in to change notification settings - Fork 29k
SPARK-1939 Refactor takeSample method in RDD to use ScaSRS #916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1441977
ffea61a
7cab53a
e3fd6a6
9bdd36e
065ebcd
ae3ad04
f80f270
0a9b3e3
ecab508
eff89e2
55518ed
64e445b
dc699f3
1481b01
fb1452f
48d954d
82dde31
3de882b
444e750
5b061ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag]( | |
| }.toArray | ||
| } | ||
|
|
||
| def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] = | ||
| { | ||
| /** | ||
| * Return a fixed-size sampled subset of this RDD in an array | ||
| * | ||
| * @param withReplacement whether sampling is done with replacement | ||
| * @param num size of the returned sample | ||
| * @param seed seed for the random number generator | ||
| * @return sample of specified size in an array | ||
| */ | ||
| def takeSample(withReplacement: Boolean, | ||
| num: Int, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use 4-space indentation |
||
| seed: Long = Utils.random.nextLong): Array[T] = { | ||
| var fraction = 0.0 | ||
| var total = 0 | ||
| val multiplier = 3.0 | ||
|
|
@@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag]( | |
| } | ||
|
|
||
| if (num > initialCount && !withReplacement) { | ||
| // special case not covered in computeFraction | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If sample without replacement,
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Legacy code to prevent overflow if initialCount = Integer.MAX_VALUE
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it can really prevent overflow. The fraction is chosen as |
||
| total = maxSelected | ||
| fraction = multiplier * (maxSelected + 1) / initialCount | ||
| } else { | ||
| fraction = multiplier * (num + 1) / initialCount | ||
| fraction = computeFraction(num, initialCount, withReplacement) | ||
| total = num | ||
| } | ||
|
|
||
|
|
@@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag]( | |
| Utils.randomizeInPlace(samples, rand).take(total) | ||
| } | ||
|
|
||
| private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only need this function in test. So it could be
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the space between
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, this function needs some docs about the theory. |
||
| val fraction = num.toDouble / total | ||
| if (withReplacement) { | ||
| var numStDev = 5 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| if (num < 12) { | ||
| // special case to guarantee sample size for small s | ||
| numStDev = 9 | ||
| } | ||
| fraction + numStDev * math.sqrt(fraction / total) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I understand what this expression is trying to bound but if you have a moment to comment it, would be great for the likes of me! |
||
| } else { | ||
| val delta = 0.00005 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Choose |
||
| val gamma = - math.log(delta)/total | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction)) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Return the union of this RDD and another one. Any identical elements will appear multiple | ||
| * times (use `.distinct()` to eliminate them). | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) | |
| } | ||
|
|
||
| /** | ||
| * Return a sampler with is the complement of the range specified of the current sampler. | ||
| * Return a sampler which is the complement of the range specified of the current sampler. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and I thought I was being a grammar nazi.... |
||
| */ | ||
| def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag | |
|
|
||
| import org.scalatest.FunSuite | ||
|
|
||
| import org.apache.commons.math3.distribution.PoissonDistribution | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add an empty line after this line to organize imports into groups
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need those two? |
||
| import org.apache.spark._ | ||
| import org.apache.spark.SparkContext._ | ||
| import org.apache.spark.rdd._ | ||
|
|
@@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext { | |
| assert(sortedTopK === nums.sorted(ord).take(5)) | ||
| } | ||
|
|
||
| test("computeFraction") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add tests for |
||
| // test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line too wide? |
||
| val data = new EmptyRDD[Int](sc) | ||
| val n = 100000 | ||
|
|
||
| for (s <- 1 to 15) { | ||
| val frac = data.computeFraction(s, n, true) | ||
| val qpois = new PoissonDistribution(frac * n) | ||
| assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
| } | ||
| for (s <- 1 to 15) { | ||
| val frac = data.computeFraction(s, n, false) | ||
| val qpois = new PoissonDistribution(frac * n) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BinomailDistribution should be used instead of Pois. |
||
| assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
| } | ||
| for (s <- List(1, 10, 100, 1000)) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1 and 10 are already tested. |
||
| val frac = data.computeFraction(s, n, true) | ||
| val qpois = new PoissonDistribution(frac * n) | ||
| assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
| } | ||
| for (s <- List(1, 10, 100, 1000)) { | ||
| val frac = data.computeFraction(s, n, false) | ||
| val qpois = new PoissonDistribution(frac * n) | ||
| assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low") | ||
| } | ||
| } | ||
|
|
||
| test("takeSample") { | ||
| val data = sc.parallelize(1 to 100, 2) | ||
| val n = 1000000 | ||
| val data = sc.parallelize(1 to n, 2) | ||
|
|
||
| for (num <- List(5, 20, 100)) { | ||
| val sample = data.takeSample(withReplacement=false, num=num) | ||
| assert(sample.size === num) // Got exactly num elements | ||
| assert(sample.toSet.size === num) // Elements are distinct | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Error messages need to be updated to use |
||
| } | ||
| for (seed <- 1 to 5) { | ||
| val sample = data.takeSample(withReplacement=false, 20, seed) | ||
| assert(sample.size === 20) // Got exactly 20 elements | ||
| assert(sample.toSet.size === 20) // Elements are distinct | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") | ||
| } | ||
| for (seed <- 1 to 5) { | ||
| val sample = data.takeSample(withReplacement=false, 200, seed) | ||
| val sample = data.takeSample(withReplacement=false, 100, seed) | ||
| assert(sample.size === 100) // Got only 100 elements | ||
| assert(sample.toSet.size === 100) // Elements are distinct | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") | ||
| } | ||
| for (seed <- 1 to 5) { | ||
| val sample = data.takeSample(withReplacement=true, 20, seed) | ||
| assert(sample.size === 20) // Got exactly 20 elements | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") | ||
| } | ||
| { | ||
| val sample = data.takeSample(withReplacement=true, num=20) | ||
| assert(sample.size === 20) // Got exactly 100 elements | ||
| assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements") | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") | ||
| } | ||
| { | ||
| val sample = data.takeSample(withReplacement=true, num=100) | ||
| assert(sample.size === 100) // Got exactly 100 elements | ||
| val sample = data.takeSample(withReplacement=true, num=n) | ||
| assert(sample.size === n) // Got exactly 100 elements | ||
| // Chance of getting all distinct elements is astronomically low, so test we got < 100 | ||
| assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") | ||
| assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]") | ||
| assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") | ||
| assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]") | ||
| } | ||
| for (seed <- 1 to 5) { | ||
| val sample = data.takeSample(withReplacement=true, 100, seed) | ||
| assert(sample.size === 100) // Got exactly 100 elements | ||
| val sample = data.takeSample(withReplacement=true, n, seed) | ||
| assert(sample.size === n) // Got exactly 100 elements | ||
| // Chance of getting all distinct elements is astronomically low, so test we got < 100 | ||
| assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") | ||
| assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") | ||
| } | ||
| for (seed <- 1 to 5) { | ||
| val sample = data.takeSample(withReplacement=true, 200, seed) | ||
| assert(sample.size === 200) // Got exactly 200 elements | ||
| val sample = data.takeSample(withReplacement=true, 2*n, seed) | ||
|
||
| assert(sample.size === 2*n) // Got exactly 200 elements | ||
|
||
| // Chance of getting all distinct elements is still quite low, so test we got < 100 | ||
| assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements") | ||
| assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements") | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -245,6 +245,11 @@ | |
| <artifactId>commons-codec</artifactId> | ||
| <version>1.5</version> | ||
| </dependency> | ||
| <dependency> | ||
| <groupId>org.apache.commons</groupId> | ||
| <artifactId>commons-math3</artifactId> | ||
| <version>3.2</version> | ||
|
||
| </dependency> | ||
| <dependency> | ||
| <groupId>com.google.code.findbugs</groupId> | ||
| <artifactId>jsr305</artifactId> | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| import warnings | ||
| import heapq | ||
| from random import Random | ||
| from math import sqrt, log, min | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
| BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long | ||
|
|
@@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None): | |
| total = maxSelected | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add the same check for the max sample size? |
||
| fraction = multiplier * (maxSelected + 1) / initialCount | ||
| else: | ||
| fraction = multiplier * (num + 1) / initialCount | ||
| fraction = self._computeFraction(num, initialCount, withReplacement) | ||
| total = num | ||
|
|
||
| samples = self.sample(withReplacement, fraction, seed).collect() | ||
|
|
@@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None): | |
| sampler.shuffle(samples) | ||
| return samples[0:total] | ||
|
|
||
| def _computeFraction(self, num, total, withReplacement): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add |
||
| fraction = float(num)/total | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if withReplacement: | ||
| numStDev = 5 | ||
| if (num < 12): | ||
| numStDev = 9 | ||
| return fraction + numStDev * sqrt(fraction/total) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| else: | ||
| delta = 0.00005 | ||
| gamma = - log(delta)/total | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| def union(self, other): | ||
| """ | ||
| Return the union of this RDD and another one. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be
<scope>test</scope>if it's a test-only dependency?