Skip to content

Commit ad3bd0d

Browse files
erikerlandsonmengxr
authored andcommitted
[SPARK-3250] Implement Gap Sampling optimization for random sampling
More efficient sampling, based on Gap Sampling optimization: http://erikerlandson.github.io/blog/2014/09/11/faster-random-samples-with-gap-sampling/ Author: Erik Erlandson <[email protected]> Closes #2455 from erikerlandson/spark-3250-pr and squashes the following commits: 72496bc [Erik Erlandson] [SPARK-3250] Implement Gap Sampling optimization for random sampling
1 parent 872fc66 commit ad3bd0d

File tree

5 files changed

+790
-121
lines changed

5 files changed

+790
-121
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ import org.apache.spark.partial.PartialResult
4343
import org.apache.spark.storage.StorageLevel
4444
import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
4545
import org.apache.spark.util.collection.OpenHashMap
46-
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
46+
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
47+
SamplingUtils}
4748

4849
/**
4950
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -375,7 +376,8 @@ abstract class RDD[T: ClassTag](
375376
val sum = weights.sum
376377
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
377378
normalizedCumWeights.sliding(2).map { x =>
378-
new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
379+
new PartitionwiseSampledRDD[T, T](
380+
this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
379381
}.toArray
380382
}
381383

core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala

Lines changed: 264 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ package org.apache.spark.util.random
1919

2020
import java.util.Random
2121

22+
import scala.reflect.ClassTag
23+
import scala.collection.mutable.ArrayBuffer
24+
2225
import org.apache.commons.math3.distribution.PoissonDistribution
2326

2427
import org.apache.spark.annotation.DeveloperApi
@@ -38,71 +41,310 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
3841
/** take a random sample */
3942
def sample(items: Iterator[T]): Iterator[U]
4043

44+
/** return a copy of the RandomSampler object */
4145
override def clone: RandomSampler[T, U] =
4246
throw new NotImplementedError("clone() is not implemented.")
4347
}
4448

49+
private[spark]
50+
object RandomSampler {
51+
/** Default random number generator used by random samplers. */
52+
def newDefaultRNG: Random = new XORShiftRandom
53+
54+
/**
55+
* Default maximum gap-sampling fraction.
56+
* For sampling fractions <= this value, the gap sampling optimization will be applied.
57+
* Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
58+
* optimal value for this will depend on the RNG. More expensive RNGs will tend to make
59+
* the optimal value higher. The most reliable way to determine this value for a new RNG
60+
* is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
61+
* in most cases, as an initial guess.
62+
*/
63+
val defaultMaxGapSamplingFraction = 0.4
64+
65+
/**
66+
* Default epsilon for floating point numbers sampled from the RNG.
67+
* The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
68+
* To guard against errors from taking log(0), a positive epsilon lower bound is applied.
69+
* A good value for this parameter is at or near the minimum positive floating
70+
* point value returned by "nextDouble()" (or equivalent), for the RNG being used.
71+
*/
72+
val rngEpsilon = 5e-11
73+
74+
/**
75+
* Sampling fraction arguments may be results of computation, and subject to floating
76+
* point jitter. I check the arguments with this epsilon slop factor to prevent spurious
77+
* warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
78+
*/
79+
val roundingEpsilon = 1e-6
80+
}
81+
4582
/**
4683
* :: DeveloperApi ::
47-
* A sampler based on Bernoulli trials.
84+
* A sampler based on Bernoulli trials for partitioning a data sequence.
4885
*
4986
* @param lb lower bound of the acceptance range
5087
* @param ub upper bound of the acceptance range
5188
* @param complement whether to use the complement of the range specified, default to false
5289
* @tparam T item type
5390
*/
5491
@DeveloperApi
55-
class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
92+
class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = false)
5693
extends RandomSampler[T, T] {
5794

58-
private[random] var rng: Random = new XORShiftRandom
95+
/** epsilon slop to avoid failure from floating point jitter. */
96+
require(
97+
lb <= (ub + RandomSampler.roundingEpsilon),
98+
s"Lower bound ($lb) must be <= upper bound ($ub)")
99+
require(
100+
lb >= (0.0 - RandomSampler.roundingEpsilon),
101+
s"Lower bound ($lb) must be >= 0.0")
102+
require(
103+
ub <= (1.0 + RandomSampler.roundingEpsilon),
104+
s"Upper bound ($ub) must be <= 1.0")
59105

60-
def this(ratio: Double) = this(0.0d, ratio)
106+
private val rng: Random = new XORShiftRandom
61107

62108
override def setSeed(seed: Long) = rng.setSeed(seed)
63109

64110
override def sample(items: Iterator[T]): Iterator[T] = {
65-
items.filter { item =>
66-
val x = rng.nextDouble()
67-
(x >= lb && x < ub) ^ complement
111+
if (ub - lb <= 0.0) {
112+
if (complement) items else Iterator.empty
113+
} else {
114+
if (complement) {
115+
items.filter { item => {
116+
val x = rng.nextDouble()
117+
(x < lb) || (x >= ub)
118+
}}
119+
} else {
120+
items.filter { item => {
121+
val x = rng.nextDouble()
122+
(x >= lb) && (x < ub)
123+
}}
124+
}
68125
}
69126
}
70127

71128
/**
72129
* Return a sampler that is the complement of the range specified of the current sampler.
73130
*/
74-
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
131+
def cloneComplement(): BernoulliCellSampler[T] =
132+
new BernoulliCellSampler[T](lb, ub, !complement)
133+
134+
override def clone = new BernoulliCellSampler[T](lb, ub, complement)
135+
}
136+
137+
138+
/**
139+
* :: DeveloperApi ::
140+
* A sampler based on Bernoulli trials.
141+
*
142+
* @param fraction the sampling fraction, aka Bernoulli sampling probability
143+
* @tparam T item type
144+
*/
145+
@DeveloperApi
146+
class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
147+
148+
/** epsilon slop to avoid failure from floating point jitter */
149+
require(
150+
fraction >= (0.0 - RandomSampler.roundingEpsilon)
151+
&& fraction <= (1.0 + RandomSampler.roundingEpsilon),
152+
s"Sampling fraction ($fraction) must be on interval [0, 1]")
75153

76-
override def clone = new BernoulliSampler[T](lb, ub, complement)
154+
private val rng: Random = RandomSampler.newDefaultRNG
155+
156+
override def setSeed(seed: Long) = rng.setSeed(seed)
157+
158+
override def sample(items: Iterator[T]): Iterator[T] = {
159+
if (fraction <= 0.0) {
160+
Iterator.empty
161+
} else if (fraction >= 1.0) {
162+
items
163+
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
164+
new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
165+
} else {
166+
items.filter { _ => rng.nextDouble() <= fraction }
167+
}
168+
}
169+
170+
override def clone = new BernoulliSampler[T](fraction)
77171
}
78172

173+
79174
/**
80175
* :: DeveloperApi ::
81-
* A sampler based on values drawn from Poisson distribution.
176+
* A sampler for sampling with replacement, based on values drawn from Poisson distribution.
82177
*
83-
* @param mean Poisson mean
178+
* @param fraction the sampling fraction (with replacement)
84179
* @tparam T item type
85180
*/
86181
@DeveloperApi
87-
class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
182+
class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
183+
184+
/** Epsilon slop to avoid failure from floating point jitter. */
185+
require(
186+
fraction >= (0.0 - RandomSampler.roundingEpsilon),
187+
s"Sampling fraction ($fraction) must be >= 0")
88188

89-
private[random] var rng = new PoissonDistribution(mean)
189+
// PoissonDistribution throws an exception when fraction <= 0
190+
// If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value.
191+
private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
192+
private val rngGap = RandomSampler.newDefaultRNG
90193

91194
override def setSeed(seed: Long) {
92-
rng = new PoissonDistribution(mean)
93195
rng.reseedRandomGenerator(seed)
196+
rngGap.setSeed(seed)
94197
}
95198

96199
override def sample(items: Iterator[T]): Iterator[T] = {
97-
items.flatMap { item =>
98-
val count = rng.sample()
99-
if (count == 0) {
100-
Iterator.empty
101-
} else {
102-
Iterator.fill(count)(item)
103-
}
200+
if (fraction <= 0.0) {
201+
Iterator.empty
202+
} else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
203+
new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
204+
} else {
205+
items.flatMap { item => {
206+
val count = rng.sample()
207+
if (count == 0) Iterator.empty else Iterator.fill(count)(item)
208+
}}
209+
}
210+
}
211+
212+
override def clone = new PoissonSampler[T](fraction)
213+
}
214+
215+
216+
private[spark]
217+
class GapSamplingIterator[T: ClassTag](
218+
var data: Iterator[T],
219+
f: Double,
220+
rng: Random = RandomSampler.newDefaultRNG,
221+
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
222+
223+
require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
224+
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
225+
226+
/** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
227+
private val iterDrop: Int => Unit = {
228+
val arrayClass = Array.empty[T].iterator.getClass
229+
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
230+
data.getClass match {
231+
case `arrayClass` => ((n: Int) => { data = data.drop(n) })
232+
case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
233+
case _ => ((n: Int) => {
234+
var j = 0
235+
while (j < n && data.hasNext) {
236+
data.next()
237+
j += 1
238+
}
239+
})
240+
}
241+
}
242+
243+
override def hasNext: Boolean = data.hasNext
244+
245+
override def next(): T = {
246+
val r = data.next()
247+
advance
248+
r
249+
}
250+
251+
private val lnq = math.log1p(-f)
252+
253+
/** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
254+
private def advance: Unit = {
255+
val u = math.max(rng.nextDouble(), epsilon)
256+
val k = (math.log(u) / lnq).toInt
257+
iterDrop(k)
258+
}
259+
260+
/** advance to first sample as part of object construction. */
261+
advance
262+
// Attempting to invoke this closer to the top with other object initialization
263+
// was causing it to break in strange ways, so I'm invoking it last, which seems to
264+
// work reliably.
265+
}
266+
267+
private[spark]
268+
class GapSamplingReplacementIterator[T: ClassTag](
269+
var data: Iterator[T],
270+
f: Double,
271+
rng: Random = RandomSampler.newDefaultRNG,
272+
epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
273+
274+
require(f > 0.0, s"Sampling fraction ($f) must be > 0")
275+
require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
276+
277+
/** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
278+
private val iterDrop: Int => Unit = {
279+
val arrayClass = Array.empty[T].iterator.getClass
280+
val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
281+
data.getClass match {
282+
case `arrayClass` => ((n: Int) => { data = data.drop(n) })
283+
case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
284+
case _ => ((n: Int) => {
285+
var j = 0
286+
while (j < n && data.hasNext) {
287+
data.next()
288+
j += 1
289+
}
290+
})
291+
}
292+
}
293+
294+
/** current sampling value, and its replication factor, as we are sampling with replacement. */
295+
private var v: T = _
296+
private var rep: Int = 0
297+
298+
override def hasNext: Boolean = data.hasNext || rep > 0
299+
300+
override def next(): T = {
301+
val r = v
302+
rep -= 1
303+
if (rep <= 0) advance
304+
r
305+
}
306+
307+
/**
308+
* Skip elements with replication factor zero (i.e. elements that won't be sampled).
309+
* Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
310+
* q is the probabililty of Poisson(0; f)
311+
*/
312+
private def advance: Unit = {
313+
val u = math.max(rng.nextDouble(), epsilon)
314+
val k = (math.log(u) / (-f)).toInt
315+
iterDrop(k)
316+
// set the value and replication factor for the next value
317+
if (data.hasNext) {
318+
v = data.next()
319+
rep = poissonGE1
320+
}
321+
}
322+
323+
private val q = math.exp(-f)
324+
325+
/**
326+
* Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
327+
* This is an adaptation from the algorithm for Generating Poisson distributed random variables:
328+
* http://en.wikipedia.org/wiki/Poisson_distribution
329+
*/
330+
private def poissonGE1: Int = {
331+
// simulate that the standard poisson sampling
332+
// gave us at least one iteration, for a sample of >= 1
333+
var pp = q + ((1.0 - q) * rng.nextDouble())
334+
var r = 1
335+
336+
// now continue with standard poisson sampling algorithm
337+
pp *= rng.nextDouble()
338+
while (pp > q) {
339+
r += 1
340+
pp *= rng.nextDouble()
104341
}
342+
r
105343
}
106344

107-
override def clone = new PoissonSampler[T](mean)
345+
/** advance to first sample as part of object construction. */
346+
advance
347+
// Attempting to invoke this closer to the top with other object initialization
348+
// was causing it to break in strange ways, so I'm invoking it last, which seems to
349+
// work reliably.
108350
}

core/src/test/java/org/apache/spark/JavaAPISuite.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,11 +140,10 @@ public void intersection() {
140140
public void sample() {
141141
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
142142
JavaRDD<Integer> rdd = sc.parallelize(ints);
143-
JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 11);
144-
// expected 2 but of course result varies randomly a bit
145-
Assert.assertEquals(1, sample20.count());
146-
JavaRDD<Integer> sample20NoReplacement = rdd.sample(false, 0.2, 11);
147-
Assert.assertEquals(2, sample20NoReplacement.count());
143+
JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 3);
144+
Assert.assertEquals(2, sample20.count());
145+
JavaRDD<Integer> sample20WithoutReplacement = rdd.sample(false, 0.2, 5);
146+
Assert.assertEquals(2, sample20WithoutReplacement.count());
148147
}
149148

150149
@Test

0 commit comments

Comments
 (0)