@@ -19,6 +19,9 @@ package org.apache.spark.util.random
1919
2020import java .util .Random
2121
22+ import scala .reflect .ClassTag
23+ import scala .collection .mutable .ArrayBuffer
24+
2225import org .apache .commons .math3 .distribution .PoissonDistribution
2326
2427import org .apache .spark .annotation .DeveloperApi
@@ -38,71 +41,310 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
3841 /** take a random sample */
3942 def sample (items : Iterator [T ]): Iterator [U ]
4043
44+ /** return a copy of the RandomSampler object */
4145 override def clone : RandomSampler [T , U ] =
4246 throw new NotImplementedError (" clone() is not implemented." )
4347}
4448
49+ private [spark]
50+ object RandomSampler {
51+ /** Default random number generator used by random samplers. */
52+ def newDefaultRNG : Random = new XORShiftRandom
53+
54+ /**
55+ * Default maximum gap-sampling fraction.
56+ * For sampling fractions <= this value, the gap sampling optimization will be applied.
57+ * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
58+ * optimal value for this will depend on the RNG. More expensive RNGs will tend to make
59+ * the optimal value higher. The most reliable way to determine this value for a new RNG
60+ * is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
61+ * in most cases, as an initial guess.
62+ */
63+ val defaultMaxGapSamplingFraction = 0.4
64+
65+ /**
66+ * Default epsilon for floating point numbers sampled from the RNG.
67+ * The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
68+ * To guard against errors from taking log(0), a positive epsilon lower bound is applied.
69+ * A good value for this parameter is at or near the minimum positive floating
70+ * point value returned by "nextDouble()" (or equivalent), for the RNG being used.
71+ */
72+ val rngEpsilon = 5e-11
73+
74+ /**
75+ * Sampling fraction arguments may be results of computation, and subject to floating
76+ * point jitter. I check the arguments with this epsilon slop factor to prevent spurious
77+ * warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
78+ */
79+ val roundingEpsilon = 1e-6
80+ }
81+
4582/**
4683 * :: DeveloperApi ::
47- * A sampler based on Bernoulli trials.
84+ * A sampler based on Bernoulli trials for partitioning a data sequence .
4885 *
4986 * @param lb lower bound of the acceptance range
5087 * @param ub upper bound of the acceptance range
5188 * @param complement whether to use the complement of the range specified, default to false
5289 * @tparam T item type
5390 */
5491@ DeveloperApi
55- class BernoulliSampler [T ](lb : Double , ub : Double , complement : Boolean = false )
92+ class BernoulliCellSampler [T ](lb : Double , ub : Double , complement : Boolean = false )
5693 extends RandomSampler [T , T ] {
5794
58- private [random] var rng : Random = new XORShiftRandom
95+ /** epsilon slop to avoid failure from floating point jitter. */
96+ require(
97+ lb <= (ub + RandomSampler .roundingEpsilon),
98+ s " Lower bound ( $lb) must be <= upper bound ( $ub) " )
99+ require(
100+ lb >= (0.0 - RandomSampler .roundingEpsilon),
101+ s " Lower bound ( $lb) must be >= 0.0 " )
102+ require(
103+ ub <= (1.0 + RandomSampler .roundingEpsilon),
104+ s " Upper bound ( $ub) must be <= 1.0 " )
59105
60- def this ( ratio : Double ) = this ( 0.0d , ratio)
106+ private val rng : Random = new XORShiftRandom
61107
62108 override def setSeed (seed : Long ) = rng.setSeed(seed)
63109
64110 override def sample (items : Iterator [T ]): Iterator [T ] = {
65- items.filter { item =>
66- val x = rng.nextDouble()
67- (x >= lb && x < ub) ^ complement
111+ if (ub - lb <= 0.0 ) {
112+ if (complement) items else Iterator .empty
113+ } else {
114+ if (complement) {
115+ items.filter { item => {
116+ val x = rng.nextDouble()
117+ (x < lb) || (x >= ub)
118+ }}
119+ } else {
120+ items.filter { item => {
121+ val x = rng.nextDouble()
122+ (x >= lb) && (x < ub)
123+ }}
124+ }
68125 }
69126 }
70127
71128 /**
72129 * Return a sampler that is the complement of the range specified of the current sampler.
73130 */
74- def cloneComplement (): BernoulliSampler [T ] = new BernoulliSampler [T ](lb, ub, ! complement)
131+ def cloneComplement (): BernoulliCellSampler [T ] =
132+ new BernoulliCellSampler [T ](lb, ub, ! complement)
133+
134+ override def clone = new BernoulliCellSampler [T ](lb, ub, complement)
135+ }
136+
137+
138+ /**
139+ * :: DeveloperApi ::
140+ * A sampler based on Bernoulli trials.
141+ *
142+ * @param fraction the sampling fraction, aka Bernoulli sampling probability
143+ * @tparam T item type
144+ */
145+ @ DeveloperApi
146+ class BernoulliSampler [T : ClassTag ](fraction : Double ) extends RandomSampler [T , T ] {
147+
148+ /** epsilon slop to avoid failure from floating point jitter */
149+ require(
150+ fraction >= (0.0 - RandomSampler .roundingEpsilon)
151+ && fraction <= (1.0 + RandomSampler .roundingEpsilon),
152+ s " Sampling fraction ( $fraction) must be on interval [0, 1] " )
75153
76- override def clone = new BernoulliSampler [T ](lb, ub, complement)
154+ private val rng : Random = RandomSampler .newDefaultRNG
155+
156+ override def setSeed (seed : Long ) = rng.setSeed(seed)
157+
158+ override def sample (items : Iterator [T ]): Iterator [T ] = {
159+ if (fraction <= 0.0 ) {
160+ Iterator .empty
161+ } else if (fraction >= 1.0 ) {
162+ items
163+ } else if (fraction <= RandomSampler .defaultMaxGapSamplingFraction) {
164+ new GapSamplingIterator (items, fraction, rng, RandomSampler .rngEpsilon)
165+ } else {
166+ items.filter { _ => rng.nextDouble() <= fraction }
167+ }
168+ }
169+
170+ override def clone = new BernoulliSampler [T ](fraction)
77171}
78172
173+
79174/**
80175 * :: DeveloperApi ::
81- * A sampler based on values drawn from Poisson distribution.
176+ * A sampler for sampling with replacement, based on values drawn from Poisson distribution.
82177 *
83- * @param mean Poisson mean
178+ * @param fraction the sampling fraction (with replacement)
84179 * @tparam T item type
85180 */
86181@ DeveloperApi
87- class PoissonSampler [T ](mean : Double ) extends RandomSampler [T , T ] {
182+ class PoissonSampler [T : ClassTag ](fraction : Double ) extends RandomSampler [T , T ] {
183+
184+ /** Epsilon slop to avoid failure from floating point jitter. */
185+ require(
186+ fraction >= (0.0 - RandomSampler .roundingEpsilon),
187+ s " Sampling fraction ( $fraction) must be >= 0 " )
88188
89- private [random] var rng = new PoissonDistribution (mean)
189+ // PoissonDistribution throws an exception when fraction <= 0
190+ // If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value.
191+ private val rng = new PoissonDistribution (if (fraction > 0.0 ) fraction else 1.0 )
192+ private val rngGap = RandomSampler .newDefaultRNG
90193
91194 override def setSeed (seed : Long ) {
92- rng = new PoissonDistribution (mean)
93195 rng.reseedRandomGenerator(seed)
196+ rngGap.setSeed(seed)
94197 }
95198
96199 override def sample (items : Iterator [T ]): Iterator [T ] = {
97- items.flatMap { item =>
98- val count = rng.sample()
99- if (count == 0 ) {
100- Iterator .empty
101- } else {
102- Iterator .fill(count)(item)
103- }
200+ if (fraction <= 0.0 ) {
201+ Iterator .empty
202+ } else if (fraction <= RandomSampler .defaultMaxGapSamplingFraction) {
203+ new GapSamplingReplacementIterator (items, fraction, rngGap, RandomSampler .rngEpsilon)
204+ } else {
205+ items.flatMap { item => {
206+ val count = rng.sample()
207+ if (count == 0 ) Iterator .empty else Iterator .fill(count)(item)
208+ }}
209+ }
210+ }
211+
212+ override def clone = new PoissonSampler [T ](fraction)
213+ }
214+
215+
216+ private [spark]
217+ class GapSamplingIterator [T : ClassTag ](
218+ var data : Iterator [T ],
219+ f : Double ,
220+ rng : Random = RandomSampler .newDefaultRNG,
221+ epsilon : Double = RandomSampler .rngEpsilon) extends Iterator [T ] {
222+
223+ require(f > 0.0 && f < 1.0 , s " Sampling fraction ( $f) must reside on open interval (0, 1) " )
224+ require(epsilon > 0.0 , s " epsilon ( $epsilon) must be > 0 " )
225+
226+ /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
227+ private val iterDrop : Int => Unit = {
228+ val arrayClass = Array .empty[T ].iterator.getClass
229+ val arrayBufferClass = ArrayBuffer .empty[T ].iterator.getClass
230+ data.getClass match {
231+ case `arrayClass` => ((n : Int ) => { data = data.drop(n) })
232+ case `arrayBufferClass` => ((n : Int ) => { data = data.drop(n) })
233+ case _ => ((n : Int ) => {
234+ var j = 0
235+ while (j < n && data.hasNext) {
236+ data.next()
237+ j += 1
238+ }
239+ })
240+ }
241+ }
242+
243+ override def hasNext : Boolean = data.hasNext
244+
245+ override def next (): T = {
246+ val r = data.next()
247+ advance
248+ r
249+ }
250+
251+ private val lnq = math.log1p(- f)
252+
253+ /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
254+ private def advance : Unit = {
255+ val u = math.max(rng.nextDouble(), epsilon)
256+ val k = (math.log(u) / lnq).toInt
257+ iterDrop(k)
258+ }
259+
260+ /** advance to first sample as part of object construction. */
261+ advance
262+ // Attempting to invoke this closer to the top with other object initialization
263+ // was causing it to break in strange ways, so I'm invoking it last, which seems to
264+ // work reliably.
265+ }
266+
267+ private [spark]
268+ class GapSamplingReplacementIterator [T : ClassTag ](
269+ var data : Iterator [T ],
270+ f : Double ,
271+ rng : Random = RandomSampler .newDefaultRNG,
272+ epsilon : Double = RandomSampler .rngEpsilon) extends Iterator [T ] {
273+
274+ require(f > 0.0 , s " Sampling fraction ( $f) must be > 0 " )
275+ require(epsilon > 0.0 , s " epsilon ( $epsilon) must be > 0 " )
276+
277+ /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
278+ private val iterDrop : Int => Unit = {
279+ val arrayClass = Array .empty[T ].iterator.getClass
280+ val arrayBufferClass = ArrayBuffer .empty[T ].iterator.getClass
281+ data.getClass match {
282+ case `arrayClass` => ((n : Int ) => { data = data.drop(n) })
283+ case `arrayBufferClass` => ((n : Int ) => { data = data.drop(n) })
284+ case _ => ((n : Int ) => {
285+ var j = 0
286+ while (j < n && data.hasNext) {
287+ data.next()
288+ j += 1
289+ }
290+ })
291+ }
292+ }
293+
294+ /** current sampling value, and its replication factor, as we are sampling with replacement. */
295+ private var v : T = _
296+ private var rep : Int = 0
297+
298+ override def hasNext : Boolean = data.hasNext || rep > 0
299+
300+ override def next (): T = {
301+ val r = v
302+ rep -= 1
303+ if (rep <= 0 ) advance
304+ r
305+ }
306+
307+ /**
308+ * Skip elements with replication factor zero (i.e. elements that won't be sampled).
309+ * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
310+ * q is the probabililty of Poisson(0; f)
311+ */
312+ private def advance : Unit = {
313+ val u = math.max(rng.nextDouble(), epsilon)
314+ val k = (math.log(u) / (- f)).toInt
315+ iterDrop(k)
316+ // set the value and replication factor for the next value
317+ if (data.hasNext) {
318+ v = data.next()
319+ rep = poissonGE1
320+ }
321+ }
322+
323+ private val q = math.exp(- f)
324+
325+ /**
326+ * Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
327+ * This is an adaptation from the algorithm for Generating Poisson distributed random variables:
328+ * http://en.wikipedia.org/wiki/Poisson_distribution
329+ */
330+ private def poissonGE1 : Int = {
331+ // simulate that the standard poisson sampling
332+ // gave us at least one iteration, for a sample of >= 1
333+ var pp = q + ((1.0 - q) * rng.nextDouble())
334+ var r = 1
335+
336+ // now continue with standard poisson sampling algorithm
337+ pp *= rng.nextDouble()
338+ while (pp > q) {
339+ r += 1
340+ pp *= rng.nextDouble()
104341 }
342+ r
105343 }
106344
107- override def clone = new PoissonSampler [T ](mean)
345+ /** advance to first sample as part of object construction. */
346+ advance
347+ // Attempting to invoke this closer to the top with other object initialization
348+ // was causing it to break in strange ways, so I'm invoking it last, which seems to
349+ // work reliably.
108350}
0 commit comments