Skip to content

Commit b2300fc

Browse files
WeichenXu123srowen
authored andcommitted
[SPARK-31676][ML] QuantileDiscretizer raise error parameter splits given invalid value (splits array includes -0.0 and 0.0)
### What changes were proposed in this pull request? In QuantileDiscretizer.getDistinctSplits, before invoking distinct, normalize all -0.0 and 0.0 to be 0.0 ``` for (i <- 0 until splits.length) { if (splits(i) == -0.0) { splits(i) = 0.0 } } ``` ### Why are the changes needed? Fix bug. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test. #### Manually test: ~~~scala import scala.util.Random val rng = new Random(3) val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ Array.fill(20)(0.0) ++ Array.fill(20)(-0.0) import spark.implicits._ val df1 = sc.parallelize(a1, 2).toDF("id") import org.apache.spark.ml.feature.QuantileDiscretizer val qd = new QuantileDiscretizer().setInputCol("id").setOutputCol("out").setNumBuckets(200).setRelativeError(0.0) val model = qd.fit(df1) // will raise error in spark master. ~~~ ### Explain scala `0.0 == -0.0` is True but `0.0.hashCode == -0.0.hashCode()` is False. This break the contract between equals() and hashCode() If two objects are equal, then they must have the same hash code. And array.distinct will rely on elem.hashCode so it leads to this error. Test code on distinct ``` import scala.util.Random val rng = new Random(3) val a1 = Array.tabulate(200)(_=>rng.nextDouble * 2.0 - 1.0) ++ Array.fill(20)(0.0) ++ Array.fill(20)(-0.0) a1.distinct.sorted.foreach(x => print(x.toString + "\n")) ``` Then you will see output like: ``` ... -0.009292684662246975 -0.0033280686465135823 -0.0 0.0 0.0022219556032221366 0.02217419561977274 ... ``` Closes #28498 from WeichenXu123/SPARK-31676. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Sean Owen <[email protected]>
1 parent ddbce4e commit b2300fc

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
236236
private def getDistinctSplits(splits: Array[Double]): Array[Double] = {
237237
splits(0) = Double.NegativeInfinity
238238
splits(splits.length - 1) = Double.PositiveInfinity
239+
240+
// 0.0 and -0.0 are distinct values, array.distinct will preserve both of them.
241+
// but 0.0 > -0.0 is False which will break the parameter validation checking.
242+
// and in scala <= 2.12, there's bug which will cause array.distinct generate
243+
// non-deterministic results when array contains both 0.0 and -0.0
244+
// So that here we should first normalize all 0.0 and -0.0 to be 0.0
245+
// See https://github.com/scala/bug/issues/11995
246+
for (i <- 0 until splits.length) {
247+
if (splits(i) == -0.0) {
248+
splits(i) = 0.0
249+
}
250+
}
239251
val distinctSplits = splits.distinct
240252
if (splits.length != distinctSplits.length) {
241253
log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +

mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,4 +512,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest {
512512
assert(observedNumBuckets === numBuckets,
513513
"Observed number of buckets does not equal expected number of buckets.")
514514
}
515+
516+
test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given invalid value") {
517+
import scala.util.Random
518+
val rng = new Random(3)
519+
520+
val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++
521+
Array.fill(20)(0.0) ++ Array.fill(20)(-0.0)
522+
523+
val df1 = sc.parallelize(a1, 2).toDF("id")
524+
525+
val qd = new QuantileDiscretizer()
526+
.setInputCol("id")
527+
.setOutputCol("out")
528+
.setNumBuckets(200)
529+
.setRelativeError(0.0)
530+
531+
qd.fit(df1) // assert no exception raised here.
532+
}
515533
}

0 commit comments

Comments
 (0)