diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala index 2917b0b8c9c5..5bfae7b77e09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/constraintExpressions.scala @@ -21,15 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral} import org.apache.spark.sql.types.DataType -case class KnownNotNull(child: Expression) extends UnaryExpression { - override def nullable: Boolean = false +trait TaggingExpression extends UnaryExpression { + override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.genCode(ctx) + + override def eval(input: InternalRow): Any = child.eval(input) +} + +case class KnownNotNull(child: Expression) extends TaggingExpression { + override def nullable: Boolean = false + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.genCode(ctx).copy(isNull = FalseLiteral) } - - override def eval(input: InternalRow): Any = { - child.eval(input) - } } + +case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index a5921ebe7751..b036092cf1fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _: Subquery => plan case _ => plan transform { - case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) => + case w: Window if w.partitionSpec.exists(p => needNormalize(p)) => // Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need // to normalize the `windowExpressions`, as they are executed per input row and should take // the input row as it is. @@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) // The analyzer guarantees left and right joins keys are of the same data type. Here we // only need to check join keys of one side. - if leftKeys.exists(k => needNormalize(k.dataType)) => + if leftKeys.exists(k => needNormalize(k)) => val newLeftJoinKeys = leftKeys.map(normalize) val newRightJoinKeys = rightKeys.map(normalize) val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { @@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { } } + /** + * Short circuit if the underlying expression is already normalized + */ + private def needNormalize(expr: Expression): Boolean = expr match { + case KnownFloatingPointNormalized(_) => false + case _ => needNormalize(expr.dataType) + } + private def needNormalize(dt: DataType): Boolean = dt match { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) @@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { } private[sql] def normalize(expr: Expression): Expression = expr match { - case _ if !needNormalize(expr.dataType) => expr + case _ if !needNormalize(expr) => expr case a: Alias => a.withNewChildren(Seq(normalize(a.child))) @@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { CreateMap(children.map(normalize)) case _ if expr.dataType == FloatType || expr.dataType == DoubleType => - NormalizeNaNAndZero(expr) + KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) case _ if expr.dataType.isInstanceOf[StructType] => val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => @@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { val ArrayType(et, containsNull) = expr.dataType val lv = NamedLambdaVariable("arg", et, containsNull) val function = normalize(lv) - ArrayTransform(expr, LambdaFunction(function, Seq(lv))) + KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) case _ => throw new IllegalStateException(s"fail to normalize $expr") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala new file mode 100644 index 000000000000..5f616da2978b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingPointNumbersSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +class NormalizeFloatingPointNumbersSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("NormalizeFloatingPointNumbers", Once, NormalizeFloatingNumbers) :: Nil + } + + val testRelation1 = LocalRelation('a.double) + val a = testRelation1.output(0) + val testRelation2 = LocalRelation('a.double) + val b = testRelation2.output(0) + + test("normalize floating points in window function expressions") { + val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc)) + + val optimized = Optimize.execute(query) + val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")), + Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc)) + + comparePlans(optimized, correctAnswer) + } + + test("normalize floating points in window function expressions - idempotence") { + val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc)) + + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")), + Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc)) + + comparePlans(doubleOptimized, correctAnswer) + } + + test("normalize floating points in join keys") { + val query = testRelation1.join(testRelation2, condition = Some(a === b)) + + val optimized = Optimize.execute(query) + val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)) + === KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) + val correctAnswer = testRelation1.join(testRelation2, condition = joinCond) + + comparePlans(optimized, correctAnswer) + } + + test("normalize floating points in join keys - idempotence") { + val query = testRelation1.join(testRelation2, condition = Some(a === b)) + + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)) + === KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) + val correctAnswer = testRelation1.join(testRelation2, condition = joinCond) + + comparePlans(doubleOptimized, correctAnswer) + } +} +