diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 105623c767d66..2ff67689c3492 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -135,7 +135,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("Subquery", Once, OptimizeSubqueries) :: Batch("Replace Operators", fixedPoint, - RewriteExcepAll, + RewriteExceptAll, RewriteIntersectAll, ReplaceIntersectWithSemiJoin, ReplaceExceptWithFilter, @@ -189,6 +189,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ReplaceIntersectWithSemiJoin.ruleName :: ReplaceExceptWithFilter.ruleName :: ReplaceExceptWithAntiJoin.ruleName :: + RewriteExceptAll.ruleName :: + RewriteIntersectAll.ruleName :: ReplaceDistinctWithAggregate.ruleName :: PullupCorrelatedPredicates.ruleName :: RewritePredicateSubquery.ruleName :: Nil @@ -1462,7 +1464,7 @@ object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] { * }}} */ -object RewriteExcepAll extends Rule[LogicalPlan] { +object RewriteExceptAll extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Except(left, right, true) => assert(left.output.size == right.output.size) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala index 30c80d26b67a1..eee8dc3b76c34 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerRuleExclusionSuite.scala @@ -80,12 +80,14 @@ class OptimizerRuleExclusionSuite extends PlanTest { "DummyRuleName")) } - test("Try to exclude a non-excludable rule") { + test("Try to exclude some non-excludable rules") { verifyExcludedRules( new SimpleTestOptimizer(), Seq( ReplaceIntersectWithSemiJoin.ruleName, - PullupCorrelatedPredicates.ruleName)) + PullupCorrelatedPredicates.ruleName, + RewriteExceptAll.ruleName, + RewriteIntersectAll.ruleName)) } test("Custom optimizer") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala index cb744be400603..da3923f8d6477 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationSuite.scala @@ -148,7 +148,7 @@ class SetOperationSuite extends PlanTest { test("EXCEPT ALL rewrite") { val input = Except(testRelation, testRelation2, isAll = true) - val rewrittenPlan = RewriteExcepAll(input) + val rewrittenPlan = RewriteExceptAll(input) val planFragment = testRelation.select(Literal(1L).as("vcol"), 'a, 'b, 'c) .union(testRelation2.select(Literal(-1L).as("vcol"), 'd, 'e, 'f))