Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ object InferFiltersFromConstraints extends Rule[LogicalPlan]
conditionOpt: Option[Expression]): ExpressionSet = {
val baseConstraints = left.constraints.union(right.constraints)
.union(ExpressionSet(conditionOpt.map(splitConjunctivePredicates).getOrElse(Nil)))
baseConstraints.union(inferAdditionalConstraints(baseConstraints))
inferConstraints(baseConstraints)
}

private def inferNewFilter(plan: LogicalPlan, constraints: ExpressionSet): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper {
* 4) Unwrap '=', '<=>' if one side is a boolean literal
*/
object SimplifyBinaryComparison
extends Rule[LogicalPlan] with PredicateHelper with ConstraintHelper {
extends Rule[LogicalPlan] with PredicateHelper {

private def canSimplifyComparison(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import scala.collection.mutable

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -25,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanSt
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike}
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}


Expand Down Expand Up @@ -183,26 +186,95 @@ trait LeafNode extends LogicalPlan with LeafLike[LogicalPlan] {
* A logical plan node with single child.
*/
trait UnaryNode extends LogicalPlan with UnaryLike[LogicalPlan] {
val constraintProjectionLimit = conf.getConf(SQLConf.CONSTRAINT_PROJECTION_LIMIT)

/**
* Generates all valid constraints including an set of aliased constraints by replacing the
* original constraint expressions with the corresponding alias
* Generates all valid constraints including a set of aliased constraints by replacing the
* original constraint expressions with the corresponding alias.
* This method only returns constraints whose referenced attributes are subset of `outputSet`.
*/
protected def getAllValidConstraints(projectList: Seq[NamedExpression]): ExpressionSet = {
var allConstraints = child.constraints
val newLiteralConstraints = mutable.ArrayBuffer.empty[EqualNullSafe]
val newConstraints = mutable.ArrayBuffer.empty[EqualNullSafe]
val aliasMap = mutable.Map[Expression, mutable.ArrayBuffer[Attribute]]()
projectList.foreach {
// These new literal constraints doesn't need any projection, not they can project any other
// constraint
case a @ Alias(l: Literal, _) =>
allConstraints += EqualNullSafe(a.toAttribute, l)
case a @ Alias(e, _) if e.deterministic =>
// For every alias in `projectList`, replace the reference in constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
})
allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
newLiteralConstraints += EqualNullSafe(a.toAttribute, l)

// We need to add simple attributes to the alias as those attributes can be aliased as well.
// Technically, we don't need to add attributes that are not otherwise aliased to the map, but
// adding them does no harm.
case a: Attribute =>
aliasMap.getOrElseUpdate(a.canonicalized, mutable.ArrayBuffer.empty) += a

// If we have an alias in the projection then we need to:
// - add it to the alias map as it can project child's constraints
// - and add it to the new constraints and let it be projected or pruned based on other
// aliases and attributes in the project list.
// E.g. `a + b <=> x` constraint can "survive" the projection if
// - `a + b` is aliased (like `a + b AS y` and so the projected constraint is `y <=> x`)
// - or both `a` and `b` are aliased or included in the output set
case a @ Alias(child, _) if child.deterministic =>
val attr = a.toAttribute
aliasMap.getOrElseUpdate(child.canonicalized, mutable.ArrayBuffer.empty) += attr
newConstraints += EqualNullSafe(child, attr)
case _ =>
}

def projectConstraint(expr: Expression) = {
// The current constraint projection doesn't do a full-blown projection which means that when
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to keep the current beahaviour in this PR.

// - a constraint contain an expression multiple times (E.g. `c + c > 1`)
// - and we have a projection where an expression is aliased as multiple different attributes
// (E.g. `c AS c1`, `c AS c2`)
// then we return only `c1 + c1 > 1` and `c2 + c2 > 1` but doesn't return `c1 + c2 > 1`.
val currentAlias = mutable.Map.empty[Expression, Seq[Expression]]
expr.multiTransformDown {
// Mapping with aliases
case e: Expression if aliasMap.contains(e.canonicalized) =>
// When we encounter an expression for the first time in the tree, set up a cache to track
// the current attribute alias and return that cached attribute when we encounter the same
// expression at other places.
currentAlias.getOrElse(e.canonicalized, {
// If a parent expression can can be transformed return to original expression too to
// let its children transformed too.
val alternatives = if (e.containsChild.nonEmpty) {
e +: aliasMap(e.canonicalized).toSeq
} else {
aliasMap(e.canonicalized).toSeq
}

// When iterate through the alternatives for the first encounter we also update the
// cache.
alternatives.toStream.map { a =>
currentAlias += e.canonicalized -> Seq(a)
a
}.append {
currentAlias -= e.canonicalized
Seq.empty
}
})


// Prune if we encounter an attribute that we can't map and it is not in output set.
case a: Attribute if !outputSet.contains(a) => Seq.empty
}.filter {
case EqualNullSafe(a1: Attribute, a2: Attribute) => a1.canonicalized != a2.canonicalized
case _ => true
}
}

allConstraints
val projectedConstraints =
// Transform child's constraints according to alias map
child.constraints.toStream.flatMap(projectConstraint) ++
// Transform child expressions of new constraints according to alias map
newConstraints.toStream.flatMap(projectConstraint)

ExpressionSet(
constraintProjectionLimit.map(l => projectedConstraints.take(l))
.getOrElse(projectedConstraints) ++
newLiteralConstraints.toSeq)
}

override protected lazy val validConstraints: ExpressionSet = child.constraints
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
package org.apache.spark.sql.catalyst.plans.logical

import scala.annotation.tailrec
import scala.collection.mutable

import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions._

import org.apache.spark.sql.internal.SQLConf

trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>

Expand All @@ -31,8 +33,7 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
*/
lazy val constraints: ExpressionSet = {
if (conf.constraintPropagationEnabled) {
validConstraints
.union(inferAdditionalConstraints(validConstraints))
inferConstraints(validConstraints)
.union(constructIsNotNullConstraints(validConstraints, output))
.filter { c =>
c.references.nonEmpty && c.references.subsetOf(outputSet) && c.deterministic
Expand All @@ -53,37 +54,69 @@ trait QueryPlanConstraints extends ConstraintHelper { self: LogicalPlan =>
protected lazy val validConstraints: ExpressionSet = ExpressionSet()
}

trait ConstraintHelper {
trait ConstraintHelper extends SQLConfHelper {

lazy val constraintInferenceLimit = conf.getConf(SQLConf.CONSTRAINT_INFERENCE_LIMIT)

/**
* Infers an additional set of constraints from a given set of equality constraints.
* Infers an additional set of constraints from a given set of equality constraints and returns
* them with the original constraint set.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`.
*/
def inferAdditionalConstraints(constraints: ExpressionSet): ExpressionSet = {
var inferredConstraints = ExpressionSet()
def inferConstraints(constraints: ExpressionSet): ExpressionSet = {
// IsNotNull should be constructed by `constructIsNotNullConstraints`.
val predicates = constraints.filterNot(_.isInstanceOf[IsNotNull])
val (notNullConstraints, predicates) = constraints.partition(_.isInstanceOf[IsNotNull])

val equivalenceMap = mutable.Map.empty[Expression, mutable.ArrayBuffer[Expression]]
predicates.foreach {
case eq @ EqualTo(l: Attribute, r: Attribute) =>
val candidateConstraints = predicates - eq
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
case eq @ EqualTo(l @ Cast(_: Attribute, _, _, _), r: Attribute) =>
inferredConstraints ++= replaceConstraints(predicates - eq, r, l)
case eq @ EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _, _)) =>
inferredConstraints ++= replaceConstraints(predicates - eq, l, r)
case EqualTo(l: Attribute, r: Attribute) =>
equivalenceMap.getOrElseUpdate(l.canonicalized, mutable.ArrayBuffer.empty) += r
equivalenceMap.getOrElseUpdate(r.canonicalized, mutable.ArrayBuffer.empty) += l
case EqualTo(l @ Cast(_: Attribute, _, _, _), r: Attribute) =>
equivalenceMap.getOrElseUpdate(r.canonicalized, mutable.ArrayBuffer.empty) += l
case EqualTo(l: Attribute, r @ Cast(_: Attribute, _, _, _)) =>
equivalenceMap.getOrElseUpdate(l.canonicalized, mutable.ArrayBuffer.empty) += r
case _ => // No inference
}
inferredConstraints -- constraints
}

private def replaceConstraints(
constraints: ExpressionSet,
source: Expression,
destination: Expression): ExpressionSet = constraints.map(_ transform {
case e: Expression if e.semanticEquals(source) => destination
})
def inferConstraints(expr: Expression) = {
// The current constraint inference doesn't do a full-blown inference which means that when
// - a constraint contain an attribute multiple times (E.g. `c + c > 1`)
// - and we have multiple equivalences for that attribute (E.g. `c = a`, `c = b`)
// then we return only `a + a > 1` and `b + b > 1` besides the original constraint, but
// doesn't return `a + b > 1`.
val currentMapping = mutable.Map.empty[Expression, Seq[Expression]]
expr.multiTransformDown {
case e: Expression if equivalenceMap.contains(e.canonicalized) =>
// When we encounter an attribute for the first time in the tree, set up a cache to track
// the current equivalence and return that cached equivalence when we encounter the same
// expression at other places.
currentMapping.getOrElse(e.canonicalized, {
// Always return the original expression too
val alternatives = e +: equivalenceMap(e.canonicalized).toSeq

// When iterate through the alternatives for the first encounter we also update the
// cache.
alternatives.toStream.map { a =>
currentMapping += e.canonicalized -> Seq(a)
a
}.append {
currentMapping -= e.canonicalized
Seq.empty
}
})
}.filter {
case EqualTo(e1, e2) => e1.canonicalized != e2.canonicalized
case _ => true
}
}

val inferredConstraints = predicates.toStream.flatMap(inferConstraints)

notNullConstraints ++
constraintInferenceLimit.map(l => inferredConstraints.take(l)).getOrElse(inferredConstraints)
}

/**
* Infers a set of `isNotNull` constraints from null intolerant expressions as well as
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,28 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val CONSTRAINT_PROJECTION_LIMIT =
buildConf("spark.sql.constraintPropagation.projectionLimit")
.doc("If defined then the maximum number of original and projected constraints during " +
"constraint propagation.")
.internal()
.version("3.5.0")
.intConf
.checkValue(_ >= 0,
"The value of spark.sql.constraintPropagation.projectionLimit must not be negative.")
.createOptional

val CONSTRAINT_INFERENCE_LIMIT =
buildConf("spark.sql.constraintPropagation.inferenceLimit")
.doc("If defined then the maximum number of inferred constraints during constraint " +
"propagation.")
.internal()
.version("3.5.0")
.intConf
.checkValue(_ >= 0,
"The value of spark.sql.constraintPropagation.inferenceLimit must not be negative.")
.createOptional

val PROPAGATE_DISTINCT_KEYS_ENABLED =
buildConf("spark.sql.optimizer.propagateDistinctKeys.enabled")
.internal()
Expand Down
Loading