diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index a5761703fd655..3c66d3a5a7822 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -183,8 +183,44 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * An [[ExpressionSet]] that contains invariants about the rows output by this operator. For * example, if this set contains the expression `a = 2` then that expression is guaranteed to * evaluate to `true` for all rows produced. + * + * Note: This invariants don't contain all the possible invariants. This don't consider aliased + * attributes. This are the effective expressions which are useful for data filtering. Because if + * any one invariant in this set is false, then we can guarantee the conjunctive predicates from + * the complete invariants [[completeConstraints]] of this node would be false too. */ - lazy val constraints: ExpressionSet = ExpressionSet(getRelevantConstraints(validConstraints)) + lazy val constraints: ExpressionSet = { + // The aliased expressions which are not contained in the outputSet of this plan. + val obsoleteKeys = aliasedExpressionsInConstraints.keys.filterNot { key => + key.references.subsetOf(outputSet) + } + + // If there are obsolete aliased expressions. We need to select a new key. + if (obsoleteKeys.nonEmpty) { + var updatedConstraints = validConstraints + obsoleteKeys.foreach { oldKey => + val newAttr = aliasedExpressionsInConstraints(oldKey).toSeq.sortWith { (x, y) => + x.exprId.id < y.exprId.id + }.head + updatedConstraints ++= updatedConstraints.map { constraint => + constraint.transform { + case e if e.semanticEquals(oldKey) => newAttr + } + } + } + ExpressionSet(getRelevantConstraints(updatedConstraints)) + } else { + ExpressionSet(getRelevantConstraints(validConstraints)) + } + } + + /** + * An [[ExpressionSet]] that contains invariants about the rows output by this operator. + * Different with `constraints`, `completeConstraints` contains all invariants by transforming + * constraints to aliased constraints. + */ + lazy val completeConstraints: ExpressionSet = ExpressionSet(getRelevantConstraints( + validConstraints.union(getAliasedValidConstraints(validConstraints)))) /** * This method can be overridden by any child class of QueryPlan to specify a set of constraints @@ -196,6 +232,96 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ protected def validConstraints: Set[Expression] = Set.empty + /** + * A map represents aliased expressions in the constraints. A key is an expression to be aliased. + * A value is a set of [[Attribute]] the corresponding key is aliased to. A constraint which + * refers a key in this map can be transformed to other constraints by replacing the key + * with the values in this map. For example, if there is a constraint "a > b" and there is a key + * "a" to value ["c", "d"] in this map. We can transform this constraint to other valid + * constraints such as ["c > b", "d > b"]. + */ + lazy val aliasedExpressionsInConstraints: Map[Expression, AttributeSet] = { + var aliasedMap = aliasedConstraintExprs + val combinedMap = children.map { child => + // Get the alised expressions which are not in the `outputSet` of this plan anymore. + // For example, if in children we have aliased map like Map("a" -> Set("c", "d")), but this + // plan doesn't output "a". Then we replace the map with Map("c" -> Set("d")). + val obsoleteKeys = child.aliasedExpressionsInConstraints.keys.filterNot { key => + key.references.subsetOf(child.outputSet) + } + obsoleteKeys.flatMap { obsoleteKey => + // Get the all attributes the obsolete expression is alised to. + val attrs = child.aliasedExpressionsInConstraints(obsoleteKey).toSeq.sortWith { (x, y) => + x.exprId.id < y.exprId.id + } + if (attrs.length > 1) { + // Take the first attribute as the new aliased expression for others attributes. + Some(attrs.head -> AttributeSet(attrs.tail)) + } else { + // Only one attribute remains, it will be used in [[constraints]] to replace obsolete key. + None + } + }.toMap ++ child.aliasedExpressionsInConstraints.filter { + case (keyExpr, attrs) => keyExpr.references.subsetOf(child.outputSet) + } + }.flatten.toMap.map { case (keyExpr, attrs) => + if (aliasedMap.contains(keyExpr)) { + val addedAttrs = aliasedMap(keyExpr) + aliasedMap = aliasedMap - keyExpr + keyExpr -> (attrs ++ addedAttrs) + } else { + keyExpr -> attrs + } + } ++ aliasedMap + + combinedMap.flatMap { + case (keyExpr, attrs) => + val newAtts = attrs.intersect(outputSet) + if (newAtts.isEmpty) { + None + } else { + Some(keyExpr -> newAtts) + } + } + } + + /** + * A map represents aliased expressions and attributes newly added in the curent QueryPlan. + * A child class of QueryPlan should override this to specify the alias relations in its output. + * For example, if an output "a" of a child plan of this plan is aliased to "c" and "d" in this + * plan, it should record the map as Map("a" -> Set("c", "d")). + */ + protected lazy val aliasedConstraintExprs: Map[Expression, AttributeSet] = Map.empty + + /** + * Generates an additional set of aliased constraints by replacing the original constraint + * expressions with the corresponding alias + */ + private def getAliasedValidConstraints(constraints: Set[Expression]): Set[Expression] = { + // We only care about the constraints which refer to attributes in output and aliased + // expressions. + // For example, for a constraint 'a > b', if 'a' is aliased to 'c', we need to get aliased + // constraint 'c > b' only if 'b' is in output. + val relativeReferences = AttributeSet( + aliasedExpressionsInConstraints.keys.flatMap(_.references) ++ outputSet) + + var allConstraints = constraints.filter { constraint => + constraint.references.subsetOf(relativeReferences) + }.asInstanceOf[Set[Expression]] + + aliasedExpressionsInConstraints.foreach { case (alisedExpr, attrs) => + attrs.foreach { attr => + allConstraints ++= allConstraints.map(_ transform { + case expr: Expression if expr.semanticEquals(alisedExpr) => + attr + }) + allConstraints += EqualNullSafe(alisedExpr, attr) + } + } + + allConstraints -- constraints + } + /** * Returns the set of attributes that are output by this node. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0937825e273a2..58a93d847b453 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -309,26 +309,6 @@ abstract class UnaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = child :: Nil - /** - * Generates an additional set of aliased constraints by replacing the original constraint - * expressions with the corresponding alias - */ - protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { - var allConstraints = child.constraints.asInstanceOf[Set[Expression]] - projectList.foreach { - case a @ Alias(e, _) => - // For every alias in `projectList`, replace the reference in constraints by its attribute. - allConstraints ++= allConstraints.map(_ transform { - case expr: Expression if expr.semanticEquals(e) => - a.toAttribute - }) - allConstraints += EqualNullSafe(e, a.toAttribute) - case _ => // Don't change. - } - - allConstraints -- child.constraints - } - override protected def validConstraints: Set[Expression] = child.constraints override def computeStats(conf: CatalystConf): Statistics = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index af57632516790..840e70b34e86f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -53,8 +53,14 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend !expressions.exists(!_.resolved) && childrenResolved && !hasSpecialExpressions } - override def validConstraints: Set[Expression] = - child.constraints.union(getAliasedConstraints(projectList)) + override lazy val aliasedConstraintExprs: Map[Expression, AttributeSet] = + projectList.collect { + case a: Alias => a + }.groupBy(_.child).map { case (k, v) => + k -> AttributeSet(v.map(_.toAttribute)) + } + + override def validConstraints: Set[Expression] = child.constraints override def computeStats(conf: CatalystConf): Statistics = { if (conf.cboEnabled) { @@ -550,10 +556,14 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows - override def validConstraints: Set[Expression] = { - val nonAgg = aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty) - child.constraints.union(getAliasedConstraints(nonAgg)) - } + override lazy val aliasedConstraintExprs: Map[Expression, AttributeSet] = + aggregateExpressions.filter(_.find(_.isInstanceOf[AggregateExpression]).isEmpty).collect { + case a: Alias => a + }.groupBy(_.child).map { case (k, v) => + k -> AttributeSet(v.map(_.toAttribute)) + } + + override def validConstraints: Set[Expression] = child.constraints override def computeStats(conf: CatalystConf): Statistics = { def simpleEstimation: Statistics = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 9f57f66a2ea20..8ec2cc663aa4f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -149,9 +149,9 @@ class InferFiltersFromConstraintsSuite extends PlanTest { .join(t2, Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull('b) && 'a <=> 'a && 'b <=> 'b &&'a === 'b) + .where(IsNotNull('a) && IsNotNull('b) && 'a === 'b) .select('a, 'b.as('d)).as("t") - .join(t2.where(IsNotNull('a) && 'a <=> 'a), Inner, + .join(t2.where(IsNotNull('a)), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) @@ -168,24 +168,19 @@ class InferFiltersFromConstraintsSuite extends PlanTest { && "t.d".attr === "t2.a".attr && "t.int_col".attr === "t2.a".attr)) .analyze + val correctAnswer = t1 - .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a - && Coalesce(Seq('a, 'a)) <=> 'b && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a)) - && 'a === 'b && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === Coalesce(Seq('a, 'b)) - && Coalesce(Seq('a, 'b)) <=> Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b - && IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) - && 'b === Coalesce(Seq('b, 'b)) && 'b <=> Coalesce(Seq('b, 'b)) - && Coalesce(Seq('b, 'b)) <=> Coalesce(Seq('b, 'b)) && 'b <=> 'b) + .where(IsNotNull('a) && 'b === Coalesce(Seq('b, 'b)) && Coalesce(Seq('a, 'b)) === 'b + && IsNotNull('b) && 'a === Coalesce(Seq('a, 'b)) && IsNotNull(Coalesce(Seq('a, 'a))) + && 'a === Coalesce(Seq('a, 'a)) && IsNotNull(Coalesce(Seq('a, 'b))) && 'a === 'b + && IsNotNull(Coalesce(Seq('b, 'b)))) .select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col)).as("t") .join(t2 .where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) - && 'a === Coalesce(Seq('a, 'a)) && 'a <=> Coalesce(Seq('a, 'a)) && 'a <=> 'a - && Coalesce(Seq('a, 'a)) <=> Coalesce(Seq('a, 'a))), Inner, + && 'a === Coalesce(Seq('a, 'a))), Inner, Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr - && "t.int_col".attr === "t2.a".attr - && Coalesce(Seq("t.d".attr, "t.d".attr)) <=> "t.int_col".attr)) + && "t.int_col".attr === "t2.a".attr)) .analyze val optimized = Optimize.execute(originalQuery) comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 908b370408280..75eb97d984ac1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -130,7 +130,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) - verifyConstraints(aliasedRelation.analyze.constraints, + // `completeConstraints` contains all constraints including fully aliased ones. + verifyConstraints(aliasedRelation.analyze.completeConstraints, ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), @@ -138,12 +139,37 @@ class ConstraintPropagationSuite extends SparkFunSuite { resolveColumn(aliasedRelation.analyze, "z") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) + // `constraints` doesn't contain fully aliased constraints. It only contains the minimal + // constraints can effectively determine if a row is produced by the relation. + verifyConstraints(aliasedRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, + IsNotNull(resolveColumn(aliasedRelation.analyze, "x"))))) + + val aliasedMap = aliasedRelation.analyze.aliasedExpressionsInConstraints + // 'a aliased to 'x and 'z. + assert(aliasedMap.contains(resolveColumn(tr.analyze, "a"))) + assert(aliasedMap(resolveColumn(tr.analyze, "a")).equals( + AttributeSet( + Seq(resolveColumn(aliasedRelation.analyze, "x"), + resolveColumn(aliasedRelation.analyze, "z"))))) + + // 'b aliased to 'y' + assert(aliasedMap.contains(resolveColumn(tr.analyze, "b"))) + assert(aliasedMap(resolveColumn(tr.analyze, "b")).equals( + AttributeSet( + Seq(resolveColumn(aliasedRelation.analyze, "y"))))) + val multiAlias = tr.where('a === 'c + 10).select('a.as('x), 'c.as('y)) - verifyConstraints(multiAlias.analyze.constraints, + verifyConstraints(multiAlias.analyze.completeConstraints, ExpressionSet(Seq(IsNotNull(resolveColumn(multiAlias.analyze, "x")), IsNotNull(resolveColumn(multiAlias.analyze, "y")), resolveColumn(multiAlias.analyze, "x") === resolveColumn(multiAlias.analyze, "y") + 10)) ) + + // Because 'a and 'c are not in the outputSet of multiAlias, their occurrence in `constraints` + // will be transformed with aliasing attributes 'x and 'y. + // So the `completeConstraints` of multiAlias is as same as its `constraints`. + assert((multiAlias.analyze.completeConstraints -- multiAlias.analyze.constraints).isEmpty) } test("propagating constraints in union") {