Skip to content

Commit 586e716

Browse files
committed
Reservoir sampling implementation.
This is going to be used in https://issues.apache.org/jira/browse/SPARK-2568 Author: Reynold Xin <[email protected]> Closes #1478 from rxin/reservoirSample and squashes the following commits: 17bcbf3 [Reynold Xin] Added seed. badf20d [Reynold Xin] Renamed the method. 6940010 [Reynold Xin] Reservoir sampling implementation.
1 parent 7f87ab9 commit 586e716

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,54 @@
1717

1818
package org.apache.spark.util.random
1919

20+
import scala.reflect.ClassTag
21+
import scala.util.Random
22+
2023
private[spark] object SamplingUtils {
2124

25+
/**
26+
* Reservoir sampling implementation that also returns the input size.
27+
*
28+
* @param input input size
29+
* @param k reservoir size
30+
* @param seed random seed
31+
* @return (samples, input size)
32+
*/
33+
def reservoirSampleAndCount[T: ClassTag](
34+
input: Iterator[T],
35+
k: Int,
36+
seed: Long = Random.nextLong())
37+
: (Array[T], Int) = {
38+
val reservoir = new Array[T](k)
39+
// Put the first k elements in the reservoir.
40+
var i = 0
41+
while (i < k && input.hasNext) {
42+
val item = input.next()
43+
reservoir(i) = item
44+
i += 1
45+
}
46+
47+
// If we have consumed all the elements, return them. Otherwise do the replacement.
48+
if (i < k) {
49+
// If input size < k, trim the array to return only an array of input size.
50+
val trimReservoir = new Array[T](i)
51+
System.arraycopy(reservoir, 0, trimReservoir, 0, i)
52+
(trimReservoir, i)
53+
} else {
54+
// If input size > k, continue the sampling process.
55+
val rand = new XORShiftRandom(seed)
56+
while (input.hasNext) {
57+
val item = input.next()
58+
val replacementIndex = rand.nextInt(i)
59+
if (replacementIndex < k) {
60+
reservoir(replacementIndex) = item
61+
}
62+
i += 1
63+
}
64+
(reservoir, i)
65+
}
66+
}
67+
2268
/**
2369
* Returns a sampling rate that guarantees a sample of size >= sampleSizeLowerBound 99.99% of
2470
* the time.

core/src/test/scala/org/apache/spark/util/random/SamplingUtilsSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,32 @@
1717

1818
package org.apache.spark.util.random
1919

20+
import scala.util.Random
21+
2022
import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution}
2123
import org.scalatest.FunSuite
2224

2325
class SamplingUtilsSuite extends FunSuite {
2426

27+
test("reservoirSampleAndCount") {
28+
val input = Seq.fill(100)(Random.nextInt())
29+
30+
// input size < k
31+
val (sample1, count1) = SamplingUtils.reservoirSampleAndCount(input.iterator, 150)
32+
assert(count1 === 100)
33+
assert(input === sample1.toSeq)
34+
35+
// input size == k
36+
val (sample2, count2) = SamplingUtils.reservoirSampleAndCount(input.iterator, 100)
37+
assert(count2 === 100)
38+
assert(input === sample2.toSeq)
39+
40+
// input size > k
41+
val (sample3, count3) = SamplingUtils.reservoirSampleAndCount(input.iterator, 10)
42+
assert(count3 === 100)
43+
assert(sample3.length === 10)
44+
}
45+
2546
test("computeFraction") {
2647
// test that the computed fraction guarantees enough data points
2748
// in the sample with a failure rate <= 0.0001

0 commit comments

Comments
 (0)