-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-41631][SQL] Support implicit lateral column alias resolution on Aggregate #39040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 30 commits
04959c2
6f44c85
725e5ac
660e1d2
fd06094
7d4f80f
b9704d5
777f13a
09480ea
c972738
97ee293
5785943
757cffb
29de892
72991c6
d45fe31
1f55f73
f753529
b9f706f
94d5c9e
d2e75fd
edde37c
fb7b18c
3698cff
e700d6a
8d20986
d952aa7
44d5a3d
ccebc1c
5540b70
338ba11
136a930
5076ad2
2f2dee5
3a5509a
a23debb
b200da0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,11 +17,14 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.analysis | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} | ||
| import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} | ||
| import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, Expression, LateralColumnAliasReference, NamedExpression} | ||
| import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} | ||
| import org.apache.spark.sql.catalyst.rules.Rule | ||
| import org.apache.spark.sql.catalyst.trees.TreeNodeTag | ||
| import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE | ||
| import org.apache.spark.sql.catalyst.util.toPrettySQL | ||
| import org.apache.spark.sql.errors.QueryCompilationErrors | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
||
| /** | ||
|
|
@@ -31,30 +34,54 @@ import org.apache.spark.sql.internal.SQLConf | |
| * Plan-wise, it handles two types of operators: Project and Aggregate. | ||
| * - in Project, pushing down the referenced lateral alias into a newly created Project, resolve | ||
| * the attributes referencing these aliases | ||
| * - in Aggregate TODO. | ||
| * - in Aggregate, inserting the Project node above and falling back to the resolution of Project. | ||
| * | ||
| * The whole process is generally divided into two phases: | ||
| * 1) recognize resolved lateral alias, wrap the attributes referencing them with | ||
| * [[LateralColumnAliasReference]] | ||
| * 2) when the whole operator is resolved, unwrap [[LateralColumnAliasReference]]. | ||
| * For Project, it further resolves the attributes and push down the referenced lateral aliases. | ||
| * For Aggregate, TODO | ||
| * 2) when the whole operator is resolved, | ||
| * For Project, it unwrap [[LateralColumnAliasReference]], further resolves the attributes and | ||
| * push down the referenced lateral aliases. | ||
| * For Aggregate, it goes through the whole aggregation list, extracts the aggregation | ||
| * expressions and grouping expressions to keep them in this Aggregate node, and add a Project | ||
| * above with the original output. It doesn't do anything on [[LateralColumnAliasReference]], but | ||
| * completely leave it to the Project in the future turns of this rule. | ||
| * | ||
| * Example for Project: | ||
| * ** Example for Project: | ||
| * Before rewrite: | ||
| * Project [age AS a, 'a + 1] | ||
| * +- Child | ||
| * | ||
| * After phase 1: | ||
| * Project [age AS a, lateralalias(a) + 1] | ||
| * Project [age AS a, lca(a) + 1] | ||
| * +- Child | ||
| * | ||
| * After phase 2: | ||
| * Project [a, a + 1] | ||
| * +- Project [child output, age AS a] | ||
| * +- Child | ||
| * | ||
| * Example for Aggregate TODO | ||
| * ** Example for Aggregate: | ||
| * Before rewrite: | ||
| * Aggregate [dept#14] [dept#14 AS a#12, 'a + 1, avg(salary#16) AS b#13, 'b + avg(bonus#17)] | ||
| * +- Child [dept#14,name#15,salary#16,bonus#17] | ||
| * | ||
| * After phase 1: | ||
| * Aggregate [dept#14] [dept#14 AS a#12, lca(a) + 1, avg(salary#16) AS b#13, lca(b) + avg(bonus#17)] | ||
| * +- Child [dept#14,name#15,salary#16,bonus#17] | ||
| * | ||
| * After phase 2: | ||
| * Project [dept#14 AS a#12, lca(a) + 1, avg(salary)#26 AS b#13, lca(b) + avg(bonus)#27] | ||
| * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27,dept#14] | ||
| * +- Child [dept#14,name#15,salary#16,bonus#17] | ||
| * | ||
| * Now the problem falls back to the lateral alias resolution in Project. | ||
| * After future rounds of this rule: | ||
| * Project [a#12, a#12 + 1, b#13, b#13 + avg(bonus)#27] | ||
| * +- Project [dept#14 AS a#12, avg(salary)#26 AS b#13] | ||
| * +- Aggregate [dept#14] [avg(salary#16) AS avg(salary)#26, avg(bonus#17) AS avg(bonus)#27, | ||
| * dept#14] | ||
| * +- Child [dept#14,name#15,salary#16,bonus#17] | ||
| * | ||
| * | ||
| * The name resolution priority: | ||
|
|
@@ -75,6 +102,13 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { | |
| */ | ||
| val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") | ||
|
|
||
| private def assignAlias(expr: Expression): NamedExpression = { | ||
| expr match { | ||
| case ne: NamedExpression => ne | ||
| case e => Alias(e, toPrettySQL(e))() | ||
| } | ||
| } | ||
|
|
||
| override def apply(plan: LogicalPlan): LogicalPlan = { | ||
| if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { | ||
| plan | ||
|
|
@@ -129,6 +163,45 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { | |
| child = Project(innerProjectList.toSeq, child) | ||
| ) | ||
| } | ||
|
|
||
| case agg @ Aggregate(groupingExpressions, aggregateExpressions, _) if agg.resolved | ||
| && aggregateExpressions.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) => | ||
|
|
||
| val newAggExprs = collection.mutable.Set.empty[NamedExpression] | ||
| val expressionMap = collection.mutable.LinkedHashMap.empty[Expression, NamedExpression] | ||
| val projectExprs = aggregateExpressions.map { exp => | ||
| exp.transformDown { | ||
| case aggExpr: AggregateExpression => | ||
| // Doesn't support referencing a lateral alias in aggregate function | ||
| if (aggExpr.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE)) { | ||
| aggExpr.collectFirst { | ||
| case lcaRef: LateralColumnAliasReference => | ||
| throw QueryCompilationErrors.lateralColumnAliasInAggFuncUnsupportedError( | ||
| lcaRef.nameParts, aggExpr) | ||
| } | ||
| } | ||
| val ne = expressionMap.getOrElseUpdate(aggExpr.canonicalized, assignAlias(aggExpr)) | ||
| newAggExprs += ne | ||
| ne.toAttribute | ||
| case e if groupingExpressions.exists(_.semanticEquals(e)) => | ||
| // TODO one concern here, is condition here be able to match all grouping | ||
|
||
| // expressions? For example, Agg [age + 10] [1 + age + 10], when transforming down, | ||
| // is it possible that (1 + age) + 10, so that it won't be able to match (age + 10) | ||
| // add a test. | ||
| val ne = expressionMap.getOrElseUpdate(e.canonicalized, assignAlias(e)) | ||
| newAggExprs += ne | ||
| ne.toAttribute | ||
| }.asInstanceOf[NamedExpression] | ||
| } | ||
| if (newAggExprs.isEmpty) { | ||
| agg | ||
| } else { | ||
| Project( | ||
| projectList = projectExprs, | ||
| child = agg.copy(aggregateExpressions = newAggExprs.toSeq) | ||
| ) | ||
| } | ||
| // TODO withOrigin? | ||
|
||
| } | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.