From 2d7e36c27998c3bc91938fb624aaa8e9079db805 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 13 Dec 2022 13:10:40 -0800 Subject: [PATCH] refactor --- .../sql/catalyst/analysis/Analyzer.scala | 113 +--------------- ....scala => ResolveLateralColumnAlias.scala} | 125 +++++++++++++++++- .../sql/catalyst/rules/RuleIdCollection.scala | 2 +- 3 files changed, 122 insertions(+), 118 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/{ResolveLateralColumnAliasReference.scala => ResolveLateralColumnAlias.scala} (50%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 042342044453..177b08e94914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.{AlwaysProcess, CurrentOrigin} import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, CaseInsensitiveMap, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -1763,117 +1763,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * The first phase to resolve lateral column alias. See comments in - * [[ResolveLateralColumnAliasReference]] for more detailed explanation. - */ - object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { - import ResolveLateralColumnAliasReference.AliasEntry - - private def insertIntoAliasMap( - a: Alias, - idx: Int, - aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { - val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) - aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) - } - - /** - * Use the given lateral alias to resolve the unresolved attribute with the name parts. - * - * Construct a dummy plan with the given lateral alias as project list, use the output of the - * plan to resolve. - * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. - */ - private def resolveByLateralAlias( - nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { - val resolvedAttr = resolveExpressionByPlanOutput( - expr = UnresolvedAttribute(nameParts), - plan = LocalRelation(Seq(lateralAlias.toAttribute)), - throws = false - ).asInstanceOf[NamedExpression] - if (resolvedAttr.resolved) { - Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) - } else { - None - } - } - - /** - * Recognize all the attributes in the given expression that reference lateral column aliases - * by looking up the alias map. Resolve these attributes and replace by wrapping with - * [[LateralColumnAliasReference]]. - * - * @param currentPlan Because lateral alias has lower resolution priority than table columns, - * the current plan is needed to first try resolving the attribute by its - * children - */ - private def wrapLCARef( - e: NamedExpression, - currentPlan: LogicalPlan, - aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { - case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && - resolveExpressionByPlanChildren(u, currentPlan).isInstanceOf[UnresolvedAttribute] => - val aliases = aliasMap.get(u.nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) - case n if n == 1 && aliases.head.alias.resolved => - // Only resolved alias can be the lateral column alias - // The lateral alias can be a struct and have nested field, need to construct - // a dummy plan to resolve the expression - resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) - case _ => u - } - case o: OuterReference - if aliasMap.contains( - o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) - .map(_.head) - .getOrElse(o.name)) => - // handle OuterReference exactly same as UnresolvedAttribute - val nameParts = o - .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) - .getOrElse(Seq(o.name)) - val aliases = aliasMap.get(nameParts.head).get - aliases.size match { - case n if n > 1 => - throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) - case n if n == 1 && aliases.head.alias.resolved => - resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) - case _ => o - } - }.asInstanceOf[NamedExpression] - } - - override def apply(plan: LogicalPlan): LogicalPlan = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { - plan - } else { - plan.resolveOperatorsUpWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { - case p @ Project(projectList, _) if p.childrenResolved - && !ResolveReferences.containsStar(projectList) - && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => - var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) - val newProjectList = projectList.zipWithIndex.map { - case (a: Alias, idx) => - val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] - // Insert the LCA-resolved alias instead of the unresolved one into map. If it is - // resolved, it can be referenced as LCA by later expressions (chaining). - // Unresolved Alias is also added to the map to perform ambiguous name check, but - // only resolved alias can be LCA. - aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) - lcaWrapped - case (e, _) => - wrapLCARef(e, p, aliasMap) - } - p.copy(projectList = newProjectList) - } - } - } - } - private def containsDeserializer(exprs: Seq[Expression]): Boolean = { exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala similarity index 50% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala index 2ca187b95ffd..93859cb86e02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAliasReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveLateralColumnAlias.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeMap, LateralColumnAliasReference, NamedExpression, OuterReference} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag -import org.apache.spark.sql.catalyst.trees.TreePattern.LATERAL_COLUMN_ALIAS_REFERENCE +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, OUTER_REFERENCE, UNRESOLVED_ATTRIBUTE} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf /** - * This rule is the second phase to resolve lateral column alias. + * The first phase to resolve lateral column alias. * * Resolve lateral column alias, which references the alias defined previously in the SELECT list. * Plan-wise, it handles two types of operators: Project and Aggregate. @@ -65,6 +67,118 @@ import org.apache.spark.sql.internal.SQLConf * [[UnresolvedAttribute]]. If success, it strips [[OuterReference]] and also wraps it with * [[LateralColumnAliasReference]]. */ +object WrapLateralColumnAliasReference extends Rule[LogicalPlan] { + import ResolveLateralColumnAliasReference.AliasEntry + + private def insertIntoAliasMap( + a: Alias, + idx: Int, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): CaseInsensitiveMap[Seq[AliasEntry]] = { + val prevAliases = aliasMap.getOrElse(a.name, Seq.empty[AliasEntry]) + aliasMap + (a.name -> (prevAliases :+ AliasEntry(a, idx))) + } + + /** + * Use the given lateral alias to resolve the unresolved attribute with the name parts. + * + * Construct a dummy plan with the given lateral alias as project list, use the output of the + * plan to resolve. + * @return The resolved [[LateralColumnAliasReference]] if succeeds. None if fails to resolve. + */ + private def resolveByLateralAlias( + nameParts: Seq[String], lateralAlias: Alias): Option[LateralColumnAliasReference] = { + val resolvedAttr = SimpleAnalyzer.resolveExpressionByPlanOutput( + expr = UnresolvedAttribute(nameParts), + plan = LocalRelation(Seq(lateralAlias.toAttribute)), + throws = false + ).asInstanceOf[NamedExpression] + if (resolvedAttr.resolved) { + Some(LateralColumnAliasReference(resolvedAttr, nameParts, lateralAlias.toAttribute)) + } else { + None + } + } + + /** + * Recognize all the attributes in the given expression that reference lateral column aliases + * by looking up the alias map. Resolve these attributes and replace by wrapping with + * [[LateralColumnAliasReference]]. + * + * @param currentPlan Because lateral alias has lower resolution priority than table columns, + * the current plan is needed to first try resolving the attribute by its + * children + */ + private def wrapLCARef( + e: NamedExpression, + currentPlan: LogicalPlan, + aliasMap: CaseInsensitiveMap[Seq[AliasEntry]]): NamedExpression = { + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) { + case u: UnresolvedAttribute if aliasMap.contains(u.nameParts.head) && + SimpleAnalyzer.resolveExpressionByPlanChildren( + u, currentPlan).isInstanceOf[UnresolvedAttribute] => + val aliases = aliasMap.get(u.nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(u.name, n) + case n if n == 1 && aliases.head.alias.resolved => + // Only resolved alias can be the lateral column alias + // The lateral alias can be a struct and have nested field, need to construct + // a dummy plan to resolve the expression + resolveByLateralAlias(u.nameParts, aliases.head.alias).getOrElse(u) + case _ => u + } + case o: OuterReference + if aliasMap.contains( + o.getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .map(_.head) + .getOrElse(o.name)) => + // handle OuterReference exactly same as UnresolvedAttribute + val nameParts = o + .getTagValue(ResolveLateralColumnAliasReference.NAME_PARTS_FROM_UNRESOLVED_ATTR) + .getOrElse(Seq(o.name)) + val aliases = aliasMap.get(nameParts.head).get + aliases.size match { + case n if n > 1 => + throw QueryCompilationErrors.ambiguousLateralColumnAlias(nameParts, n) + case n if n == 1 && aliases.head.alias.resolved => + resolveByLateralAlias(nameParts, aliases.head.alias).getOrElse(o) + case _ => o + } + }.asInstanceOf[NamedExpression] + } + + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { + plan + } else { + plan.resolveOperatorsUpWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE), ruleId) { + case p @ Project(projectList, _) if p.childrenResolved + && !SimpleAnalyzer.ResolveReferences.containsStar(projectList) + && projectList.exists(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, OUTER_REFERENCE)) => + var aliasMap = CaseInsensitiveMap(Map[String, Seq[AliasEntry]]()) + val newProjectList = projectList.zipWithIndex.map { + case (a: Alias, idx) => + val lcaWrapped = wrapLCARef(a, p, aliasMap).asInstanceOf[Alias] + // Insert the LCA-resolved alias instead of the unresolved one into map. If it is + // resolved, it can be referenced as LCA by later expressions (chaining). + // Unresolved Alias is also added to the map to perform ambiguous name check, but + // only resolved alias can be LCA. + aliasMap = insertIntoAliasMap(lcaWrapped, idx, aliasMap) + lcaWrapped + case (e, _) => + wrapLCARef(e, p, aliasMap) + } + p.copy(projectList = newProjectList) + } + } + } +} + +/** + * This rule is the second phase to resolve lateral column alias. + * See comments in [[WrapLateralColumnAliasReference]] for more detailed explanation. + */ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { case class AliasEntry(alias: Alias, index: Int) @@ -73,7 +187,8 @@ object ResolveLateralColumnAliasReference extends Rule[LogicalPlan] { * It is set for [[OuterReference]], used in the current rule to convert [[OuterReference]] back * to [[LateralColumnAliasReference]]. */ - val NAME_PARTS_FROM_UNRESOLVED_ATTR = TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") + val NAME_PARTS_FROM_UNRESOLVED_ATTR: TreeNodeTag[Seq[String]] = + TreeNodeTag[Seq[String]]("name_parts_from_unresolved_attr") override def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index efafd3cfbcde..3b44a9cf30cb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -77,7 +77,6 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowOrder" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$WindowsSubstitution" :: - "org.apache.spark.sql.catalyst.analysis.Analyzer$WrapLateralColumnAliasReference" :: "org.apache.spark.sql.catalyst.analysis.AnsiTypeCoercion$AnsiCombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.ApplyCharTypePadding" :: "org.apache.spark.sql.catalyst.analysis.DeduplicateRelations" :: @@ -97,6 +96,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.TypeCoercionBase$CombinedTypeCoercionRule" :: "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" :: "org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability" :: + "org.apache.spark.sql.catalyst.analysis.WrapLateralColumnAliasReference" :: // Catalyst Optimizer rules "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" :: "org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::