Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
splits(0) = Double.NegativeInfinity
splits(splits.length - 1) = Double.PositiveInfinity

// 0.0 and -0.0 are distinct values, array.distinct will preserve both of them.
// but 0.0 > -0.0 is False which will break the parameter validation checking.
// and in scala <= 2.12, there's bug which will cause array.distinct generate
// non-deterministic results when array contains both 0.0 and -0.0
// So that here we should first normalize all 0.0 and -0.0 to be 0.0
// See https://github.com/scala/bug/issues/11995
for (i <- 0 until splits.length) {
if (splits(i) == -0.0) {
splits(i) = 0.0
}
}
val distinctSplits = splits.distinct
if (splits.length != distinctSplits.length) {
log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,4 +512,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
assert(observedNumBuckets === numBuckets,
"Observed number of buckets does not equal expected number of buckets.")
}

test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given invalid value") {
import scala.util.Random
val rng = new Random(3)

val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++
Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)

val df1 = sc.parallelize(a1, 2).toDF("id")

val qd = new QuantileDiscretizer()
.setInputCol("id")
.setOutputCol("out")
.setNumBuckets(200)
.setRelativeError(0.0)

qd.fit(df1) // assert no exception raised here.
}
}