Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,44 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* An [[ExpressionSet]] that contains invariants about the rows output by this operator. For
* example, if this set contains the expression `a = 2` then that expression is guaranteed to
* evaluate to `true` for all rows produced.
*
* Note: This invariants don't contain all the possible invariants. This don't consider aliased
* attributes. This are the effective expressions which are useful for data filtering. Because if
* any one invariant in this set is false, then we can guarantee the conjunctive predicates from
* the complete invariants [[completeConstraints]] of this node would be false too.
*/
lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints))
lazy val constraints: ExpressionSet = {
// The aliased expressions which are not contained in the outputSet of this plan.
val obsoleteKeys = aliasedExpressionsInConstraints.keys.filterNot { key =>
key.references.subsetOf(outputSet)
}

// If there are obsolete aliased expressions. We need to select a new key.
if (obsoleteKeys.nonEmpty) {
var updatedConstraints = validConstraints
obsoleteKeys.foreach { oldKey =>
val newAttr = aliasedExpressionsInConstraints(oldKey).toSeq.sortWith { (x, y) =>
x.exprId.id < y.exprId.id
}.head
updatedConstraints ++= updatedConstraints.map { constraint =>
constraint.transform {
case e if e.semanticEquals(oldKey) => newAttr
}
}
}
ExpressionSet(getRelevantConstraints(updatedConstraints))
} else {
ExpressionSet(getRelevantConstraints(validConstraints))
}
}

/**
* An [[ExpressionSet]] that contains invariants about the rows output by this operator.
* Different with `constraints`, `completeConstraints` contains all invariants by transforming
* constraints to aliased constraints.
*/
lazy val completeConstraints: ExpressionSet = ExpressionSet(getRelevantConstraints(
validConstraints.union(getAliasedValidConstraints(validConstraints))))

/**
* This method can be overridden by any child class of QueryPlan to specify a set of constraints
Expand All @@ -196,6 +232,96 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
*/
protected def validConstraints: Set[Expression] = Set.empty

/**
* A map represents aliased expressions in the constraints. A key is an expression to be aliased.
* A value is a set of [[Attribute]] the corresponding key is aliased to. A constraint which
* refers a key in this map can be transformed to other constraints by replacing the key
* with the values in this map. For example, if there is a constraint "a > b" and there is a key
* "a" to value ["c", "d"] in this map. We can transform this constraint to other valid
* constraints such as ["c > b", "d > b"].
*/
lazy val aliasedExpressionsInConstraints: Map[Expression, AttributeSet] = {
var aliasedMap = aliasedConstraintExprs
val combinedMap = children.map { child =>
// Get the alised expressions which are not in the `outputSet` of this plan anymore.
// For example, if in children we have aliased map like Map("a" -> Set("c", "d")), but this
// plan doesn't output "a". Then we replace the map with Map("c" -> Set("d")).
val obsoleteKeys = child.aliasedExpressionsInConstraints.keys.filterNot { key =>
key.references.subsetOf(child.outputSet)
}
obsoleteKeys.flatMap { obsoleteKey =>
// Get the all attributes the obsolete expression is alised to.
val attrs = child.aliasedExpressionsInConstraints(obsoleteKey).toSeq.sortWith { (x, y) =>
x.exprId.id < y.exprId.id
}
if (attrs.length > 1) {
// Take the first attribute as the new aliased expression for others attributes.
Some(attrs.head -> AttributeSet(attrs.tail))
} else {
// Only one attribute remains, it will be used in [[constraints]] to replace obsolete key.
None
}
}.toMap ++ child.aliasedExpressionsInConstraints.filter {
case (keyExpr, attrs) => keyExpr.references.subsetOf(child.outputSet)
}
}.flatten.toMap.map { case (keyExpr, attrs) =>
if (aliasedMap.contains(keyExpr)) {
val addedAttrs = aliasedMap(keyExpr)
aliasedMap = aliasedMap - keyExpr
keyExpr -> (attrs ++ addedAttrs)
} else {
keyExpr -> attrs
}
} ++ aliasedMap

combinedMap.flatMap {
case (keyExpr, attrs) =>
val newAtts = attrs.intersect(outputSet)
if (newAtts.isEmpty) {
None
} else {
Some(keyExpr -> newAtts)
}
}
}

/**
* A map represents aliased expressions and attributes newly added in the curent QueryPlan.
* A child class of QueryPlan should override this to specify the alias relations in its output.
* For example, if an output "a" of a child plan of this plan is aliased to "c" and "d" in this
* plan, it should record the map as Map("a" -> Set("c", "d")).
*/
protected lazy val aliasedConstraintExprs: Map[Expression, AttributeSet] = Map.empty

/**
* Generates an additional set of aliased constraints by replacing the original constraint
* expressions with the corresponding alias
*/
private def getAliasedValidConstraints(constraints: Set[Expression]): Set[Expression] = {
// We only care about the constraints which refer to attributes in output and aliased
// expressions.
// For example, for a constraint 'a > b', if 'a' is aliased to 'c', we need to get aliased
// constraint 'c > b' only if 'b' is in output.
val relativeReferences = AttributeSet(
aliasedExpressionsInConstraints.keys.flatMap(_.references) ++ outputSet)

var allConstraints = constraints.filter { constraint =>
constraint.references.subsetOf(relativeReferences)
}.asInstanceOf[Set[Expression]]

aliasedExpressionsInConstraints.foreach { case (alisedExpr, attrs) =>
attrs.foreach { attr =>
allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(alisedExpr) =>
attr
})
allConstraints += EqualNullSafe(alisedExpr, attr)
}
}

allConstraints -- constraints
}

/**
* Returns the set of attributes that are output by this node.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,26 +309,6 @@ abstract class UnaryNode extends LogicalPlan {

override final def children: Seq[LogicalPlan] = child :: Nil

/**
* Generates an additional set of aliased constraints by replacing the original constraint
* expressions with the corresponding alias
*/
protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = {
var allConstraints = child.constraints.asInstanceOf[Set[Expression]]
projectList.foreach {
case a @ Alias(e, _) =>
// For every alias in `projectList`, replace the reference in constraints by its attribute.
allConstraints ++= allConstraints.map(_ transform {
case expr: Expression if expr.semanticEquals(e) =>
a.toAttribute
})
allConstraints += EqualNullSafe(e, a.toAttribute)
case _ => // Don't change.
}

allConstraints -- child.constraints
}

override protected def validConstraints: Set[Expression] = child.constraints

override def computeStats(conf: CatalystConf): Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
!expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions
}

override def validConstraints: Set[Expression] =
child.constraints.union(getAliasedConstraints(projectList))
override lazy val aliasedConstraintExprs: Map[Expression, AttributeSet] =
projectList.collect {
case a: Alias => a
}.groupBy(_.child).map { case (k, v) =>
k -> AttributeSet(v.map(_.toAttribute))
}

override def validConstraints: Set[Expression] = child.constraints

override def computeStats(conf: CatalystConf): Statistics = {
if (conf.cboEnabled) {
Expand Down Expand Up @@ -550,10 +556,14 @@ case class Aggregate(
override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
override def maxRows: Option[Long] = child.maxRows

override def validConstraints: Set[Expression] = {
val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty)
child.constraints.union(getAliasedConstraints(nonAgg))
}
override lazy val aliasedConstraintExprs: Map[Expression, AttributeSet] =
aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty).collect {
case a: Alias => a
}.groupBy(_.child).map { case (k, v) =>
k -> AttributeSet(v.map(_.toAttribute))
}

override def validConstraints: Set[Expression] = child.constraints

override def computeStats(conf: CatalystConf): Statistics = {
def simpleEstimation: Statistics = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
.join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b)
.where(IsNotNull('a) && IsNotNull('b) && 'a === 'b)
.select('a, 'b.as('d)).as("t")
.join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner,
.join(t2.where(IsNotNull('a)), Inner,
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
Expand All @@ -168,24 +168,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr))
.analyze

val correctAnswer = t1
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
&& Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))
&& 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b))
&& Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b)))
&& 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b))
&& Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b)
.where(IsNotNull('a) && 'b === Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b
&& IsNotNull('b) && 'a === Coalesce(Seq('a, 'b)) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b
&& IsNotNull(Coalesce(Seq('b, 'b))))
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t")
.join(t2
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a)))
&& 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a
&& Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner,
&& 'a === Coalesce(Seq('a, 'a))), Inner,
Some("t.a".attr === "t2.a".attr
&& "t.d".attr === "t2.a".attr
&& "t.int_col".attr === "t2.a".attr
&& Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr))
&& "t.int_col".attr === "t2.a".attr))
.analyze
val optimized = Optimize.execute(originalQuery)
comparePlans(optimized, correctAnswer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,46 @@ class ConstraintPropagationSuite extends SparkFunSuite {

val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z))

verifyConstraints(aliasedRelation.analyze.constraints,
// `completeConstraints` contains all constraints including fully aliased ones.
verifyConstraints(aliasedRelation.analyze.completeConstraints,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))

// `constraints` doesn't contain fully aliased constraints. It only contains the minimal
// constraints can effectively determine if a row is produced by the relation.
verifyConstraints(aliasedRelation.analyze.constraints,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")))))

val aliasedMap = aliasedRelation.analyze.aliasedExpressionsInConstraints
// 'a aliased to 'x and 'z.
assert(aliasedMap.contains(resolveColumn(tr.analyze, "a")))
assert(aliasedMap(resolveColumn(tr.analyze, "a")).equals(
AttributeSet(
Seq(resolveColumn(aliasedRelation.analyze, "x"),
resolveColumn(aliasedRelation.analyze, "z")))))

// 'b aliased to 'y'
assert(aliasedMap.contains(resolveColumn(tr.analyze, "b")))
assert(aliasedMap(resolveColumn(tr.analyze, "b")).equals(
AttributeSet(
Seq(resolveColumn(aliasedRelation.analyze, "y")))))

val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y))
verifyConstraints(multiAlias.analyze.constraints,
verifyConstraints(multiAlias.analyze.completeConstraints,
ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")),
IsNotNull(resolveColumn(multiAlias.analyze, "y")),
resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10))
)

// Because 'a and 'c are not in the outputSet of multiAlias, their occurrence in `constraints`
// will be transformed with aliasing attributes 'x and 'y.
// So the `completeConstraints` of multiAlias is as same as its `constraints`.
assert((multiAlias.analyze.completeConstraints -- multiAlias.analyze.constraints).isEmpty)
}

test("propagating constraints in union") {
Expand Down