Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be <scope>test</scope> if it's a test-only dependency?

<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
32 changes: 29 additions & 3 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag](
}.toArray
}

def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
{
/**
* Return a fixed-size sampled subset of this RDD in an array
*
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
*/
def takeSample(withReplacement: Boolean,
num: Int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use 4-space indentation

seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
Expand All @@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag](
}

if (num > initialCount && !withReplacement) {
// special case not covered in computeFraction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If sample without replacement, num cannot be greater than initialCount. What is block for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Legacy code to prevent overflow if initialCount = Integer.MAX_VALUE

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it can really prevent overflow. The fraction is chosen as 3 * INT_MAX / count, which means the expect sample size is 3 * INT_MAX > INT_MAX. So collect() will throw an exception almost surely.

total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
} else {
fraction = multiplier * (num + 1) / initialCount
fraction = computeFraction(num, initialCount, withReplacement)
total = num
}

Expand All @@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag](
Utils.randomizeInPlace(samples, rand).take(total)
}

private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We only need this function in test. So it could be private[rdd].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the space between ) and :.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, this function needs some docs about the theory.

val fraction = num.toDouble / total
if (withReplacement) {
var numStDev = 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val numStd = if (num < 12) 9 else 5

if (num < 12) {
// special case to guarantee sample size for small s
numStDev = 9
}
fraction + numStDev * math.sqrt(fraction / total)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand what this expression is trying to bound but if you have a moment to comment it, would be great for the likes of me!

} else {
val delta = 0.00005
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose delta = 1e-4 to have success rate greater than 99.99%.

val gamma = - math.log(delta)/total
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

math.log(delta) / total (space around operators)

math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
}
}

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}

/**
* Return a sampler with is the complement of the range specified of the current sampler.
* Return a sampler which is the complement of the range specified of the current sampler.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which -> , which or which -> that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and I thought I was being a grammar nazi....

*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)

Expand Down
63 changes: 46 additions & 17 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.ClassTag

import org.scalatest.FunSuite

import org.apache.commons.math3.distribution.PoissonDistribution
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an empty line after this line to organize imports into groups

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need those two?

import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
Expand Down Expand Up @@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}

test("computeFraction") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for BinomialDistribution as well?

// test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line too wide?

val data = new EmptyRDD[Int](sc)
val n = 100000

for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, true)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- 1 to 15) {
val frac = data.computeFraction(s, n, false)
val qpois = new PoissonDistribution(frac * n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BinomailDistribution should be used instead of Pois.

assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 and 10 are already tested.

val frac = data.computeFraction(s, n, true)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
for (s <- List(1, 10, 100, 1000)) {
val frac = data.computeFraction(s, n, false)
val qpois = new PoissonDistribution(frac * n)
assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
}
}

test("takeSample") {
val data = sc.parallelize(1 to 100, 2)
val n = 1000000
val data = sc.parallelize(1 to n, 2)

for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error messages need to be updated to use n instead of 100.

}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 200, seed)
val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=100)
assert(sample.size === 100) // Got exactly 100 elements
val sample = data.takeSample(withReplacement=true, num=n)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 100, seed)
assert(sample.size === 100) // Got exactly 100 elements
val sample = data.takeSample(withReplacement=true, n, seed)
assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 200, seed)
assert(sample.size === 200) // Got exactly 200 elements
val sample = data.takeSample(withReplacement=true, 2*n, seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 * n.

assert(sample.size === 2*n) // Got exactly 200 elements
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 * n

// Chance of getting all distinct elements is still quite low, so test we got < 100
assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}

Expand Down
5 changes: 5 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,11 @@
<artifactId>commons-codec</artifactId>
<version>1.5</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.2</version>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commons Math 3.3 is out now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a risk of conflicting with existing hadoop/hive distribution?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. If it's test-only, should be no risk. Although, hm, can this be expressed as test-only in SBT? If not then if it were 3.3 it would affect the final non-test artifact. Hm. Up to you all's judgment on that one. Heh, as it happens I did fix a bug in 3.3 in the Poisson distribution for CM: https://issues.apache.org/jira/browse/MATH-1056 Probably has no effect here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sbt allows test-only dependencies. forging ahead with 3.3. thanks for the info!

</dependency>
<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand Down
1 change: 1 addition & 0 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"com.google.guava" % "guava" % "14.0.1",
"org.apache.commons" % "commons-lang3" % "3.3.2",
"org.apache.commons" % "commons-math3" % "3.2",
"com.google.code.findbugs" % "jsr305" % "1.3.9",
"log4j" % "log4j" % "1.2.17",
"org.slf4j" % "slf4j-api" % slf4jVersion,
Expand Down
15 changes: 14 additions & 1 deletion python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import warnings
import heapq
from random import Random
from math import sqrt, log, min
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min is a built-in function, which doesn't need to be imported. Please run bin/pyspark path/to/rdd.py to test a single file.


from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
Expand Down Expand Up @@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None):
total = maxSelected
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add the same check for the max sample size?

fraction = multiplier * (maxSelected + 1) / initialCount
else:
fraction = multiplier * (num + 1) / initialCount
fraction = self._computeFraction(num, initialCount, withReplacement)
total = num

samples = self.sample(withReplacement, fraction, seed).collect()
Expand All @@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None):
sampler.shuffle(samples)
return samples[0:total]

def _computeFraction(self, num, total, withReplacement):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add @staticmethod

fraction = float(num)/total
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ -> / (space)

if withReplacement:
numStDev = 5
if (num < 12):
numStDev = 9
return fraction + numStDev * sqrt(fraction/total)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fraction / total

else:
delta = 0.00005
gamma = - log(delta)/total
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/ (space)

return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 *


def union(self, other):
"""
Return the union of this RDD and another one.
Expand Down