diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 5f94af5ffe636..43738204c6704 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -56,10 +56,6 @@ import org.apache.spark.sql.types._ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan match { - // A subquery will be rewritten into join later, and will go through this rule - // eventually. Here we skip subquery, as we only need to run this rule once. - case _: Subquery => plan - case _ => plan transform { case w: Window if w.partitionSpec.exists(p => needNormalize(p)) => // Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f7a904169d6c3..a219b91627b2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3503,6 +3503,24 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"), Seq(Row(Short.MinValue.toLong * -1))) } + + test("normalize special floating numbers in subquery") { + withTempView("v1", "v2", "v3") { + Seq(-0.0).toDF("d").createTempView("v1") + Seq(0.0).toDF("d").createTempView("v2") + spark.range(2).createTempView("v3") + + // non-correlated subquery + checkAnswer(sql("SELECT (SELECT v1.d FROM v1 JOIN v2 ON v1.d = v2.d)"), Row(-0.0)) + // correlated subquery + checkAnswer( + sql( + """ + |SELECT id FROM v3 WHERE EXISTS + | (SELECT v1.d FROM v1 JOIN v2 ON v1.d = v2.d WHERE id > 0) + |""".stripMargin), Row(1)) + } + } } case class Foo(bar: Option[String]) diff --git a/sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 b/sql/hive/src/test/resources/golden/timestamp cast #3-0-732ed232ac592c5e7f7c913a88874fd2 similarity index 100% rename from sql/hive/src/test/resources/golden/timestamp cast #2-0-732ed232ac592c5e7f7c913a88874fd2 rename to sql/hive/src/test/resources/golden/timestamp cast #3-0-732ed232ac592c5e7f7c913a88874fd2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 b/sql/hive/src/test/resources/golden/timestamp cast #4-0-6d2da5cfada03605834e38bc4075bc79 similarity index 100% rename from sql/hive/src/test/resources/golden/timestamp cast #6-0-6d2da5cfada03605834e38bc4075bc79 rename to sql/hive/src/test/resources/golden/timestamp cast #4-0-6d2da5cfada03605834e38bc4075bc79 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 b/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 deleted file mode 100644 index 5625e59da8873..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #4-0-732ed232ac592c5e7f7c913a88874fd2 +++ /dev/null @@ -1 +0,0 @@ -1.2 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 b/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 deleted file mode 100644 index 1d94c8a014fb4..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #8-0-6d2da5cfada03605834e38bc4075bc79 +++ /dev/null @@ -1 +0,0 @@ --1.2 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 63b985fbe4d32..b10a8cb8bf2bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -556,33 +556,27 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd assert(1 == res.getDouble(0)) } - createQueryTest("timestamp cast #2", - "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - - test("timestamp cast #3") { - val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head - assert(1200 == res.getInt(0)) + test("timestamp cast #2") { + val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(-1 == res.get(0)) } - createQueryTest("timestamp cast #4", + createQueryTest("timestamp cast #3", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + createQueryTest("timestamp cast #4", + "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + test("timestamp cast #5") { - val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head - assert(-1 == res.get(0)) + val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head + assert(1200 == res.getInt(0)) } - createQueryTest("timestamp cast #6", - "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - - test("timestamp cast #7") { + test("timestamp cast #6") { val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head assert(-1200 == res.getInt(0)) } - createQueryTest("timestamp cast #8", - "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") - createQueryTest("select null from table", "SELECT null FROM src LIMIT 1")