Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,21 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, FalseLiteral}
import org.apache.spark.sql.types.DataType

case class KnownNotNull(child: Expression) extends UnaryExpression {
override def nullable: Boolean = false
trait TaggingExpression extends UnaryExpression {
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = child.genCode(ctx)

override def eval(input: InternalRow): Any = child.eval(input)
}

case class KnownNotNull(child: Expression) extends TaggingExpression {
override def nullable: Boolean = false

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.genCode(ctx).copy(isNull = FalseLiteral)
}

override def eval(input: InternalRow): Any = {
child.eval(input)
}
}

case class KnownFloatingPointNormalized(child: Expression) extends TaggingExpression
Copy link
Contributor

@cloud-fan cloud-fan Jul 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we override toString here, so that it's invisible to end users when running EXPLAIN?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's already handled in Expression::toString?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan should it be invisible though? I'd rather leave a trace of the marker in the plan, but we could make it less verbose by making it something like adding a prefix to the child instead of the regular tostring, e.g. print
normalizing-transform(...)
instead of
knownfloatingpointnormalized(transform(...))

WDYT?

Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window}
Expand Down Expand Up @@ -61,7 +61,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case _: Subquery => plan

case _ => plan transform {
case w: Window if w.partitionSpec.exists(p => needNormalize(p.dataType)) =>
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
// Although the `windowExpressions` may refer to `partitionSpec` expressions, we don't need
// to normalize the `windowExpressions`, as they are executed per input row and should take
// the input row as it is.
Expand All @@ -73,7 +73,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
// The analyzer guarantees left and right joins keys are of the same data type. Here we
// only need to check join keys of one side.
if leftKeys.exists(k => needNormalize(k.dataType)) =>
if leftKeys.exists(k => needNormalize(k)) =>
val newLeftJoinKeys = leftKeys.map(normalize)
val newRightJoinKeys = rightKeys.map(normalize)
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
Expand All @@ -87,6 +87,14 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}
}

/**
* Short circuit if the underlying expression is already normalized
*/
private def needNormalize(expr: Expression): Boolean = expr match {
case KnownFloatingPointNormalized(_) => false
case _ => needNormalize(expr.dataType)
}

private def needNormalize(dt: DataType): Boolean = dt match {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
Expand All @@ -98,7 +106,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
}

private[sql] def normalize(expr: Expression): Expression = expr match {
case _ if !needNormalize(expr.dataType) => expr
case _ if !needNormalize(expr) => expr

case a: Alias =>
a.withNewChildren(Seq(normalize(a.child)))
Expand All @@ -116,7 +124,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
CreateMap(children.map(normalize))

case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
NormalizeNaNAndZero(expr)
KnownFloatingPointNormalized(NormalizeNaNAndZero(expr))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, from my understanding, we didn't quite like such approach though like analysis barrier. Scope here is small so might be fine but this doesn't particularly look like a good fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is from TransformArray, since we can't easily tell whether a TransformArray is for FP normalization or not. Otherwise we can just check for NormalizeNaNAndZero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we don't want to add a new kind of TransformArray node in the final logical plan either (and related logic)... I can't really think of an elegant approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has a much less impact than the AnalysisBarrier -- this only applies to expressions whereas the AnalysisBarrier applied to plans.
We'd to leave markers in place in case a plan gets re-optimized after the initial optimization, and we have to have something that provides such information persisted in the plan.

The alternative for providing this information would be something like having a new dedicated expression type for floating point array normalization, which would also be disruptive to the expression tree structure. In terms of code reuse and semantic clarity, I'd say Yesheng's current design strikes the best balance.


case _ if expr.dataType.isInstanceOf[StructType] =>
val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i =>
Expand All @@ -128,7 +136,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val ArrayType(et, containsNull) = expr.dataType
val lv = NamedLambdaVariable("arg", et, containsNull)
val function = normalize(lv)
ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))

case _ => throw new IllegalStateException(s"fail to normalize $expr")
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.KnownFloatingPointNormalized
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class NormalizeFloatingPointNumbersSuite extends PlanTest {

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("NormalizeFloatingPointNumbers", Once, NormalizeFloatingNumbers) :: Nil
}

val testRelation1 = LocalRelation('a.double)
val a = testRelation1.output(0)
val testRelation2 = LocalRelation('a.double)
val b = testRelation2.output(0)

test("normalize floating points in window function expressions") {
val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))

val optimized = Optimize.execute(query)
val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))

comparePlans(optimized, correctAnswer)
}

test("normalize floating points in window function expressions - idempotence") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we can remove this test after we add idempotence policy and change the once policy in this test suite to idempotence?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Do we have to add a mark here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not necessary, I just want to confirm it.

val query = testRelation1.window(Seq(sum(a).as("sum")), Seq(a), Seq(a.asc))

val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)
val correctAnswer = testRelation1.window(Seq(sum(a).as("sum")),
Seq(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))), Seq(a.asc))

comparePlans(doubleOptimized, correctAnswer)
}

test("normalize floating points in join keys") {
val query = testRelation1.join(testRelation2, condition = Some(a === b))

val optimized = Optimize.execute(query)
val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
=== KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)

comparePlans(optimized, correctAnswer)
}

test("normalize floating points in join keys - idempotence") {
val query = testRelation1.join(testRelation2, condition = Some(a === b))

val optimized = Optimize.execute(query)
val doubleOptimized = Optimize.execute(optimized)
val joinCond = Some(KnownFloatingPointNormalized(NormalizeNaNAndZero(a))
=== KnownFloatingPointNormalized(NormalizeNaNAndZero(b)))
val correctAnswer = testRelation1.join(testRelation2, condition = joinCond)

comparePlans(doubleOptimized, correctAnswer)
}
}