Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,13 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation}
import org.apache.spark.sql.internal.SQLConf

/**
* A pattern that matches any number of project or filter operations even if they are
* non-deterministic, as long as they satisfy the requirement of CollapseProject and CombineFilters.
* All filter operators are collected and their conditions are broken up and returned
* together with the top project operator. [[Alias Aliases]] are in-lined/substituted if
* necessary.
*/
object PhysicalOperation extends AliasHelper with PredicateHelper {
trait OperationHelper extends AliasHelper with PredicateHelper {
import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions

type ReturnType =
(Seq[NamedExpression], Seq[Expression], LogicalPlan)
type IntermediateType =
(Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Alias])

def unapply(plan: LogicalPlan): Option[ReturnType] = {
val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline)
Some((fields.getOrElse(child.output), filters, child))
}
protected def collectAllFilters: Boolean

/**
* Collects all adjacent projects and filters, in-lining/substituting aliases if necessary.
Expand All @@ -64,7 +51,7 @@ object PhysicalOperation extends AliasHelper with PredicateHelper {
* SELECT key AS c2 FROM t1 WHERE key > 10
* }}}
*/
private def collectProjectsAndFilters(
protected def collectProjectsAndFilters(
plan: LogicalPlan,
alwaysInline: Boolean): IntermediateType = {
def empty: IntermediateType = (None, Nil, plan, AttributeMap.empty)
Expand All @@ -84,16 +71,21 @@ object PhysicalOperation extends AliasHelper with PredicateHelper {
// When collecting projects and filters, we effectively push down filters through
// projects. We need to meet the following conditions to do so:
// 1) no Project collected so far or the collected Projects are all deterministic
// 2) the collected filters and this filter are all deterministic, or this is the
// first collected filter.
// 3) this filter does not repeat any expensive expressions from the collected
// 2) this filter does not repeat any expensive expressions from the collected
// projects.
val canIncludeThisFilter = fields.forall(_.forall(_.deterministic)) && {
filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic)
} && canCollapseExpressions(Seq(condition), aliases, alwaysInline)
if (canIncludeThisFilter) {
val replaced = replaceAlias(condition, aliases)
(fields, filters ++ splitConjunctivePredicates(replaced), other, aliases)
val canPushFilterThroughProject = fields.forall(_.forall(_.deterministic)) &&
canCollapseExpressions(Seq(condition), aliases, alwaysInline)
if (canPushFilterThroughProject) {
// Ideally we can't combine non-deterministic filters, but if `collectAllFilters` is true,
// we relax this restriction and assume the caller will take care of it.
val canIncludeThisFilter = filters.isEmpty || {
filters.last.deterministic && condition.deterministic
}
if (canIncludeThisFilter || collectAllFilters) {
(fields, filters :+ replaceAlias(condition, aliases), other, aliases)
} else {
empty
}
} else {
empty
}
Expand All @@ -105,6 +97,54 @@ object PhysicalOperation extends AliasHelper with PredicateHelper {
}
}

/**
* A pattern that matches any number of project or filter operations even if they are
* non-deterministic, as long as they satisfy the requirement of CollapseProject and CombineFilters.
* All filter operators are collected and their conditions are broken up and returned
* together with the top project operator. [[Alias Aliases]] are in-lined/substituted if
* necessary.
*/
object PhysicalOperation extends OperationHelper {
// Returns: (the final project list, filters to push down, relation)
type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
override protected def collectAllFilters: Boolean = false

def unapply(plan: LogicalPlan): Option[ReturnType] = {
val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline)
// If more than 2 filters are collected, they must all be deterministic.
if (filters.length > 1) assert(filters.forall(_.deterministic))
Some((
fields.getOrElse(child.output),
filters.flatMap(splitConjunctivePredicates),
child))
}
}

/**
* A variant of [[PhysicalOperation]] which can match multiple Filters that are not combinable due
* to non-deterministic predicates. This is useful for scan operations as we need to match a bunch
* of adjacent Projects/Filters to apply column pruning, even if the Filters can't be combined,
* such as `Project(a, Filter(rand() > 0.5, Filter(rand() < 0.8, TableScan)))`, which we should
* only read column `a` from the relation.
*/
object ScanOperation extends OperationHelper {
// Returns: (the final project list, filters to stay up, filters to push down, relation)
type ReturnType = (Seq[NamedExpression], Seq[Expression], Seq[Expression], LogicalPlan)
override protected def collectAllFilters: Boolean = true

def unapply(plan: LogicalPlan): Option[ReturnType] = {
val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline)
// `collectProjectsAndFilters` transforms the plan bottom-up, so the bottom-most filter are
// placed at the beginning of `filters` list. According to the SQL semantic, we can only
// push down the bottom deterministic filters.
val filtersCanPushDown = filters.takeWhile(_.deterministic).flatMap(splitConjunctivePredicates)
val filtersStayUp = filters.dropWhile(_.deterministic)
Some((fields.getOrElse(child.output), filtersStayUp, filtersCanPushDown, child))
}
}

object NodeWithOnlyDeterministicProjectAndFilter {
def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match {
case Project(projectList, child) if projectList.forall(_.deterministic) => unapply(child)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{FileSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.FileFormat.METADATA_NAME
Expand Down Expand Up @@ -146,7 +146,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
}

def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case PhysicalOperation(projects, filters,
case ScanOperation(projects, stayUpFilters, filters,
l @ LogicalRelation(fsRelation: HadoopFsRelation, _, table, _)) =>
// Filters on this relation fall into four categories based on where we can use them to avoid
// reading unneeded data:
Expand Down Expand Up @@ -204,7 +204,7 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
val afterScanFilters = filterSet -- partitionKeyFilters.filter(_.references.nonEmpty)
logInfo(s"Post-Scan Filters: ${afterScanFilters.mkString(",")}")

val filterAttributes = AttributeSet(afterScanFilters)
val filterAttributes = AttributeSet(afterScanFilters ++ stayUpFilters)
val requiredExpressions: Seq[NamedExpression] = filterAttributes.toSeq ++ projects
val requiredAttributes = AttributeSet(requiredExpressions)

Expand All @@ -222,8 +222,8 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
metadataColumns.filter(_.name != FileFormat.ROW_INDEX)

val readDataColumns = dataColumns
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)
.filter(requiredAttributes.contains)
.filterNot(partitionColumns.contains)

val fileFormatReaderGeneratedMetadataColumns: Seq[Attribute] =
metadataColumns.map(_.name).flatMap {
Expand Down Expand Up @@ -281,10 +281,11 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
readDataColumns ++ partitionColumns :+ metadataAlias, scan)
}.getOrElse(scan)

val afterScanFilter = afterScanFilters.toSeq.reduceOption(expressions.And)
val withFilter = afterScanFilter
.map(execution.FilterExec(_, withMetadataProjections))
.getOrElse(withMetadataProjections)
// bottom-most filters are put in the left of the list.
val finalFilters = afterScanFilters.toSeq.reduceOption(expressions.And).toSeq ++ stayUpFilters
val withFilter = finalFilters.foldLeft(withMetadataProjections)((plan, cond) => {
execution.FilterExec(cond, plan)
})
val withProjections = if (projects == withFilter.output) {
withFilter
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
Expand All @@ -27,7 +27,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.util.SchemaUtils._

/**
* Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation.
* Prunes unnecessary physical columns given a [[ScanOperation]] over a data source relation.
* By "physical column", we mean a column as defined in the data source format like Parquet format
* or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL
* column, and a nested Parquet column corresponds to a [[StructField]].
Expand All @@ -39,9 +39,10 @@ object SchemaPruning extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan =
plan transformDown {
case op @ PhysicalOperation(projects, filters,
case op @ ScanOperation(projects, filtersStayUp, filtersPushDown,
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) =>
prunePhysicalColumns(l, projects, filters, hadoopFsRelation,
val allFilters = filtersPushDown.reduceOption(And).toSeq ++ filtersStayUp
prunePhysicalColumns(l, projects, allFilters, hadoopFsRelation,
(prunedDataSchema, prunedMetadataSchema) => {
val prunedHadoopRelation =
hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession)
Expand All @@ -61,9 +62,10 @@ object SchemaPruning extends Rule[LogicalPlan] {
filters: Seq[Expression],
hadoopFsRelation: HadoopFsRelation,
leafNodeBuilder: (StructType, StructType) => LeafNode): Option[LogicalPlan] = {

val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(relation.output, projects, filters)
val attrNameMap = relation.output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = normalizeAttributeRefNames(attrNameMap, projects)
.asInstanceOf[Seq[NamedExpression]]
val normalizedFilters = normalizeAttributeRefNames(attrNameMap, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)

// If requestedRootFields includes a nested field, continue. Otherwise,
Expand Down Expand Up @@ -112,24 +114,17 @@ object SchemaPruning extends Rule[LogicalPlan] {
fsRelation.fileFormat.isInstanceOf[OrcFileFormat])

/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
* Normalizes the names of the attribute references in the given expressions to reflect
* the names in the given logical relation. This makes it possible to compare attributes and
* fields by name. Returns a tuple with the normalized projects and filters, respectively.
*/
private def normalizeAttributeRefNames(
output: Seq[AttributeReference],
projects: Seq[NamedExpression],
filters: Seq[Expression]): (Seq[NamedExpression], Seq[Expression]) = {
val normalizedAttNameMap = output.map(att => (att.exprId, att.name)).toMap
val normalizedProjects = projects.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
}).map { case expr: NamedExpression => expr }
val normalizedFilters = filters.map(_.transform {
case att: AttributeReference if normalizedAttNameMap.contains(att.exprId) =>
att.withName(normalizedAttNameMap(att.exprId))
attrNameMap: Map[ExprId, String],
exprs: Seq[Expression]): Seq[Expression] = {
exprs.map(_.transform {
case att: AttributeReference if attrNameMap.contains(att.exprId) =>
att.withName(attrNameMap(att.exprId))
})
(normalizedProjects, normalizedFilters)
}

/**
Expand All @@ -148,8 +143,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
val projectedFilters = filters.map(_.transformDown {
case projectionOverSchema(expr) => expr
})
val newFilterCondition = projectedFilters.reduce(And)
Filter(newFilterCondition, leafNode)
// bottom-most filters are put in the left of the list.
projectedFilters.foldLeft[LogicalPlan](leafNode)((plan, cond) => Filter(cond, plan))
} else {
leafNode
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, And, Attribute, AttributeMap, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, Literal, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
Expand Down Expand Up @@ -345,13 +345,15 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}

def pruneColumns(plan: LogicalPlan): LogicalPlan = plan.transform {
case PhysicalOperation(project, filters, sHolder: ScanBuilderHolder) =>
case ScanOperation(project, filtersStayUp, filtersPushDown, sHolder: ScanBuilderHolder) =>
// column pruning
val normalizedProjects = DataSourceStrategy
.normalizeExprs(project, sHolder.output)
.asInstanceOf[Seq[NamedExpression]]
val allFilters = filtersStayUp ++ filtersPushDown.reduceOption(And)
val normalizedFilters = DataSourceStrategy.normalizeExprs(allFilters, sHolder.output)
val (scan, output) = PushDownUtils.pruneColumns(
sHolder.builder, sHolder.relation, normalizedProjects, filters)
sHolder.builder, sHolder.relation, normalizedProjects, normalizedFilters)

logInfo(
s"""
Expand All @@ -368,24 +370,24 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case projectionOverSchema(newExpr) => newExpr
}

val filterCondition = filters.reduceLeftOption(And)
val newFilterCondition = filterCondition.map(projectionFunc)
val withFilter = newFilterCondition.map(Filter(_, scanRelation)).getOrElse(scanRelation)
val finalFilters = normalizedFilters.map(projectionFunc)
val withFilter = finalFilters.foldRight[LogicalPlan](scanRelation)((cond, plan) => {
Filter(cond, plan)
})

val withProjection = if (withFilter.output != project) {
if (withFilter.output != project) {
val newProjects = normalizedProjects
.map(projectionFunc)
.asInstanceOf[Seq[NamedExpression]]
Project(restoreOriginalOutputNames(newProjects, project.map(_.name)), withFilter)
} else {
withFilter
}
withProjection
}

def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform {
case sample: Sample => sample.child match {
case PhysicalOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
case PhysicalOperation(_, Nil, sHolder: ScanBuilderHolder) =>
val tableSample = TableSampleInfo(
sample.lowerBound,
sample.upperBound,
Expand All @@ -404,7 +406,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}

private def pushDownLimit(plan: LogicalPlan, limit: Int): (LogicalPlan, Boolean) = plan match {
case operation @ PhysicalOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
case operation @ PhysicalOperation(_, Nil, sHolder: ScanBuilderHolder) =>
val (isPushed, isPartiallyPushed) = PushDownUtils.pushLimit(sHolder.builder, limit)
if (isPushed) {
sHolder.pushedLimit = Some(limit)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,4 +1104,11 @@ abstract class SchemaPruningSuite
checkAnswer(query2.orderBy("id"),
Row("John", "Y."))
}

testSchemaPruning("SPARK-41017: column pruning through 2 filters") {
import testImplicits._
val query = spark.table("contacts").filter(rand() > 0.5).filter(rand() < 0.8)
.select($"id", $"name.first")
checkScan(query, "struct<id:int, name:struct<first:string>>")
}
}