diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 80c25d0b0fb7a..fffcc7c9ef53a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -105,12 +105,22 @@ case class AggregateExpression( } // We compute the same thing regardless of our final result. - override lazy val canonicalized: Expression = + override lazy val canonicalized: Expression = { + val normalizedAggFunc = mode match { + // For PartialMerge or Final mode, the input to the `aggregateFunction` is aggregate buffers, + // and the actual children of `aggregateFunction` is not used, here we normalize the expr id. + case PartialMerge | Final => aggregateFunction.transform { + case a: AttributeReference => a.withExprId(ExprId(0)) + } + case Partial | Complete => aggregateFunction + } + AggregateExpression( - aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + normalizedAggFunc.canonicalized.asInstanceOf[AggregateFunction], mode, isDistinct, ExprId(0)) + } override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index ff6c93ae9815c..4ed6728994193 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.json import java.io.ByteArrayOutputStream -import java.util.Locale import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -126,16 +125,11 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase(Locale.ROOT) - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toFloat - } else { - throw new RuntimeException(s"Cannot parse $value as FloatType.") + parser.getText match { + case "NaN" => Float.NaN + case "Infinity" => Float.PositiveInfinity + case "-Infinity" => Float.NegativeInfinity + case other => throw new RuntimeException(s"Cannot parse $other as FloatType.") } } @@ -146,16 +140,11 @@ class JacksonParser( case VALUE_STRING => // Special case handling for NaN and Infinity. - val value = parser.getText - val lowerCaseValue = value.toLowerCase(Locale.ROOT) - if (lowerCaseValue.equals("nan") || - lowerCaseValue.equals("infinity") || - lowerCaseValue.equals("-infinity") || - lowerCaseValue.equals("inf") || - lowerCaseValue.equals("-inf")) { - value.toDouble - } else { - throw new RuntimeException(s"Cannot parse $value as DoubleType.") + parser.getText match { + case "NaN" => Double.NaN + case "Infinity" => Double.PositiveInfinity + case "-Infinity" => Double.NegativeInfinity + case other => throw new RuntimeException(s"Cannot parse $other as DoubleType.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 51faa333307b3..959fcf7c7548e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -286,7 +286,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT def recursiveTransform(arg: Any): AnyRef = arg match { case e: Expression => transformExpression(e) - case Some(e: Expression) => Some(transformExpression(e)) + case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) @@ -320,7 +320,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT productIterator.flatMap { case e: Expression => e :: Nil - case Some(e: Expression) => e :: Nil + case s: Some[_] => seqToExpressions(s.toSeq) case seq: Traversable[_] => seqToExpressions(seq) case other => Nil }.toSeq diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala index 25e4ca060ae02..aaf51b5b90111 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SameResultSuite.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext /** * Tests for the sameResult function for [[SparkPlan]]s. */ class SameResultSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("FileSourceScanExec: different orders of data filters and partition filters") { withTempPath { path => @@ -46,4 +48,14 @@ class SameResultSuite extends QueryTest with SharedSQLContext { df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get .asInstanceOf[FileSourceScanExec] } + + test("SPARK-20725: partial aggregate should behave correctly for sameResult") { + val df1 = spark.range(10).agg(sum($"id")) + val df2 = spark.range(10).agg(sum($"id")) + assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan)) + + val df3 = spark.range(10).agg(sumDistinct($"id")) + val df4 = spark.range(10).agg(sumDistinct($"id")) + assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 5e7f7944bd845..e66a60d7503f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} +import java.util.Locale import com.fasterxml.jackson.core.JsonFactory import org.apache.hadoop.fs.{Path, PathFilter} @@ -1988,4 +1989,43 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { assert(errMsg.startsWith("The field for corrupt records must be string type and nullable")) } } + + test("SPARK-18772: Parse special floats correctly") { + val jsons = Seq( + """{"a": "NaN"}""", + """{"a": "Infinity"}""", + """{"a": "-Infinity"}""") + + // positive cases + val checks: Seq[Double => Boolean] = Seq( + _.isNaN, + _.isPosInfinity, + _.isNegInfinity) + + Seq(FloatType, DoubleType).foreach { dt => + jsons.zip(checks).foreach { case (json, check) => + val ds = spark.read + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(json).toDS()) + .select($"a".cast(DoubleType)).as[Double] + assert(check(ds.first())) + } + } + + // negative cases + Seq(FloatType, DoubleType).foreach { dt => + val lowerCasedJsons = jsons.map(_.toLowerCase(Locale.ROOT)) + // The special floats are case-sensitive so these cases below throw exceptions. + lowerCasedJsons.foreach { lowerCasedJson => + val e = intercept[SparkException] { + spark.read + .option("mode", "FAILFAST") + .schema(StructType(Seq(StructField("a", dt)))) + .json(Seq(lowerCasedJson).toDS()) + .collect() + } + assert(e.getMessage.contains("Cannot parse")) + } + } + } }