Skip to content

Commit 8598a98

Browse files
gengliangwangcloud-fan
authored andcommitted
[SPARK-23079][SQL] Fix query constraints propagation with aliases
## What changes were proposed in this pull request? Previously, PR #19201 fix the problem of non-converging constraints. After that PR #19149 improve the loop and constraints is inferred only once. So the problem of non-converging constraints is gone. However, the case below will fail. ``` spark.range(5).write.saveAsTable("t") val t = spark.read.table("t") val left = t.withColumn("xid", $"id" + lit(1)).as("x") val right = t.withColumnRenamed("id", "xid").as("y") val df = left.join(right, "xid").filter("id = 3").toDF() checkAnswer(df, Row(4, 3)) ``` Because `aliasMap` replace all the aliased child. See the test case in PR for details. This PR is to fix this bug by removing useless code for preventing non-converging constraints. It can be also fixed with #20270, but this is much simpler and clean up the code. ## How was this patch tested? Unit test Author: Wang Gengliang <[email protected]> Closes #20278 from gengliangwang/FixConstraintSimple.
1 parent 0f8a286 commit 8598a98

File tree

5 files changed

+17
-93
lines changed

5 files changed

+17
-93
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ abstract class UnaryNode extends LogicalPlan {
255255
case expr: Expression if expr.semanticEquals(e) =>
256256
a.toAttribute
257257
})
258+
allConstraints += EqualNullSafe(e, a.toAttribute)
258259
case _ => // Don't change.
259260
}
260261

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -94,56 +94,23 @@ trait QueryPlanConstraints { self: LogicalPlan =>
9494
case _ => Seq.empty[Attribute]
9595
}
9696

97-
// Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so
98-
// we may avoid producing recursive constraints.
99-
private lazy val aliasMap: AttributeMap[Expression] = AttributeMap(
100-
expressions.collect {
101-
case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child)
102-
} ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap))
103-
// Note: the explicit cast is necessary, since Scala compiler fails to infer the type.
104-
10597
/**
10698
* Infers an additional set of constraints from a given set of equality constraints.
10799
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
108100
* additional constraint of the form `b = 5`.
109101
*/
110102
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
111-
val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints)
112103
var inferredConstraints = Set.empty[Expression]
113-
aliasedConstraints.foreach {
104+
constraints.foreach {
114105
case eq @ EqualTo(l: Attribute, r: Attribute) =>
115-
val candidateConstraints = aliasedConstraints - eq
106+
val candidateConstraints = constraints - eq
116107
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
117108
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
118109
case _ => // No inference
119110
}
120111
inferredConstraints -- constraints
121112
}
122113

123-
/**
124-
* Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints.
125-
* Thus non-converging inference can be prevented.
126-
* E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions.
127-
* Also, the size of constraints is reduced without losing any information.
128-
* When the inferred filters are pushed down the operators that generate the alias,
129-
* the alias names used in filters are replaced by the aliased expressions.
130-
*/
131-
private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression])
132-
: Set[Expression] = {
133-
val attributesInEqualTo = constraints.flatMap {
134-
case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil
135-
case _ => Nil
136-
}
137-
var aliasedConstraints = constraints
138-
attributesInEqualTo.foreach { a =>
139-
if (aliasMap.contains(a)) {
140-
val child = aliasMap.get(a).get
141-
aliasedConstraints = replaceConstraints(aliasedConstraints, child, a)
142-
}
143-
}
144-
aliasedConstraints
145-
}
146-
147114
private def replaceConstraints(
148115
constraints: Set[Expression],
149116
source: Expression,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
3434
PushDownPredicate,
3535
InferFiltersFromConstraints,
3636
CombineFilters,
37+
SimplifyBinaryComparison,
3738
BooleanSimplification) :: Nil
3839
}
3940

@@ -160,64 +161,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
160161
comparePlans(optimized, correctAnswer)
161162
}
162163

163-
test("inner join with alias: don't generate constraints for recursive functions") {
164-
val t1 = testRelation.subquery('t1)
165-
val t2 = testRelation.subquery('t2)
166-
167-
// We should prevent `Coalese(a, b)` from recursively creating complicated constraints through
168-
// the constraint inference procedure.
169-
val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
170-
// We hide an `Alias` inside the child's child's expressions, to cover the situation reported
171-
// in [SPARK-20700].
172-
.select('int_col, 'd, 'a).as("t")
173-
.join(t2, Inner,
174-
Some("t.a".attr === "t2.a".attr
175-
&& "t.d".attr === "t2.a".attr
176-
&& "t.int_col".attr === "t2.a".attr))
177-
.analyze
178-
val correctAnswer = t1
179-
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a)))
180-
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b)))
181-
&& 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b))
182-
&& 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b))
183-
&& 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b)))
184-
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
185-
.select('int_col, 'd, 'a).as("t")
186-
.join(
187-
t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) &&
188-
'a === Coalesce(Seq('a, 'a))),
189-
Inner,
190-
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
191-
&& "t.int_col".attr === "t2.a".attr))
192-
.analyze
193-
val optimized = Optimize.execute(originalQuery)
194-
comparePlans(optimized, correctAnswer)
195-
}
196-
197-
test("inner join with EqualTo expressions containing part of each other: don't generate " +
198-
"constraints for recursive functions") {
199-
val t1 = testRelation.subquery('t1)
200-
val t2 = testRelation.subquery('t2)
201-
202-
// We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating
203-
// complicated constraints through the constraint inference procedure.
204-
val originalQuery = t1
205-
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
206-
.where('a === 'd && 'c === 'e)
207-
.join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
208-
.analyze
209-
val correctAnswer = t1
210-
.where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) &&
211-
'c === Coalesce(Seq('a, 'b)))
212-
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
213-
.join(t2.where(IsNotNull('a) && IsNotNull('c)),
214-
Inner,
215-
Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
216-
.analyze
217-
val optimized = Optimize.execute(originalQuery)
218-
comparePlans(optimized, correctAnswer)
219-
}
220-
221164
test("generate correct filters for alias that don't produce recursive constraints") {
222165
val t1 = testRelation.subquery('t1)
223166

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
134134
verifyConstraints(aliasedRelation.analyze.constraints,
135135
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
136136
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
137+
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
138+
resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
137139
resolveColumn(aliasedRelation.analyze, "z") > 10,
138140
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
139141

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
27172717
}
27182718
}
27192719

2720+
test("SPARK-23079: constraints should be inferred correctly with aliases") {
2721+
withTable("t") {
2722+
spark.range(5).write.saveAsTable("t")
2723+
val t = spark.read.table("t")
2724+
val left = t.withColumn("xid", $"id" + lit(1)).as("x")
2725+
val right = t.withColumnRenamed("id", "xid").as("y")
2726+
val df = left.join(right, "xid").filter("id = 3").toDF()
2727+
checkAnswer(df, Row(4, 3))
2728+
}
2729+
}
2730+
27202731
test("SRARK-22266: the same aggregate function was calculated multiple times") {
27212732
val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a"
27222733
val df = sql(query)

0 commit comments

Comments
 (0)