Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,43 @@ import org.apache.spark.util.Utils
abstract class Optimizer(sessionCatalog: SessionCatalog)
extends RuleExecutor[LogicalPlan] {

// Check for structural integrity of the plan in test mode. Currently we only check if a plan is
// still resolved after the execution of each rule.
// Check for structural integrity of the plan in test mode.
// Currently we check after the execution of each rule if a plan:
// - is still resolved
// - only host special expressions in supported operators
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || plan.resolved
!Utils.isTesting || (plan.resolved && checkSpecialExpressionIntegrity(plan))
}

/**
* Check if all operators in this plan hold structural integrity with regards to hosting special
* expressions.
* Returns true when all operators are integral.
*/
private def checkSpecialExpressionIntegrity(plan: LogicalPlan): Boolean = {
plan.find(specialExpressionInUnsupportedOperator).isEmpty
}

/**
* Check if there's any expression in this query plan operator that is
* - A WindowExpression but the plan is not Window
* - An AggregateExpresion but the plan is not Aggregate or Window
* - A Generator but the plan is not Generate
* Returns true when this operator breaks structural integrity with one of the cases above.
*/
private def specialExpressionInUnsupportedOperator(plan: LogicalPlan): Boolean = {
val exprs = plan.expressions
exprs.flatMap { root =>
root.find {
case e: WindowExpression
if !plan.isInstanceOf[Window] => true
case e: AggregateExpression
if !(plan.isInstanceOf[Aggregate] || plan.isInstanceOf[Window]) => true
case e: Generator
if !plan.isInstanceOf[Generate] => true
case _ => false
}
}.nonEmpty
}

protected def fixedPoint = FixedPoint(SQLConf.get.optimizerMaxIterations)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,13 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
val planChangeLogger = new PlanChangeLogger()
val tracker: Option[QueryPlanningTracker] = QueryPlanningTracker.get

// Run the structural integrity checker against the initial input
if (!isPlanIntegral(plan)) {
val message = "The structural integrity of the input plan is broken in " +
s"${this.getClass.getName.stripSuffix("$")}."
throw new TreeNodeException(plan, message, null)
}

batches.foreach { batch =>
val batchStartPlan = curPlan
var iteration = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation, LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf

Expand All @@ -35,6 +36,9 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
case Project(projectList, child) =>
val newAttr = UnresolvedAttribute("unresolvedAttr")
Project(projectList ++ Seq(newAttr), child)
case agg @ Aggregate(Nil, aggregateExpressions, child) =>
// Project cannot host AggregateExpression
Project(aggregateExpressions, child)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment here to explain Project is unable to handle Aggregate expressions.

}
}

Expand All @@ -47,7 +51,7 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches
}

test("check for invalid plan after execution of rule") {
test("check for invalid plan after execution of rule - unresolved attribute") {
val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
assert(analyzed.resolved)
val message = intercept[TreeNodeException[LogicalPlan]] {
Expand All @@ -57,4 +61,35 @@ class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {
assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
assert(message.contains("the structural integrity of the plan is broken"))
}

test("check for invalid plan after execution of rule - special expression in wrong operator") {
val analyzed =
Aggregate(Nil, Seq[NamedExpression](max('id) as 'm),
LocalRelation('id.long)).analyze
assert(analyzed.resolved)

// Should fail verification with the OptimizeRuleBreakSI rule
val message = intercept[TreeNodeException[LogicalPlan]] {
Optimize.execute(analyzed)
}.getMessage
val ruleName = OptimizeRuleBreakSI.ruleName
assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
assert(message.contains("the structural integrity of the plan is broken"))

// Should not fail verification with the regular optimizer
SimpleTestOptimizer.execute(analyzed)
}

test("check for invalid plan before execution of any rule") {
val analyzed =
Aggregate(Nil, Seq[NamedExpression](max('id) as 'm),
LocalRelation('id.long)).analyze
val invalidPlan = OptimizeRuleBreakSI.apply(analyzed)

// Should fail verification right at the beginning
val message = intercept[TreeNodeException[LogicalPlan]] {
Optimize.execute(invalidPlan)
}.getMessage
assert(message.contains("The structural integrity of the input plan is broken"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class RuleExecutorSuite extends SparkFunSuite {
assert(message.contains("Max iterations (10) reached for batch fixedPoint"))
}

test("structural integrity checker") {
test("structural integrity checker - verify initial input") {
object WithSIChecker extends RuleExecutor[Expression] {
override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
case IntegerLiteral(_) => true
Expand All @@ -69,8 +69,26 @@ class RuleExecutorSuite extends SparkFunSuite {
assert(WithSIChecker.execute(Literal(10)) === Literal(9))

val message = intercept[TreeNodeException[LogicalPlan]] {
// The input is already invalid as determined by WithSIChecker.isPlanIntegral
WithSIChecker.execute(Literal(10.1))
}.getMessage
assert(message.contains("The structural integrity of the input plan is broken"))
}

test("structural integrity checker - verify rule execution result") {
object WithSICheckerForPositiveLiteral extends RuleExecutor[Expression] {
override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
case IntegerLiteral(i) if i > 0 => true
case _ => false
}
val batches = Batch("once", Once, DecrementLiterals) :: Nil
}

assert(WithSICheckerForPositiveLiteral.execute(Literal(2)) === Literal(1))

val message = intercept[TreeNodeException[LogicalPlan]] {
WithSICheckerForPositiveLiteral.execute(Literal(1))
}.getMessage
assert(message.contains("the structural integrity of the plan is broken"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1186,7 +1186,7 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
}
}

test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") {
ignore("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") {
withTempView("t1", "t2", "t3") {
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1")
Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2")
Expand Down