From 41a318eb5808e3f3838c12301330ea9c1b3a351f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jun 2020 22:57:13 -0700 Subject: [PATCH 1/2] NormalizeFloatingNumbers should work on null struct. --- .../optimizer/NormalizeFloatingNumbers.scala | 5 +++-- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 43738204c6704..230d990be1824 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -123,7 +123,8 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => normalize(GetStructField(expr, i)) } - CreateStruct(fields) + val struct = CreateStruct(fields) + If(IsNull(expr), Literal(null, struct.dataType), struct) case _ if expr.dataType.isInstanceOf[ArrayType] => val ArrayType(et, containsNull) = expr.dataType diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index f7438f3ffec04..09f30bb5e2c77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1028,4 +1028,16 @@ class DataFrameAggregateSuite extends QueryTest checkAnswer(df, Row("abellina", 2) :: Row("mithunr", 1) :: Nil) } } + + test("SPARK-32136: NormalizeFloatingNumbers should work on null struct") { + val df = Seq( + A(None), + A(Some(B(None))), + A(Some(B(Some(1.0))))).toDF + val groupBy = df.groupBy("b").agg(count("*")) + checkAnswer(groupBy, Row(null, 1) :: Row(Row(null), 1) :: Row(Row(1.0), 1) :: Nil) + } } + +case class B(c: Option[Double]) +case class A(b: Option[B]) From d0586c94512ce1def1fdcdca37dc70663346334b Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 1 Jul 2020 09:49:17 -0700 Subject: [PATCH 2/2] Fix test. --- .../spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 230d990be1824..8d5dbc7dc90eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -124,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { normalize(GetStructField(expr, i)) } val struct = CreateStruct(fields) - If(IsNull(expr), Literal(null, struct.dataType), struct) + KnownFloatingPointNormalized(If(IsNull(expr), Literal(null, struct.dataType), struct)) case _ if expr.dataType.isInstanceOf[ArrayType] => val ArrayType(et, containsNull) = expr.dataType