Skip to content

Commit 8a22e1d

Browse files
committed
Fix bug of Constraint
1 parent 7bcc266 commit 8a22e1d

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,15 @@ trait QueryPlanConstraints { self: LogicalPlan =>
9696

9797
// Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so
9898
// we may avoid producing recursive constraints.
99-
private lazy val aliasMap: AttributeMap[Expression] = AttributeMap(
100-
expressions.collect {
99+
private lazy val aliasMap: AttributeMap[Expression] = {
100+
val aliases = expressions.collect {
101101
case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child)
102-
} ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap))
102+
} ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)
103+
AttributeMap(aliases.filter {
104+
case (_, child) => child.references.nonEmpty && child.references.subsetOf(outputSet)
105+
})
106+
}
107+
103108
// Note: the explicit cast is necessary, since Scala compiler fails to infer the type.
104109

105110
/**

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
27172717
}
27182718
}
27192719

2720+
test("SPARK-23079: constraints should be inferred correctly with aliases") {
2721+
withTable("t") {
2722+
spark.range(5).write.saveAsTable("t")
2723+
val t = spark.read.table("t")
2724+
val left = t.withColumn("xid", $"id" + lit(1)).as("x")
2725+
val right = t.withColumnRenamed("id", "xid").as("y")
2726+
val df = left.join(right, "xid").filter("id = 3").toDF()
2727+
checkAnswer(df, Row(4, 3))
2728+
}
2729+
}
2730+
27202731
test("SRARK-22266: the same aggregate function was calculated multiple times") {
27212732
val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a"
27222733
val df = sql(query)

0 commit comments

Comments
 (0)