diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 216d99d01f2f..4eedfc4dc0ef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -236,6 +236,18 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui private def getDistinctSplits(splits: Array[Double]): Array[Double] = { splits(0) = Double.NegativeInfinity splits(splits.length - 1) = Double.PositiveInfinity + + // 0.0 and -0.0 are distinct values, array.distinct will preserve both of them. + // but 0.0 > -0.0 is False which will break the parameter validation checking. + // and in scala <= 2.12, there's bug which will cause array.distinct generate + // non-deterministic results when array contains both 0.0 and -0.0 + // So that here we should first normalize all 0.0 and -0.0 to be 0.0 + // See https://github.com/scala/bug/issues/11995 + for (i <- 0 until splits.length) { + if (splits(i) == -0.0) { + splits(i) = 0.0 + } + } val distinctSplits = splits.distinct if (splits.length != distinctSplits.length) { log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 6f6ab26cbac4..682b87a0f68d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -512,4 +512,22 @@ class QuantileDiscretizerSuite extends MLTest with DefaultReadWriteTest { assert(observedNumBuckets === numBuckets, "Observed number of buckets does not equal expected number of buckets.") } + + test("[SPARK-31676] QuantileDiscretizer raise error parameter splits given invalid value") { + import scala.util.Random + val rng = new Random(3) + + val a1 = Array.tabulate(200)(_ => rng.nextDouble * 2.0 - 1.0) ++ + Array.fill(20)(0.0) ++ Array.fill(20)(-0.0) + + val df1 = sc.parallelize(a1, 2).toDF("id") + + val qd = new QuantileDiscretizer() + .setInputCol("id") + .setOutputCol("out") + .setNumBuckets(200) + .setRelativeError(0.0) + + qd.fit(df1) // assert no exception raised here. + } }