From d24e7f039e02f045f0cf8eac508aaa164c45eb80 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 Jul 2020 12:46:02 -0700 Subject: [PATCH] Not duplicate normalization on children for float/double If/CaseWhen/Coalesce. --- .../optimizer/NormalizeFloatingNumbers.scala | 6 ++--- .../NormalizeFloatingPointNumbersSuite.scala | 26 +++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 98c78c6312222..10f846cf910f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -116,6 +116,9 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateMap(children, useStringTypeWhenEmpty) => CreateMap(children.map(normalize), useStringTypeWhenEmpty) + case _ if expr.dataType == FloatType || expr.dataType == DoubleType => + KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) + case If(cond, trueValue, falseValue) => If(cond, normalize(trueValue), normalize(falseValue)) @@ -125,9 +128,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case Coalesce(children) => Coalesce(children.map(normalize)) - case _ if expr.dataType == FloatType || expr.dataType == DoubleType => - KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) - case _ if expr.dataType.isInstanceOf[StructType] => val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => normalize(GetStructField(expr, i)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala index 3f6bdd206535b..bb9919f94eef2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -85,25 +85,23 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { val optimized = Optimize.execute(query) val doubleOptimized = Optimize.execute(optimized) val joinCond = IsNull(a) === IsNull(b) && - coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)), - KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0))) === - coalesce(KnownFloatingPointNormalized(NormalizeNaNAndZero(b)), - KnownFloatingPointNormalized(NormalizeNaNAndZero(0.0))) + KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(a, 0.0))) === + KnownFloatingPointNormalized(NormalizeNaNAndZero(coalesce(b, 0.0))) val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) comparePlans(doubleOptimized, correctAnswer) } test("SPARK-32258: normalize the children of If") { - val cond = If(a > 0.1D, a, a + 0.2D) === b + val cond = If(a > 0.1D, namedStruct("a", a), namedStruct("a", a + 0.2D)) === namedStruct("a", b) val query = testRelation1.join(testRelation2, condition = Some(cond)) val optimized = Optimize.execute(query) val doubleOptimized = Optimize.execute(optimized) val joinCond = If(a > 0.1D, - KnownFloatingPointNormalized(NormalizeNaNAndZero(a)), - KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D))) === - KnownFloatingPointNormalized(NormalizeNaNAndZero(b)) + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D)))) === + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) comparePlans(doubleOptimized, correctAnswer) @@ -111,17 +109,17 @@ class NormalizeFloatingPointNumbersSuite extends PlanTest { test("SPARK-32258: normalize the children of CaseWhen") { val cond = CaseWhen( - Seq((a > 0.1D, a), (a > 0.2D, a + 0.2D)), - Some(a + 0.3D)) === b + Seq((a > 0.1D, namedStruct("a", a)), (a > 0.2D, namedStruct("a", a + 0.2D))), + Some(namedStruct("a", a + 0.3D))) === namedStruct("a", b) val query = testRelation1.join(testRelation2, condition = Some(cond)) val optimized = Optimize.execute(query) val doubleOptimized = Optimize.execute(optimized) val joinCond = CaseWhen( - Seq((a > 0.1D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), - (a > 0.2D, KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D)))), - Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.3D)))) === - KnownFloatingPointNormalized(NormalizeNaNAndZero(b)) + Seq((a > 0.1D, namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a)))), + (a > 0.2D, namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.2D))))), + Some(namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(a + 0.3D))))) === + namedStruct("a", KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) val correctAnswer = testRelation1.join(testRelation2, condition = Some(joinCond)) comparePlans(doubleOptimized, correctAnswer)