Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -839,9 +839,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
expr.foreach {
case a: AggregateExpression if containsOuter(a) =>
val outer = a.collect { case OuterReference(e) => e.toAttribute }
val local = a.references -- outer
Copy link
Contributor Author

@cloud-fan cloud-fan Jun 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expression.references doesn't count outer refs, so this is unnecessary. It can miss some cases if the outer ref is the same as local ref (e.g. the outer query and inner query select from the same table)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

if (local.nonEmpty) {
if (a.references.nonEmpty) {
throw QueryCompilationErrors.mixedRefsInAggFunc(a.sql)
}
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -652,3 +652,15 @@ case object UnresolvedSeed extends LeafExpression with Unevaluable {
override def dataType: DataType = throw new UnresolvedException("dataType")
override lazy val resolved = false
}

/**
* An intermediate expression to hold a resolved (nested) column. Some rules may need to undo the
* column resolution and use this expression to keep the original column name.
*/
case class TempResolvedColumn(child: Expression, nameParts: Seq[String]) extends UnaryExpression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is its child always a named expression?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'll refine the type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we don't require it to be a named expression, so it's more robust to keep the type as Expression.

with Unevaluable {
override lazy val canonicalized = child.canonicalized
override def dataType: DataType = child.dataType
override protected def withNewChildInternal(newChild: Expression): Expression =
copy(child = newChild)
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ object SubExprUtils extends PredicateHelper {
val outerExpressions = ArrayBuffer.empty[Expression]
def collectOutRefs(input: Expression): Unit = input match {
case a: AggregateExpression if containsOuter(a) =>
val outer = a.collect { case OuterReference(e) => e.toAttribute }
val local = a.references -- outer
if (local.nonEmpty) {
if (a.references.nonEmpty) {
throw QueryCompilationErrors.mixedRefsInAggFunc(a.sql)
} else {
// Collect and update the sub-tree so that outer references inside this aggregate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,6 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
*/
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsPattern(PLAN_EXPRESSION)) {
case f @ Filter(_, a: Aggregate) =>
rewriteSubQueries(f, Seq(a, a.child))
case j: LateralJoin =>
val newPlan = rewriteSubQueries(j, j.children)
// Since a lateral join's output depends on its left child output and its lateral subquery's
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,9 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val b = testRelation2.output(1)
val c = testRelation2.output(2)
val alias_a3 = count(a).as("a3")
val alias_b = b.as("aggOrder")

// Case 1: when the child of Sort is not Aggregate,
// the sort reference is handled by the rule ResolveSortReferences
// the sort reference is handled by the rule ResolveMissingReferences
val plan1 = testRelation2
.groupBy($"a", $"c", $"b")($"a", $"c", count($"a").as("a3"))
.select($"a", $"c", $"a3")
Expand All @@ -194,8 +193,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
.orderBy($"b".asc)

val expected2 = testRelation2
.groupBy(a, c, b)(a, c, alias_a3, alias_b)
.orderBy(alias_b.toAttribute.asc)
.groupBy(a, c, b)(a, c, alias_a3, b)
.orderBy(b.asc)
.select(a, c, alias_a3.toAttribute)

checkAnalysis(plan2, expected2)
Expand Down Expand Up @@ -415,7 +414,6 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val expected = testRelation2
.groupBy(a, c)(alias1, alias2, alias3)
.orderBy(alias1.toAttribute.asc, alias2.toAttribute.asc)
.select(alias1.toAttribute, alias2.toAttribute, alias3.toAttribute)
checkAnalysis(plan, expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)))),
Seq(unresolved_a, unresolved_b), r1))
val expected = Project(Seq(a, b), Sort(
Seq(SortOrder('aggOrder.byte.withNullability(false), Ascending)), true,
Seq(SortOrder(grouping_a, Ascending)), true,
Aggregate(Seq(a, b, gid),
Seq(a, b, grouping_a.as("aggOrder")),
Seq(a, b, gid),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a more correct plan. If you look at other tests in this file, test("grouping function") and test("filter with grouping function"), they both expect plans to group by gid first, then calculate the grouping function (which is grouping_a) in agg list or the filter node above.

Expand(
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
Seq(a, b, c, a, b, 0L)),
Expand All @@ -309,9 +309,9 @@ class ResolveGroupingAnalyticsSuite extends AnalysisTest {
Seq(GroupingSets(Seq(Seq(), Seq(unresolved_a), Seq(unresolved_a, unresolved_b)))),
Seq(unresolved_a, unresolved_b), r1))
val expected3 = Project(Seq(a, b), Sort(
Seq(SortOrder('aggOrder.long.withNullability(false), Ascending)), true,
Seq(SortOrder(gid, Ascending)), true,
Aggregate(Seq(a, b, gid),
Seq(a, b, gid.as("aggOrder")),
Seq(a, b, gid),
Expand(
Seq(Seq(a, b, c, nulInt, nulStr, 3L), Seq(a, b, c, a, nulStr, 1L),
Seq(a, b, c, a, b, 0L)),
Expand Down
27 changes: 11 additions & 16 deletions sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,13 @@ EXPLAIN FORMATTED
struct<plan:string>
-- !query output
== Physical Plan ==
AdaptiveSparkPlan (8)
+- Project (7)
+- Filter (6)
+- HashAggregate (5)
+- Exchange (4)
+- HashAggregate (3)
+- Filter (2)
+- Scan parquet default.explain_temp1 (1)
AdaptiveSparkPlan (7)
+- Filter (6)
+- HashAggregate (5)
+- Exchange (4)
+- HashAggregate (3)
+- Filter (2)
+- Scan parquet default.explain_temp1 (1)


(1) Scan parquet default.explain_temp1
Expand Down Expand Up @@ -186,17 +185,13 @@ Input [2]: [key#x, max#x]
Keys [1]: [key#x]
Functions [1]: [max(val#x)]
Aggregate Attributes [1]: [max(val#x)#x]
Results [3]: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x]
Results [2]: [key#x, max(val#x)#x AS max(val)#x]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removes a duplicated max function.


(6) Filter
Input [3]: [key#x, max(val)#x, max(val#x)#x]
Condition : (isnotnull(max(val#x)#x) AND (max(val#x)#x > 0))

(7) Project
Output [2]: [key#x, max(val)#x]
Input [3]: [key#x, max(val)#x, max(val#x)#x]
Input [2]: [key#x, max(val)#x]
Condition : (isnotnull(max(val)#x) AND (max(val)#x > 0))

(8) AdaptiveSparkPlan
(7) AdaptiveSparkPlan
Output [2]: [key#x, max(val)#x]
Arguments: isFinalPlan=false

Expand Down
25 changes: 10 additions & 15 deletions sql/core/src/test/resources/sql-tests/results/explain.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,13 @@ EXPLAIN FORMATTED
struct<plan:string>
-- !query output
== Physical Plan ==
* Project (8)
+- * Filter (7)
+- * HashAggregate (6)
+- Exchange (5)
+- * HashAggregate (4)
+- * Filter (3)
+- * ColumnarToRow (2)
+- Scan parquet default.explain_temp1 (1)
* Filter (7)
+- * HashAggregate (6)
+- Exchange (5)
+- * HashAggregate (4)
+- * Filter (3)
+- * ColumnarToRow (2)
+- Scan parquet default.explain_temp1 (1)


(1) Scan parquet default.explain_temp1
Expand Down Expand Up @@ -188,15 +187,11 @@ Input [2]: [key#x, max#x]
Keys [1]: [key#x]
Functions [1]: [max(val#x)]
Aggregate Attributes [1]: [max(val#x)#x]
Results [3]: [key#x, max(val#x)#x AS max(val)#x, max(val#x)#x AS max(val#x)#x]
Results [2]: [key#x, max(val#x)#x AS max(val)#x]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


(7) Filter [codegen id : 2]
Input [3]: [key#x, max(val)#x, max(val#x)#x]
Condition : (isnotnull(max(val#x)#x) AND (max(val#x)#x > 0))

(8) Project [codegen id : 2]
Output [2]: [key#x, max(val)#x]
Input [3]: [key#x, max(val)#x, max(val#x)#x]
Input [2]: [key#x, max(val)#x]
Condition : (isnotnull(max(val)#x) AND (max(val)#x > 0))


-- !query
Expand Down
Loading