Skip to content

Commit 17bcbf3

Browse files
committed
Added seed.
1 parent badf20d commit 17bcbf3

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

core/src/main/scala/org/apache/spark/util/random/SamplingUtils.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util.random
1919

2020
import scala.reflect.ClassTag
21+
import scala.util.Random
2122

2223
private[spark] object SamplingUtils {
2324

@@ -26,9 +27,14 @@ private[spark] object SamplingUtils {
2627
*
2728
* @param input input size
2829
* @param k reservoir size
30+
* @param seed random seed
2931
* @return (samples, input size)
3032
*/
31-
def reservoirSampleAndCount[T: ClassTag](input: Iterator[T], k: Int): (Array[T], Int) = {
33+
def reservoirSampleAndCount[T: ClassTag](
34+
input: Iterator[T],
35+
k: Int,
36+
seed: Long = Random.nextLong())
37+
: (Array[T], Int) = {
3238
val reservoir = new Array[T](k)
3339
// Put the first k elements in the reservoir.
3440
var i = 0
@@ -46,7 +52,7 @@ private[spark] object SamplingUtils {
4652
(trimReservoir, i)
4753
} else {
4854
// If input size > k, continue the sampling process.
49-
val rand = new XORShiftRandom
55+
val rand = new XORShiftRandom(seed)
5056
while (input.hasNext) {
5157
val item = input.next()
5258
val replacementIndex = rand.nextInt(i)

0 commit comments

Comments
 (0)