diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d35099b0642e..71d36733464f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,7 +38,7 @@ import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, SQLOrderingUtil} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf @@ -624,8 +624,12 @@ class CodegenContext extends Logging { def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" - case DoubleType => s"java.lang.Double.compare($c1, $c2)" - case FloatType => s"java.lang.Float.compare($c1, $c2)" + case DoubleType => + val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$") + s"$clsName.compareDoubles($c1, $c2)" + case FloatType => + val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$") + s"$clsName.compareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/SQLOrderingUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/SQLOrderingUtil.scala new file mode 100644 index 000000000000..3b7f748c2817 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/SQLOrderingUtil.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +object SQLOrderingUtil { + + /** + * A special version of double comparison that follows SQL semantic: + * 1. NaN == NaN + * 2. NaN is greater than any non-NaN double + * 3. -0.0 == 0.0 + */ + def compareDoubles(x: Double, y: Double): Int = { + if (x == y) 0 else java.lang.Double.compare(x, y) + } + + /** + * A special version of float comparison that follows SQL semantic: + * 1. NaN == NaN + * 2. NaN is greater than any non-NaN float + * 3. -0.0 == 0.0 + */ + def compareFloats(x: Float, y: Float): Int = { + if (x == y) 0 else java.lang.Float.compare(x, y) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 01268a9ff166..ea4f39d4b19d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import scala.util.Try import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.util.SQLOrderingUtil /** * The data type representing `Double` values. Please use the singleton `DataTypes.DoubleType`. @@ -38,7 +39,7 @@ class DoubleType private() extends FractionalType { private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] private[sql] val ordering = - (x: Double, y: Double) => java.lang.Double.compare(x, y) + (x: Double, y: Double) => SQLOrderingUtil.compareDoubles(x, y) private[sql] val asIntegral = DoubleType.DoubleAsIfIntegral override private[sql] def exactNumeric = DoubleExactNumeric diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 1491f5904bae..f00046facf69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.typeTag import scala.util.Try import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.util.SQLOrderingUtil /** * The data type representing `Float` values. Please use the singleton `DataTypes.FloatType`. @@ -38,7 +39,7 @@ class FloatType private() extends FractionalType { private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] private[sql] val ordering = - (x: Float, y: Float) => java.lang.Float.compare(x, y) + (x: Float, y: Float) => SQLOrderingUtil.compareFloats(x, y) private[sql] val asIntegral = FloatType.FloatAsIfIntegral override private[sql] def exactNumeric = FloatExactNumeric diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala index 3956629cf6a5..7026ff7de2e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/numerics.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.types import scala.math.Numeric._ import scala.math.Ordering +import org.apache.spark.sql.catalyst.util.SQLOrderingUtil import org.apache.spark.sql.types.Decimal.DecimalIsConflicted private[sql] object ByteExactNumeric extends ByteIsIntegral with Ordering.ByteOrdering { @@ -148,7 +149,7 @@ private[sql] object FloatExactNumeric extends FloatIsFractional { } } - override def compare(x: Float, y: Float): Int = java.lang.Float.compare(x, y) + override def compare(x: Float, y: Float): Int = SQLOrderingUtil.compareFloats(x, y) } private[sql] object DoubleExactNumeric extends DoubleIsFractional { @@ -176,7 +177,7 @@ private[sql] object DoubleExactNumeric extends DoubleIsFractional { } } - override def compare(x: Double, y: Double): Int = java.lang.Double.compare(x, y) + override def compare(x: Double, y: Double): Int = SQLOrderingUtil.compareDoubles(x, y) } private[sql] object DecimalExactNumeric extends DecimalIsConflicted { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 1ad0a8ed758f..a36baec1a0b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -538,4 +538,20 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val inSet = InSet(BoundReference(0, IntegerType, true), Set.empty) checkEvaluation(inSet, false, row) } + + test("SPARK-32764: compare special double/float values") { + checkEvaluation(EqualTo(Literal(Double.NaN), Literal(Double.NaN)), true) + checkEvaluation(EqualTo(Literal(Double.NaN), Literal(Double.PositiveInfinity)), false) + checkEvaluation(EqualTo(Literal(0.0D), Literal(-0.0D)), true) + checkEvaluation(GreaterThan(Literal(Double.NaN), Literal(Double.PositiveInfinity)), true) + checkEvaluation(GreaterThan(Literal(Double.NaN), Literal(Double.NaN)), false) + checkEvaluation(GreaterThan(Literal(0.0D), Literal(-0.0D)), false) + + checkEvaluation(EqualTo(Literal(Float.NaN), Literal(Float.NaN)), true) + checkEvaluation(EqualTo(Literal(Float.NaN), Literal(Float.PositiveInfinity)), false) + checkEvaluation(EqualTo(Literal(0.0F), Literal(-0.0F)), true) + checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.PositiveInfinity)), true) + checkEvaluation(GreaterThan(Literal(Float.NaN), Literal(Float.NaN)), false) + checkEvaluation(GreaterThan(Literal(0.0F), Literal(-0.0F)), false) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/SQLOrderingUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/SQLOrderingUtilSuite.scala new file mode 100644 index 000000000000..6fe774e8afcb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/SQLOrderingUtilSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import java.lang.{Double => JDouble, Float => JFloat} + +import org.apache.spark.SparkFunSuite + +class SQLOrderingUtilSuite extends SparkFunSuite { + + test("compareDoublesSQL") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(SQLOrderingUtil.compareDoubles(a, b) === JDouble.compare(a, b)) + assert(SQLOrderingUtil.compareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(-1d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + + val specialNaN = JDouble.longBitsToDouble(0x7ff1234512345678L) + assert(JDouble.isNaN(specialNaN)) + assert(JDouble.doubleToRawLongBits(Double.NaN) != JDouble.doubleToRawLongBits(specialNaN)) + + assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.NaN) === 0) + assert(SQLOrderingUtil.compareDoubles(Double.NaN, specialNaN) === 0) + assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.PositiveInfinity) > 0) + assert(SQLOrderingUtil.compareDoubles(specialNaN, Double.PositiveInfinity) > 0) + assert(SQLOrderingUtil.compareDoubles(Double.NaN, Double.NegativeInfinity) > 0) + assert(SQLOrderingUtil.compareDoubles(Double.PositiveInfinity, Double.NaN) < 0) + assert(SQLOrderingUtil.compareDoubles(Double.NegativeInfinity, Double.NaN) < 0) + assert(SQLOrderingUtil.compareDoubles(0.0d, -0.0d) === 0) + assert(SQLOrderingUtil.compareDoubles(-0.0d, 0.0d) === 0) + } + + test("compareFloatsSQL") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(SQLOrderingUtil.compareFloats(a, b) === JFloat.compare(a, b)) + assert(SQLOrderingUtil.compareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(0f, 1f) + shouldMatchDefaultOrder(-1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + + val specialNaN = JFloat.intBitsToFloat(-6966608) + assert(JFloat.isNaN(specialNaN)) + assert(JFloat.floatToRawIntBits(Float.NaN) != JFloat.floatToRawIntBits(specialNaN)) + + assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.NaN) === 0) + assert(SQLOrderingUtil.compareDoubles(Float.NaN, specialNaN) === 0) + assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.PositiveInfinity) > 0) + assert(SQLOrderingUtil.compareDoubles(specialNaN, Float.PositiveInfinity) > 0) + assert(SQLOrderingUtil.compareDoubles(Float.NaN, Float.NegativeInfinity) > 0) + assert(SQLOrderingUtil.compareDoubles(Float.PositiveInfinity, Float.NaN) < 0) + assert(SQLOrderingUtil.compareDoubles(Float.NegativeInfinity, Float.NaN) < 0) + assert(SQLOrderingUtil.compareDoubles(0.0f, -0.0f) === 0) + assert(SQLOrderingUtil.compareDoubles(-0.0f, 0.0f) === 0) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 155035c3fa05..d95f09a4cc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2550,6 +2550,11 @@ class DataFrameSuite extends QueryTest test("SPARK-32761: aggregating multiple distinct CONSTANT columns") { checkAnswer(sql("select count(distinct 2), count(distinct 2,3)"), Row(1, 1)) } + + test("SPARK-32764: -0.0 and 0.0 should be equal") { + val df = Seq(0.0 -> -0.0).toDF("pos", "neg") + checkAnswer(df.select($"pos" > $"neg"), Row(false)) + } } case class GroupByKey(a: Int, b: Int)