diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 72546ea73dd9..3c35ba9b6004 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -29,26 +29,13 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation} import org.apache.spark.sql.internal.SQLConf -/** - * A pattern that matches any number of project or filter operations even if they are - * non-deterministic, as long as they satisfy the requirement of CollapseProject and CombineFilters. - * All filter operators are collected and their conditions are broken up and returned - * together with the top project operator. [[Alias Aliases]] are in-lined/substituted if - * necessary. - */ -object PhysicalOperation extends AliasHelper with PredicateHelper { +trait OperationHelper extends AliasHelper with PredicateHelper { import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions - type ReturnType = - (Seq[NamedExpression], Seq[Expression], LogicalPlan) type IntermediateType = (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Alias]) - def unapply(plan: LogicalPlan): Option[ReturnType] = { - val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) - val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline) - Some((fields.getOrElse(child.output), filters, child)) - } + protected def collectAllFilters: Boolean /** * Collects all adjacent projects and filters, in-lining/substituting aliases if necessary. @@ -64,7 +51,7 @@ object PhysicalOperation extends AliasHelper with PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - private def collectProjectsAndFilters( + protected def collectProjectsAndFilters( plan: LogicalPlan, alwaysInline: Boolean): IntermediateType = { def empty: IntermediateType = (None, Nil, plan, AttributeMap.empty) @@ -84,16 +71,21 @@ object PhysicalOperation extends AliasHelper with PredicateHelper { // When collecting projects and filters, we effectively push down filters through // projects. We need to meet the following conditions to do so: // 1) no Project collected so far or the collected Projects are all deterministic - // 2) the collected filters and this filter are all deterministic, or this is the - // first collected filter. - // 3) this filter does not repeat any expensive expressions from the collected + // 2) this filter does not repeat any expensive expressions from the collected // projects. - val canIncludeThisFilter = fields.forall(_.forall(_.deterministic)) && { - filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) - } && canCollapseExpressions(Seq(condition), aliases, alwaysInline) - if (canIncludeThisFilter) { - val replaced = replaceAlias(condition, aliases) - (fields, filters ++ splitConjunctivePredicates(replaced), other, aliases) + val canPushFilterThroughProject = fields.forall(_.forall(_.deterministic)) && + canCollapseExpressions(Seq(condition), aliases, alwaysInline) + if (canPushFilterThroughProject) { + // Ideally we can't combine non-deterministic filters, but if `collectAllFilters` is true, + // we relax this restriction and assume the caller will take care of it. + val canIncludeThisFilter = filters.isEmpty || { + filters.last.deterministic && condition.deterministic + } + if (canIncludeThisFilter || collectAllFilters) { + (fields, filters :+ replaceAlias(condition, aliases), other, aliases) + } else { + empty + } } else { empty } @@ -105,6 +97,54 @@ object PhysicalOperation extends AliasHelper with PredicateHelper { } } +/** + * A pattern that matches any number of project or filter operations even if they are + * non-deterministic, as long as they satisfy the requirement of CollapseProject and CombineFilters. + * All filter operators are collected and their conditions are broken up and returned + * together with the top project operator. [[Alias Aliases]] are in-lined/substituted if + * necessary. + */ +object PhysicalOperation extends OperationHelper { + // Returns: (the final project list, filters to push down, relation) + type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) + override protected def collectAllFilters: Boolean = false + + def unapply(plan: LogicalPlan): Option[ReturnType] = { + val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline) + // If more than 2 filters are collected, they must all be deterministic. + if (filters.length > 1) assert(filters.forall(_.deterministic)) + Some(( + fields.getOrElse(child.output), + filters.flatMap(splitConjunctivePredicates), + child)) + } +} + +/** + * A variant of [[PhysicalOperation]] which can match multiple Filters that are not combinable due + * to non-deterministic predicates. This is useful for scan operations as we need to match a bunch + * of adjacent Projects/Filters to apply column pruning, even if the Filters can't be combined, + * such as `Project(a, Filter(rand() > 0.5, Filter(rand() < 0.8, TableScan)))`, which we should + * only read column `a` from the relation. + */ +object ScanOperation extends OperationHelper { + // Returns: (the final project list, filters to stay up, filters to push down, relation) + type ReturnType = (Seq[NamedExpression], Seq[Expression], Seq[Expression], LogicalPlan) + override protected def collectAllFilters: Boolean = true + + def unapply(plan: LogicalPlan): Option[ReturnType] = { + val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline) + // `collectProjectsAndFilters` transforms the plan bottom-up, so the bottom-most filter are + // placed at the beginning of `filters` list. According to the SQL semantic, we can only + // push down the bottom deterministic filters. + val filtersCanPushDown = filters.takeWhile(_.deterministic).flatMap(splitConjunctivePredicates) + val filtersStayUp = filters.dropWhile(_.deterministic) + Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child)) + } +} + object NodeWithOnlyDeterministicProjectAndFilter { def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { case Project(projectList, child) if projectList.forall(_.deterministic) => unapply(child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 72e48842b07d..576801d3dd59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.datasources.FileFormat.METADATA_NAME @@ -146,7 +146,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, + case ScanOperation(projects, stayUpFilters, filters, l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) => // Filters on this relation fall into four categories based on where we can use them to avoid // reading unneeded data: @@ -204,7 +204,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { val afterScanFilters = filterSet -- partitionKeyFilters.filter(_.references.nonEmpty) logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}") - val filterAttributes = AttributeSet(afterScanFilters) + val filterAttributes = AttributeSet(afterScanFilters ++ stayUpFilters) val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects val requiredAttributes = AttributeSet(requiredExpressions) @@ -222,8 +222,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { metadataColumns.filter(_.name != FileFormat.ROW_INDEX) val readDataColumns = dataColumns - .filter(requiredAttributes.contains) - .filterNot(partitionColumns.contains) + .filter(requiredAttributes.contains) + .filterNot(partitionColumns.contains) val fileFormatReaderGeneratedMetadataColumns: Seq[Attribute] = metadataColumns.map(_.name).flatMap { @@ -281,10 +281,11 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging { readDataColumns ++ partitionColumns :+ metadataAlias, scan) }.getOrElse(scan) - val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And) - val withFilter = afterScanFilter - .map(execution.FilterExec(_, withMetadataProjections)) - .getOrElse(withMetadataProjections) + // bottom-most filters are put in the left of the list. + val finalFilters = afterScanFilters.toSeq.reduceOption(expressions.And).toSeq ++ stayUpFilters + val withFilter = finalFilters.foldLeft(withMetadataProjections)((plan, cond) => { + execution.FilterExec(cond, plan) + }) val withProjections = if (projects == withFilter.output) { withFilter } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index 26d5d92fecb3..279fea6d64bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} import org.apache.spark.sql.util.SchemaUtils._ /** - * Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation. + * Prunes unnecessary physical columns given a [[ScanOperation]] over a data source relation. * By "physical column", we mean a column as defined in the data source format like Parquet format * or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL * column, and a nested Parquet column corresponds to a [[StructField]]. @@ -39,9 +39,10 @@ object SchemaPruning extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { - case op @ PhysicalOperation(projects, filters, + case op @ ScanOperation(projects, filtersStayUp, filtersPushDown, l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) => - prunePhysicalColumns(l, projects, filters, hadoopFsRelation, + val allFilters = filtersPushDown.reduceOption(And).toSeq ++ filtersStayUp + prunePhysicalColumns(l, projects, allFilters, hadoopFsRelation, (prunedDataSchema, prunedMetadataSchema) => { val prunedHadoopRelation = hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) @@ -61,9 +62,10 @@ object SchemaPruning extends Rule[LogicalPlan] { filters: Seq[Expression], hadoopFsRelation: HadoopFsRelation, leafNodeBuilder: (StructType, StructType) => LeafNode): Option[LogicalPlan] = { - - val (normalizedProjects, normalizedFilters) = - normalizeAttributeRefNames(relation.output, projects, filters) + val attrNameMap = relation.output.map(att => (att.exprId, att.name)).toMap + val normalizedProjects = normalizeAttributeRefNames(attrNameMap, projects) + .asInstanceOf[Seq[NamedExpression]] + val normalizedFilters = normalizeAttributeRefNames(attrNameMap, filters) val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters) // If requestedRootFields includes a nested field, continue. Otherwise, @@ -112,24 +114,17 @@ object SchemaPruning extends Rule[LogicalPlan] { fsRelation.fileFormat.isInstanceOf[OrcFileFormat]) /** - * Normalizes the names of the attribute references in the given projects and filters to reflect + * Normalizes the names of the attribute references in the given expressions to reflect * the names in the given logical relation. This makes it possible to compare attributes and * fields by name. Returns a tuple with the normalized projects and filters, respectively. */ private def normalizeAttributeRefNames( - output: Seq[AttributeReference], - projects: Seq[NamedExpression], - filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = { - val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap - val normalizedProjects = projects.map(_.transform { - case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => - att.withName(normalizedAttNameMap(att.exprId)) - }).map { case expr: NamedExpression => expr } - val normalizedFilters = filters.map(_.transform { - case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) => - att.withName(normalizedAttNameMap(att.exprId)) + attrNameMap: Map[ExprId, String], + exprs: Seq[Expression]): Seq[Expression] = { + exprs.map(_.transform { + case att: AttributeReference if attrNameMap.contains(att.exprId) => + att.withName(attrNameMap(att.exprId)) }) - (normalizedProjects, normalizedFilters) } /** @@ -148,8 +143,8 @@ object SchemaPruning extends Rule[LogicalPlan] { val projectedFilters = filters.map(_.transformDown { case projectionOverSchema(expr) => expr }) - val newFilterCondition = projectedFilters.reduce(And) - Filter(newFilterCondition, leafNode) + // bottom-most filters are put in the left of the list. + projectedFilters.foldLeft[LogicalPlan](leafNode)((plan, cond) => Filter(cond, plan)) } else { leafNode } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 27daa899583e..24ffe4b887d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -22,7 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.CollapseProject -import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} @@ -345,13 +345,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform { - case PhysicalOperation(project, filters, sHolder: ScanBuilderHolder) => + case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) => // column pruning val normalizedProjects = DataSourceStrategy .normalizeExprs(project, sHolder.output) .asInstanceOf[Seq[NamedExpression]] + val allFilters = filtersStayUp ++ filtersPushDown.reduceOption(And) + val normalizedFilters = DataSourceStrategy.normalizeExprs(allFilters, sHolder.output) val (scan, output) = PushDownUtils.pruneColumns( - sHolder.builder, sHolder.relation, normalizedProjects, filters) + sHolder.builder, sHolder.relation, normalizedProjects, normalizedFilters) logInfo( s""" @@ -368,11 +370,12 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case projectionOverSchema(newExpr) => newExpr } - val filterCondition = filters.reduceLeftOption(And) - val newFilterCondition = filterCondition.map(projectionFunc) - val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation) + val finalFilters = normalizedFilters.map(projectionFunc) + val withFilter = finalFilters.foldRight[LogicalPlan](scanRelation)((cond, plan) => { + Filter(cond, plan) + }) - val withProjection = if (withFilter.output != project) { + if (withFilter.output != project) { val newProjects = normalizedProjects .map(projectionFunc) .asInstanceOf[Seq[NamedExpression]] @@ -380,12 +383,11 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } else { withFilter } - withProjection } def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { case sample: Sample => sample.child match { - case PhysicalOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + case PhysicalOperation(_, Nil, sHolder: ScanBuilderHolder) => val tableSample = TableSampleInfo( sample.lowerBound, sample.upperBound, @@ -404,7 +406,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match { - case operation @ PhysicalOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + case operation @ PhysicalOperation(_, Nil, sHolder: ScanBuilderHolder) => val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit) if (isPushed) { sHolder.pushedLimit = Some(limit) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index f0d8e5ddaf70..1f6a6bfda66c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -1104,4 +1104,11 @@ abstract class SchemaPruningSuite checkAnswer(query2.orderBy("id"), Row("John", "Y.")) } + + testSchemaPruning("SPARK-41017: column pruning through 2 filters") { + import testImplicits._ + val query = spark.table("contacts").filter(rand() > 0.5).filter(rand() < 0.8) + .select($"id", $"name.first") + checkScan(query, "struct>") + } }