diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 5bdaa504a3beb..7586bdf4392f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -79,7 +79,6 @@ abstract class Optimizer(catalogManager: CatalogManager) PushLeftSemiLeftAntiThroughJoin, LimitPushDown, ColumnPruning, - InferFiltersFromConstraints, // Operator combine CollapseRepartition, CollapseProject, @@ -117,14 +116,13 @@ abstract class Optimizer(catalogManager: CatalogManager) extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { - val rulesWithoutInferFiltersFromConstraints = - operatorOptimizationRuleSet.filterNot(_ == InferFiltersFromConstraints) Batch("Operator Optimization before Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: + operatorOptimizationRuleSet: _*) :: Batch("Infer Filters", Once, + InferFiltersFromGenerate, InferFiltersFromConstraints) :: Batch("Operator Optimization after Inferring Filters", fixedPoint, - rulesWithoutInferFiltersFromConstraints: _*) :: + operatorOptimizationRuleSet: _*) :: // Set strategy to Once to avoid pushing filter every time because we do not change the // join condition. Batch("Push extra predicate through join", fixedPoint, @@ -868,6 +866,41 @@ object TransposeWindow extends Rule[LogicalPlan] { } } +/** + * Infers filters from [[Generate]], such that rows that would have been removed + * by this [[Generate]] can be removed earlier - before joins and in data sources. + */ +object InferFiltersFromGenerate extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + // This rule does not infer filters from foldable expressions to avoid constant filters + // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints and + // then the idempotence will break. + case generate @ Generate(e, _, _, _, _, _) + if !e.deterministic || e.children.forall(_.foldable) => generate + + case generate @ Generate(g, _, false, _, _, _) if canInferFilters(g) => + // Exclude child's constraints to guarantee idempotency + val inferredFilters = ExpressionSet( + Seq( + GreaterThan(Size(g.children.head), Literal(0)), + IsNotNull(g.children.head) + ) + ) -- generate.child.constraints + + if (inferredFilters.nonEmpty) { + generate.copy(child = Filter(inferredFilters.reduce(And), generate.child)) + } else { + generate + } + } + + private def canInferFilters(g: Generator): Boolean = g match { + case _: ExplodeBase => true + case _: Inline => true + case _ => false + } +} + /** * Generate a list of additional filters from an operator's existing constraint but remove those * that are either already part of the operator's condition or are part of the operator's child diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala new file mode 100644 index 0000000000000..3f83971aa9821 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromGenerateSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class InferFiltersFromGenerateSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Infer Filters", Once, InferFiltersFromGenerate) :: Nil + } + + val testRelation = LocalRelation('a.array(StructType(Seq( + StructField("x", IntegerType), + StructField("y", IntegerType) + )))) + + Seq(Explode(_), PosExplode(_), Inline(_)).foreach { f => + val generator = f('a) + test("Infer filters from " + generator) { + val originalQuery = testRelation.generate(generator).analyze + val correctAnswer = testRelation + .where(IsNotNull('a) && Size('a) > 0) + .generate(generator) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, correctAnswer) + } + + test("Don't infer duplicate filters from " + generator) { + val originalQuery = testRelation + .where(IsNotNull('a) && Size('a) > 0) + .generate(generator) + .analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + + test("Don't infer filters from outer " + generator) { + val originalQuery = testRelation.generate(generator, outer = true).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + + val foldableExplode = f(CreateArray(Seq( + CreateStruct(Seq(Literal(0), Literal(1))), + CreateStruct(Seq(Literal(2), Literal(3))) + ))) + test("Don't infer filters from " + foldableExplode) { + val originalQuery = testRelation.generate(foldableExplode).analyze + val optimized = Optimize.execute(originalQuery) + comparePlans(optimized, originalQuery) + } + } +}