Skip to content
Merged

sync #16

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,6 @@ import org.apache.spark.sql.types._
object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan match {
// A subquery will be rewritten into join later, and will go through this rule
// eventually. Here we skip subquery, as we only need to run this rule once.
case _: Subquery => plan

case _ => plan transform {
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
Expand Down
18 changes: 18 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3503,6 +3503,24 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"),
Seq(Row(Short.MinValue.toLong * -1)))
}

test("normalize special floating numbers in subquery") {
withTempView("v1", "v2", "v3") {
Seq(-0.0).toDF("d").createTempView("v1")
Seq(0.0).toDF("d").createTempView("v2")
spark.range(2).createTempView("v3")

// non-correlated subquery
checkAnswer(sql("SELECT (SELECT v1.d FROM v1 JOIN v2 ON v1.d = v2.d)"), Row(-0.0))
// correlated subquery
checkAnswer(
sql(
"""
|SELECT id FROM v3 WHERE EXISTS
| (SELECT v1.d FROM v1 JOIN v2 ON v1.d = v2.d WHERE id > 0)
|""".stripMargin), Row(1))
}
}
}

case class Foo(bar: Option[String])

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -556,33 +556,27 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd
assert(1 == res.getDouble(0))
}

createQueryTest("timestamp cast #2",
"SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1")

test("timestamp cast #3") {
val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head
assert(1200 == res.getInt(0))
test("timestamp cast #2") {
val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head
assert(-1 == res.get(0))
}

createQueryTest("timestamp cast #4",
createQueryTest("timestamp cast #3",
"SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1")

createQueryTest("timestamp cast #4",
"SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1")

test("timestamp cast #5") {
val res = sql("SELECT CAST(CAST(-1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head
assert(-1 == res.get(0))
val res = sql("SELECT CAST(CAST(1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head
assert(1200 == res.getInt(0))
}

createQueryTest("timestamp cast #6",
"SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1")

test("timestamp cast #7") {
test("timestamp cast #6") {
val res = sql("SELECT CAST(CAST(-1200 AS TIMESTAMP) AS INT) FROM src LIMIT 1").collect().head
assert(-1200 == res.getInt(0))
}

createQueryTest("timestamp cast #8",
"SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1")

createQueryTest("select null from table",
"SELECT null FROM src LIMIT 1")

Expand Down