-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-28306][SQL] Make NormalizeFloatingNumbers rule idempotent #25080
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression} | ||
| import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
| import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} | ||
|
|
@@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
| case _: Subquery => plan | ||
|
|
||
| case _ => plan transform { | ||
| case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) => | ||
| case w: Window if w.partitionSpec.exists(p => needNormalize(p)) => | ||
| // Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need | ||
| // to normalize the `windowExpressions`, as they are executed per input row and should take | ||
| // the input row as it is. | ||
|
|
@@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
| case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) | ||
| // The analyzer guarantees left and right joins keys are of the same data type. Here we | ||
| // only need to check join keys of one side. | ||
| if leftKeys.exists(k => needNormalize(k.dataType)) => | ||
| if leftKeys.exists(k => needNormalize(k)) => | ||
| val newLeftJoinKeys = leftKeys.map(normalize) | ||
| val newRightJoinKeys = rightKeys.map(normalize) | ||
| val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { | ||
|
|
@@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Short circuit if the underlying expression is already normalized | ||
| */ | ||
| private def needNormalize(expr: Expression): Boolean = expr match { | ||
| case KnownFloatingPointNormalized(_) => false | ||
| case _ => needNormalize(expr.dataType) | ||
| } | ||
|
|
||
| private def needNormalize(dt: DataType): Boolean = dt match { | ||
| case FloatType | DoubleType => true | ||
| case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) | ||
|
|
@@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
| } | ||
|
|
||
| private[sql] def normalize(expr: Expression): Expression = expr match { | ||
| case _ if !needNormalize(expr.dataType) => expr | ||
| case _ if !needNormalize(expr) => expr | ||
|
|
||
| case a: Alias => | ||
| a.withNewChildren(Seq(normalize(a.child))) | ||
|
|
@@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
| CreateMap(children.map(normalize)) | ||
|
|
||
| case _ if expr.dataType == FloatType || expr.dataType == DoubleType => | ||
| NormalizeNaNAndZero(expr) | ||
| KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, from my understanding, we didn't quite like such approach though like analysis barrier. Scope here is small so might be fine but this doesn't particularly look like a good fix.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is from
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And we don't want to add a new kind of
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has a much less impact than the The alternative for providing this information would be something like having a new dedicated expression type for floating point array normalization, which would also be disruptive to the expression tree structure. In terms of code reuse and semantic clarity, I'd say Yesheng's current design strikes the best balance. |
||
|
|
||
| case _ if expr.dataType.isInstanceOf[StructType] => | ||
| val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => | ||
|
|
@@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
| val ArrayType(et, containsNull) = expr.dataType | ||
| val lv = NamedLambdaVariable("arg", et, containsNull) | ||
| val function = normalize(lv) | ||
| ArrayTransform(expr, LambdaFunction(function, Seq(lv))) | ||
| KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) | ||
|
|
||
| case _ => throw new IllegalStateException(s"fail to normalize $expr") | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.optimizer | ||
|
|
||
| import org.apache.spark.sql.catalyst.dsl.expressions._ | ||
| import org.apache.spark.sql.catalyst.dsl.plans._ | ||
| import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized | ||
| import org.apache.spark.sql.catalyst.plans.PlanTest | ||
| import org.apache.spark.sql.catalyst.plans.logical._ | ||
| import org.apache.spark.sql.catalyst.rules.RuleExecutor | ||
|
|
||
| class NormalizeFloatingPointNumbersSuite extends PlanTest { | ||
|
|
||
| object Optimize extends RuleExecutor[LogicalPlan] { | ||
| val batches = Batch("NormalizeFloatingPointNumbers", Once, NormalizeFloatingNumbers) :: Nil | ||
| } | ||
|
|
||
| val testRelation1 = LocalRelation('a.double) | ||
| val a = testRelation1.output(0) | ||
| val testRelation2 = LocalRelation('a.double) | ||
| val b = testRelation2.output(0) | ||
|
|
||
| test("normalize floating points in window function expressions") { | ||
| val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc)) | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
| val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")), | ||
| Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc)) | ||
|
|
||
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("normalize floating points in window function expressions - idempotence") { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so we can remove this test after we add idempotence policy and change the once policy in this test suite to idempotence?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep. Do we have to add a mark here?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not necessary, I just want to confirm it. |
||
| val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc)) | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
| val doubleOptimized = Optimize.execute(optimized) | ||
| val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")), | ||
| Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc)) | ||
|
|
||
| comparePlans(doubleOptimized, correctAnswer) | ||
| } | ||
|
|
||
| test("normalize floating points in join keys") { | ||
| val query = testRelation1.join(testRelation2, condition = Some(a === b)) | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
| val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)) | ||
| === KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) | ||
| val correctAnswer = testRelation1.join(testRelation2, condition = joinCond) | ||
|
|
||
| comparePlans(optimized, correctAnswer) | ||
| } | ||
|
|
||
| test("normalize floating points in join keys - idempotence") { | ||
| val query = testRelation1.join(testRelation2, condition = Some(a === b)) | ||
|
|
||
| val optimized = Optimize.execute(query) | ||
| val doubleOptimized = Optimize.execute(optimized) | ||
| val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a)) | ||
| === KnownFloatingPointNormalized(NormalizeNaNAndZero(b))) | ||
| val correctAnswer = testRelation1.join(testRelation2, condition = joinCond) | ||
|
|
||
| comparePlans(doubleOptimized, correctAnswer) | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we override
toStringhere, so that it's invisible to end users when running EXPLAIN?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's already handled in
Expression::toString?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cloud-fan should it be invisible though? I'd rather leave a trace of the marker in the plan, but we could make it less verbose by making it something like adding a prefix to the child instead of the regular tostring, e.g. print
normalizing-transform(...)instead of
knownfloatingpointnormalized(transform(...))WDYT?