From 545407353964e78ee4e0a6f13760685727edeb59 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 11 Jan 2023 19:29:25 +0800 Subject: [PATCH 1/6] Centralize more column resolution rules --- .../sql/catalyst/analysis/Analyzer.scala | 458 ++---------------- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../analysis/ColumnResolutionHelper.scala | 365 ++++++++++++++ .../catalyst/analysis/ResolveGroupByAll.scala | 119 ----- .../catalyst/analysis/ResolveOrderByAll.scala | 81 ---- .../ResolveReferencesInAggregate.scala | 202 ++++++++ .../analysis/ResolveReferencesInSort.scala | 81 ++++ .../resources/sql-tests/inputs/group-by.sql | 3 + .../sql-tests/results/group-by.sql.out | 9 + 9 files changed, 691 insertions(+), 629 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ce273f01c7aa..7330d7de7423 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -185,8 +185,8 @@ object AnalysisContext { * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a [[SessionCatalog]]. */ -class Analyzer(override val catalogManager: CatalogManager) - extends RuleExecutor[LogicalPlan] with CheckAnalysis with SQLConfHelper { +class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor[LogicalPlan] + with CheckAnalysis with SQLConfHelper with ColumnResolutionHelper { private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog @@ -295,10 +295,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveGroupingAnalytics :: ResolvePivot :: ResolveUnpivot :: - ResolveOrderByAll :: - ResolveGroupByAll :: ResolveOrdinalInOrderByAndGroupBy :: - ResolveAggAliasInGroupBy :: ExtractGenerator :: ResolveGenerate :: ResolveFunctions :: @@ -1489,25 +1486,34 @@ class Analyzer(override val catalogManager: CatalogManager) } /** - * Resolves [[UnresolvedAttribute]]s with the following precedence: - * 1. Resolves it to [[AttributeReference]] with the output of the children plans. This includes - * metadata columns as well. - * 2. If the plan is Project/Aggregate, resolves it to lateral column alias, which is the alias - * defined previously in the SELECT list. - * 3. If the plan is UnresolvedHaving/Filter/Sort + Aggregate, resolves it to - * [[TempResolvedColumn]] with the output of Aggregate's child plan. This is to allow - * UnresolvedHaving/Filter/Sort to host grouping expressions and aggregate functions, which - * can be pushed down to the Aggregate later. - * 4. If the plan is Sort/Filter/RepartitionByExpression, resolves it to [[AttributeReference]] - * with the output of a descendant plan node. Spark will propagate the missing attributes from - * the descendant plan node to the Sort/Filter/RepartitionByExpression node. This is to allow - * users to filter/order/repartition by columns that are not in the SELECT clause, which is - * widely supported in other SQL dialects. - * 5. Resolves it to [[OuterReference]] with the outer plan if this is a subquery plan. + * Resolves column references in the query plan. Basically it transform the query plan tree bottom + * up, and only try to resolve references for a plan node if all its children nodes are resolved, + * and there is no conflicting attributes between the children nodes (see `hasConflictingAttrs` + * for details). + * + * The general workflow to resolve references: + * 1. Expands the star in Project/Aggregate/Generate. + * 2. Resolves the column to [[AttributeReference]] with the output of the children plans. This + * includes metadata columns as well. + * 3. Resolves the column to a literal function which is allowed to be invoked without braces, + * e.g. `SELECT col, current_date FROM t`. + * 4. Resolves the column to outer references with the outer plan if we are resolving subquery + * expressions. + * + * Some plan nodes have special column reference resolution logic, please read these sub-rules for + * details: + * - [[ResolveReferencesInAggregate]] + * - [[ResolveReferencesInSort]] */ - object ResolveReferences extends Rule[LogicalPlan] { + object ResolveReferences extends Rule[LogicalPlan] with ColumnResolutionHelper { - /** Return true if there're conflicting attributes among children's outputs of a plan */ + /** + * Return true if there're conflicting attributes among children's outputs of a plan + * + * The children logical plans may output columns with conflicting attribute IDs. This may happen + * in cases such as self-join. We should wait for the rule [[DeduplicateRelations]] to eliminate + * conflicting attribute IDs, otherwise we can't resolve columns correctly due to ambiguity. + */ def hasConflictingAttrs(p: LogicalPlan): Boolean = { p.children.length > 1 && { // Note that duplicated attributes are allowed within a single node, @@ -1628,31 +1634,7 @@ class Analyzer(override val catalogManager: CatalogManager) // rule: ResolveDeserializer. case plan if containsDeserializer(plan.expressions) => plan - // SPARK-31670: Resolve Struct field in groupByExpressions and aggregateExpressions - // with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with - // different ExprId. This cause aggregateExpressions can't be replaced by expanded - // groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim - // unnecessary alias of GetStructField here. - case a: Aggregate => - val planForResolve = a.child match { - // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of - // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute - // names leading to ambiguous references exception. - case appendColumns: AppendColumns => appendColumns - case _ => a - } - - val resolvedGroupingExprs = a.groupingExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = true)) - .map(trimTopLevelGetStructFieldAlias) - - val resolvedAggExprsNoOuter = a.aggregateExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) - // Aggregate supports Lateral column alias, which has higher priority than outer reference. - val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) - val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) - .map(_.asInstanceOf[NamedExpression]) - a.copy(resolvedGroupingExprs, resolvedAggExprsWithOuter, a.child) + case a: Aggregate => ResolveReferencesInAggregate(a) // Special case for Project as it supports lateral column alias. case p: Project => @@ -1790,82 +1772,13 @@ class Analyzer(override val catalogManager: CatalogManager) Project(child.output, newFilter) } - // Same as Filter, Sort can host both grouping expressions/aggregate functions and missing - // attributes as well. - case s @ Sort(orders, _, child) if !s.resolved || s.missingInput.nonEmpty => - val resolvedNoOuter = orders.map(resolveExpressionByPlanOutput(_, child)) - val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, child)) - val (newOrder, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, child) - // Outer reference has lowermost priority. See the doc of `ResolveReferences`. - val ordering = newOrder.map(e => resolveOuterRef(e).asInstanceOf[SortOrder]) - if (child.output == newChild.output) { - s.copy(order = ordering) - } else { - // Add missing attributes and then project them away. - val newSort = s.copy(order = ordering, child = newChild) - Project(child.output, newSort) - } + case s: Sort if !s.resolved || s.missingInput.nonEmpty => ResolveReferencesInSort(s) case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(conf.maxToStringFields)}") q.mapExpressions(resolveExpressionByPlanChildren(_, q, allowOuter = true)) } - /** - * This method tries to resolve expressions and find missing attributes recursively. - * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes - * or resolved attributes which are missing from child output. This method tries to find the - * missing attributes and add them into the projection. - */ - private def resolveExprsAndAddMissingAttrs( - exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { - // Missing attributes can be unresolved attributes or resolved attributes which are not in - // the output attributes of the plan. - if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { - (exprs, plan) - } else { - plan match { - case p: Project => - // Resolving expressions against current plan. - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) - // Recursively resolving expressions on the child of current plan. - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) - // If some attributes used by expressions are resolvable only on the rewritten child - // plan, we need to add them into original projection. - val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) - (newExprs, Project(p.projectList ++ missingAttrs, newChild)) - - case a @ Aggregate(groupExprs, aggExprs, child) => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) - val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) - if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { - // All the missing attributes are grouping expressions, valid case. - (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) - } else { - // Need to add non-grouping attributes, invalid case. - (exprs, a) - } - - case g: Generate => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) - (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) - - // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes - // via its children. - case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => - val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) - val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) - (newExprs, u.withNewChildren(Seq(newChild))) - - // For other operators, we can't recursively resolve and add attributes via its children. - case other => - (exprs.map(resolveExpressionByPlanOutput(_, other)), other) - } - } - } - private object MergeResolvePolicy extends Enumeration { val BOTH, SOURCE, TARGET = Value } @@ -1916,16 +1829,6 @@ class Analyzer(override val catalogManager: CatalogManager) resolved } - // This method is used to trim groupByExpressions/selectedGroupByExpressions's top-level - // GetStructField Alias. Since these expression are not NamedExpression originally, - // we are safe to trim top-level GetStructField Alias. - def trimTopLevelGetStructFieldAlias(e: Expression): Expression = { - e match { - case Alias(s: GetStructField, _) => s - case other => other - } - } - // Expand the star expression using the input plan first. If failed, try resolve // the star expression using the outer query plan and wrap the resolved attributes // in outer references. Otherwise throw the original exception. @@ -2037,277 +1940,6 @@ class Analyzer(override val catalogManager: CatalogManager) exprs.exists(_.exists(_.isInstanceOf[UnresolvedDeserializer])) } - // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id - private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( - (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), - (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), - (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), - ("user", () => CurrentUser(), toPrettySQL), - (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) - ) - - /** - * Literal functions do not require the user to specify braces when calling them - * When an attributes is not resolvable, we try to resolve it as a literal function. - */ - private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { - if (nameParts.length != 1) return None - val name = nameParts.head - literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { - case (_, getFuncExpr, getAliasName) => - val funcExpr = getFuncExpr() - Alias(funcExpr, getAliasName(funcExpr))() - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by - * traversing the input expression in top-down manner. It must be top-down because we need to - * skip over unbound lambda function expression. The lambda expressions are resolved in a - * different place [[ResolveLambdaVariables]]. - * - * Example : - * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" - * - * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. - */ - private def resolveExpression( - expr: Expression, - resolveColumnByName: Seq[String] => Option[Expression], - getAttrCandidates: () => Seq[Attribute], - throws: Boolean, - allowOuter: Boolean): Expression = { - def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { - if (e.resolved) return e - val resolved = e match { - case f: LambdaFunction if !f.bound => f - - case GetColumnByOrdinal(ordinal, _) => - val attrCandidates = getAttrCandidates() - assert(ordinal >= 0 && ordinal < attrCandidates.length) - attrCandidates(ordinal) - - case GetViewColumnByNameAndOrdinal( - viewName, colName, ordinal, expectedNumCandidates, viewDDL) => - val attrCandidates = getAttrCandidates() - val matched = attrCandidates.filter(a => resolver(a.name, colName)) - if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChangeError( - viewName, colName, expectedNumCandidates, matched, viewDDL) - } - matched(ordinal) - - case u @ UnresolvedAttribute(nameParts) => - val result = withPosition(u) { - resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { - // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, - // as we should resolve `UnresolvedAttribute` to a named expression. The caller side - // can trim the top-level alias if it's safe to do so. Since we will call - // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. - case Alias(child, _) if !isTopLevel => child - case other => other - }.getOrElse(u) - } - logDebug(s"Resolving $u to $result") - result - - // Re-resolves `TempResolvedColumn` if it has tried to be resolved with Aggregate - // but failed. If we still can't resolve it, we should keep it as `TempResolvedColumn`, - // so that it won't become a fresh `TempResolvedColumn` again. - case t: TempResolvedColumn if t.hasTried => withPosition(t) { - innerResolve(UnresolvedAttribute(t.nameParts), isTopLevel) match { - case _: UnresolvedAttribute => t - case other => other - } - } - - case u @ UnresolvedExtractValue(child, fieldName) => - val newChild = innerResolve(child, isTopLevel = false) - if (newChild.resolved) { - ExtractValue(newChild, fieldName, resolver) - } else { - u.copy(child = newChild) - } - - case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) - } - resolved.copyTagsFrom(e) - resolved - } - - try { - val resolved = innerResolve(expr, isTopLevel = true) - if (allowOuter) resolveOuterRef(resolved) else resolved - } catch { - case ae: AnalysisException if !throws => - logDebug(ae.getMessage) - expr - } - } - - // Resolves `UnresolvedAttribute` to `OuterReference`. - private def resolveOuterRef(e: Expression): Expression = { - val outerPlan = AnalysisContext.get.outerPlan - if (outerPlan.isEmpty) return e - - def resolve(nameParts: Seq[String]): Option[Expression] = try { - outerPlan.get match { - // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. - // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will - // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. - case u @ UnresolvedHaving(_, agg: Aggregate) => - agg.resolveChildren(nameParts, resolver).orElse(u.resolveChildren(nameParts, resolver)) - .map(wrapOuterReference) - case other => - other.resolveChildren(nameParts, resolver).map(wrapOuterReference) - } - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - None - } - - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { - case u: UnresolvedAttribute => - resolve(u.nameParts).getOrElse(u) - // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with - // Aggregate but failed. - case t: TempResolvedColumn if t.hasTried => - resolve(t.nameParts).getOrElse(t) - } - } - - // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an - // `Aggregate`. If `TempResolvedColumn` doesn't end up as aggregate function input or grouping - // column, we will undo the column resolution later to avoid confusing error message. E,g,, if - // a table `t` has columns `c1` and `c2`, for query `SELECT ... FROM t GROUP BY c1 HAVING c2 = 0`, - // even though we can resolve column `c2` here, we should undo it and fail with - // "Column c2 not found". - private def resolveColWithAgg(e: Expression, plan: LogicalPlan): Expression = plan match { - case agg: Aggregate => - e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute => - try { - agg.child.resolve(u.nameParts, resolver).map({ - case a: Alias => TempResolvedColumn(a.child, u.nameParts) - case o => TempResolvedColumn(o, u.nameParts) - }).getOrElse(u) - } catch { - case ae: AnalysisException => - logDebug(ae.getMessage) - u - } - } - case _ => e - } - - private def resolveLateralColumnAlias(selectList: Seq[Expression]): Seq[Expression] = { - if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) return selectList - - // A mapping from lower-cased alias name to either the Alias itself, or the count of aliases - // that have the same lower-cased name. If the count is larger than 1, we won't use it to - // resolve lateral column aliases. - val aliasMap = mutable.HashMap.empty[String, Either[Alias, Int]] - - def resolve(e: Expression): Expression = { - e.transformWithPruning( - _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, LATERAL_COLUMN_ALIAS_REFERENCE)) { - case u: UnresolvedAttribute => - // Lateral column alias does not have qualifiers. We always use the first name part to - // look up lateral column aliases. - val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT) - aliasMap.get(lowerCasedName).map { - case scala.util.Left(alias) => - if (alias.resolved) { - val resolvedAttr = resolveExpressionByPlanOutput( - u, LocalRelation(Seq(alias.toAttribute)), throws = true - ).asInstanceOf[NamedExpression] - assert(resolvedAttr.resolved) - LateralColumnAliasReference(resolvedAttr, u.nameParts, alias.toAttribute) - } else { - // Still returns a `LateralColumnAliasReference` even if the lateral column alias - // is not resolved yet. This is to make sure we won't mistakenly resolve it to - // outer references. - LateralColumnAliasReference(u, u.nameParts, alias.toAttribute) - } - case scala.util.Right(count) => - throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, count) - }.getOrElse(u) - - case LateralColumnAliasReference(u: UnresolvedAttribute, _, _) => - resolve(u) - } - } - - selectList.map { - case a: Alias => - val result = resolve(a) - val lowerCasedName = a.name.toLowerCase(Locale.ROOT) - aliasMap.get(lowerCasedName) match { - case Some(scala.util.Left(_)) => - aliasMap(lowerCasedName) = scala.util.Right(2) - case Some(scala.util.Right(count)) => - aliasMap(lowerCasedName) = scala.util.Right(count + 1) - case None => - aliasMap += lowerCasedName -> scala.util.Left(a) - } - result - case other => resolve(other) - } - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's output attributes. In order to resolve the nested fields correctly, this function - * makes use of `throws` parameter to control when to raise an AnalysisException. - * - * Example : - * SELECT * FROM t ORDER BY a.b - * - * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` - * if there is no such nested field named "b". We should not fail and wait for other rules to - * resolve it if possible. - */ - def resolveExpressionByPlanOutput( - expr: Expression, - plan: LogicalPlan, - throws: Boolean = false, - allowOuter: Boolean = false): Expression = { - resolveExpression( - expr, - resolveColumnByName = nameParts => { - plan.resolve(nameParts, resolver) - }, - getAttrCandidates = () => plan.output, - throws = throws, - allowOuter = allowOuter) - } - - /** - * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the - * input plan's children output attributes. - * - * @param e The expression need to be resolved. - * @param q The LogicalPlan whose children are used to resolve expression's attribute. - * @return resolved Expression. - */ - def resolveExpressionByPlanChildren( - e: Expression, - q: LogicalPlan, - allowOuter: Boolean = false): Expression = { - resolveExpression( - e, - resolveColumnByName = nameParts => { - q.resolveChildren(nameParts, resolver) - }, - getAttrCandidates = () => { - assert(q.children.length == 1) - q.children.head.output - }, - throws = true, - allowOuter = allowOuter) - } - /** * In many dialects of SQL it is valid to use ordinal positions in order/sort by and group by * clauses. This rule is to convert ordinal positions to the corresponding expressions in the @@ -2377,36 +2009,6 @@ class Analyzer(override val catalogManager: CatalogManager) } } - /** - * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. - * This rule is expected to run after [[ResolveReferences]] applied. - */ - object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { - - // This is a strict check though, we put this to apply the rule only if the expression is not - // resolvable by child. - private def notResolvableByChild(attrName: String, child: LogicalPlan): Boolean = { - !child.output.exists(a => resolver(a.name, attrName)) - } - - private def mayResolveAttrByAggregateExprs( - exprs: Seq[Expression], aggs: Seq[NamedExpression], child: LogicalPlan): Seq[Expression] = { - exprs.map { _.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { - case u: UnresolvedAttribute if notResolvableByChild(u.name, child) => - aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) - }} - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - // mayResolveAttrByAggregateExprs requires the TreePattern UNRESOLVED_ATTRIBUTE. - _.containsAllPatterns(AGGREGATE, UNRESOLVED_ATTRIBUTE), ruleId) { - case agg @ Aggregate(groups, aggs, child) - if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && - groups.exists(!_.resolved) => - agg.copy(groupingExpressions = mayResolveAttrByAggregateExprs(groups, aggs, child)) - } - } - /** * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bc7b031a7382..c020cf727b08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -238,7 +238,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB // Fail if we still have an unresolved all in group by. This needs to run before the // general unresolved check below to throw a more tailored error message. - ResolveGroupByAll.checkAnalysis(operator) + ResolveReferencesInAggregate.checkUnresolvedGroupByAll(operator) getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala new file mode 100644 index 000000000000..d6026133f851 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.SubExprUtils.wrapOuterReference +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin +import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf + +trait ColumnResolutionHelper extends Logging { + + def conf: SQLConf + + /** + * This method tries to resolve expressions and find missing attributes recursively. + * Specifically, when the expressions used in `Sort` or `Filter` contain unresolved attributes + * or resolved attributes which are missing from child output. This method tries to find the + * missing attributes and add them into the projection. + */ + protected def resolveExprsAndAddMissingAttrs( + exprs: Seq[Expression], plan: LogicalPlan): (Seq[Expression], LogicalPlan) = { + // Missing attributes can be unresolved attributes or resolved attributes which are not in + // the output attributes of the plan. + if (exprs.forall(e => e.resolved && e.references.subsetOf(plan.outputSet))) { + (exprs, plan) + } else { + plan match { + case p: Project => + // Resolving expressions against current plan. + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, p)) + // Recursively resolving expressions on the child of current plan. + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, p.child) + // If some attributes used by expressions are resolvable only on the rewritten child + // plan, we need to add them into original projection. + val missingAttrs = (AttributeSet(newExprs) -- p.outputSet).intersect(newChild.outputSet) + (newExprs, Project(p.projectList ++ missingAttrs, newChild)) + + case a @ Aggregate(groupExprs, aggExprs, child) => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, a)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, child) + val missingAttrs = (AttributeSet(newExprs) -- a.outputSet).intersect(newChild.outputSet) + if (missingAttrs.forall(attr => groupExprs.exists(_.semanticEquals(attr)))) { + // All the missing attributes are grouping expressions, valid case. + (newExprs, a.copy(aggregateExpressions = aggExprs ++ missingAttrs, child = newChild)) + } else { + // Need to add non-grouping attributes, invalid case. + (exprs, a) + } + + case g: Generate => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, g)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, g.child) + (newExprs, g.copy(unrequiredChildIndex = Nil, child = newChild)) + + // For `Distinct` and `SubqueryAlias`, we can't recursively resolve and add attributes + // via its children. + case u: UnaryNode if !u.isInstanceOf[Distinct] && !u.isInstanceOf[SubqueryAlias] => + val maybeResolvedExprs = exprs.map(resolveExpressionByPlanOutput(_, u)) + val (newExprs, newChild) = resolveExprsAndAddMissingAttrs(maybeResolvedExprs, u.child) + (newExprs, u.withNewChildren(Seq(newChild))) + + // For other operators, we can't recursively resolve and add attributes via its children. + case other => + (exprs.map(resolveExpressionByPlanOutput(_, other)), other) + } + } + } + + // support CURRENT_DATE, CURRENT_TIMESTAMP, and grouping__id + private val literalFunctions: Seq[(String, () => Expression, Expression => String)] = Seq( + (CurrentDate().prettyName, () => CurrentDate(), toPrettySQL(_)), + (CurrentTimestamp().prettyName, () => CurrentTimestamp(), toPrettySQL(_)), + (CurrentUser().prettyName, () => CurrentUser(), toPrettySQL), + ("user", () => CurrentUser(), toPrettySQL), + (VirtualColumn.hiveGroupingIdName, () => GroupingID(Nil), _ => VirtualColumn.hiveGroupingIdName) + ) + + /** + * Literal functions do not require the user to specify braces when calling them + * When an attributes is not resolvable, we try to resolve it as a literal function. + */ + private def resolveLiteralFunction(nameParts: Seq[String]): Option[NamedExpression] = { + if (nameParts.length != 1) return None + val name = nameParts.head + literalFunctions.find(func => caseInsensitiveResolution(func._1, name)).map { + case (_, getFuncExpr, getAliasName) => + val funcExpr = getFuncExpr() + Alias(funcExpr, getAliasName(funcExpr))() + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by + * traversing the input expression in top-down manner. It must be top-down because we need to + * skip over unbound lambda function expression. The lambda expressions are resolved in a + * different place [[ResolveLambdaVariables]]. + * + * Example : + * SELECT transform(array(1, 2, 3), (x, i) -> x + i)" + * + * In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]. + */ + private def resolveExpression( + expr: Expression, + resolveColumnByName: Seq[String] => Option[Expression], + getAttrCandidates: () => Seq[Attribute], + throws: Boolean, + allowOuter: Boolean): Expression = { + def innerResolve(e: Expression, isTopLevel: Boolean): Expression = withOrigin(e.origin) { + if (e.resolved) return e + val resolved = e match { + case f: LambdaFunction if !f.bound => f + + case GetColumnByOrdinal(ordinal, _) => + val attrCandidates = getAttrCandidates() + assert(ordinal >= 0 && ordinal < attrCandidates.length) + attrCandidates(ordinal) + + case GetViewColumnByNameAndOrdinal( + viewName, colName, ordinal, expectedNumCandidates, viewDDL) => + val attrCandidates = getAttrCandidates() + val matched = attrCandidates.filter(a => conf.resolver(a.name, colName)) + if (matched.length != expectedNumCandidates) { + throw QueryCompilationErrors.incompatibleViewSchemaChange( + viewName, colName, expectedNumCandidates, matched, viewDDL) + } + matched(ordinal) + + case u @ UnresolvedAttribute(nameParts) => + val result = withPosition(u) { + resolveColumnByName(nameParts).orElse(resolveLiteralFunction(nameParts)).map { + // We trim unnecessary alias here. Note that, we cannot trim the alias at top-level, + // as we should resolve `UnresolvedAttribute` to a named expression. The caller side + // can trim the top-level alias if it's safe to do so. Since we will call + // CleanupAliases later in Analyzer, trim non top-level unnecessary alias is safe. + case Alias(child, _) if !isTopLevel => child + case other => other + }.getOrElse(u) + } + logDebug(s"Resolving $u to $result") + result + + // Re-resolves `TempResolvedColumn` if it has tried to be resolved with Aggregate + // but failed. If we still can't resolve it, we should keep it as `TempResolvedColumn`, + // so that it won't become a fresh `TempResolvedColumn` again. + case t: TempResolvedColumn if t.hasTried => withPosition(t) { + innerResolve(UnresolvedAttribute(t.nameParts), isTopLevel) match { + case _: UnresolvedAttribute => t + case other => other + } + } + + case u @ UnresolvedExtractValue(child, fieldName) => + val newChild = innerResolve(child, isTopLevel = false) + if (newChild.resolved) { + ExtractValue(newChild, fieldName, conf.resolver) + } else { + u.copy(child = newChild) + } + + case _ => e.mapChildren(innerResolve(_, isTopLevel = false)) + } + resolved.copyTagsFrom(e) + resolved + } + + try { + val resolved = innerResolve(expr, isTopLevel = true) + if (allowOuter) resolveOuterRef(resolved) else resolved + } catch { + case ae: AnalysisException if !throws => + logDebug(ae.getMessage) + expr + } + } + + // Resolves `UnresolvedAttribute` to `OuterReference`. + protected def resolveOuterRef(e: Expression): Expression = { + val outerPlan = AnalysisContext.get.outerPlan + if (outerPlan.isEmpty) return e + + def resolve(nameParts: Seq[String]): Option[Expression] = try { + outerPlan.get match { + // Subqueries in UnresolvedHaving can host grouping expressions and aggregate functions. + // We should resolve columns with `agg.output` and the rule `ResolveAggregateFunctions` will + // push them down to Aggregate later. This is similar to what we do in `resolveColumns`. + case u @ UnresolvedHaving(_, agg: Aggregate) => + agg.resolveChildren(nameParts, conf.resolver) + .orElse(u.resolveChildren(nameParts, conf.resolver)) + .map(wrapOuterReference) + case other => + other.resolveChildren(nameParts, conf.resolver).map(wrapOuterReference) + } + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + None + } + + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE, TEMP_RESOLVED_COLUMN)) { + case u: UnresolvedAttribute => + resolve(u.nameParts).getOrElse(u) + // Re-resolves `TempResolvedColumn` as outer references if it has tried to be resolved with + // Aggregate but failed. + case t: TempResolvedColumn if t.hasTried => + resolve(t.nameParts).getOrElse(t) + } + } + + // Resolves `UnresolvedAttribute` to `TempResolvedColumn` via `plan.child.output` if plan is an + // `Aggregate`. If `TempResolvedColumn` doesn't end up as aggregate function input or grouping + // column, we will undo the column resolution later to avoid confusing error message. E,g,, if + // a table `t` has columns `c1` and `c2`, for query `SELECT ... FROM t GROUP BY c1 HAVING c2 = 0`, + // even though we can resolve column `c2` here, we should undo it and fail with + // "Column c2 not found". + protected def resolveColWithAgg(e: Expression, plan: LogicalPlan): Expression = plan match { + case agg: Aggregate => + e.transformWithPruning(_.containsAnyPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute => + try { + agg.child.resolve(u.nameParts, conf.resolver).map({ + case a: Alias => TempResolvedColumn(a.child, u.nameParts) + case o => TempResolvedColumn(o, u.nameParts) + }).getOrElse(u) + } catch { + case ae: AnalysisException => + logDebug(ae.getMessage) + u + } + } + case _ => e + } + + protected def resolveLateralColumnAlias(selectList: Seq[Expression]): Seq[Expression] = { + if (!conf.getConf(SQLConf.LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED)) return selectList + + // A mapping from lower-cased alias name to either the Alias itself, or the count of aliases + // that have the same lower-cased name. If the count is larger than 1, we won't use it to + // resolve lateral column aliases. + val aliasMap = mutable.HashMap.empty[String, Either[Alias, Int]] + + def resolve(e: Expression): Expression = { + e.transformWithPruning( + _.containsAnyPattern(UNRESOLVED_ATTRIBUTE, LATERAL_COLUMN_ALIAS_REFERENCE)) { + case u: UnresolvedAttribute => + // Lateral column alias does not have qualifiers. We always use the first name part to + // look up lateral column aliases. + val lowerCasedName = u.nameParts.head.toLowerCase(Locale.ROOT) + aliasMap.get(lowerCasedName).map { + case scala.util.Left(alias) => + if (alias.resolved) { + val resolvedAttr = resolveExpressionByPlanOutput( + u, LocalRelation(Seq(alias.toAttribute)), throws = true + ).asInstanceOf[NamedExpression] + assert(resolvedAttr.resolved) + LateralColumnAliasReference(resolvedAttr, u.nameParts, alias.toAttribute) + } else { + // Still returns a `LateralColumnAliasReference` even if the lateral column alias + // is not resolved yet. This is to make sure we won't mistakenly resolve it to + // outer references. + LateralColumnAliasReference(u, u.nameParts, alias.toAttribute) + } + case scala.util.Right(count) => + throw QueryCompilationErrors.ambiguousLateralColumnAliasError(u.name, count) + }.getOrElse(u) + + case LateralColumnAliasReference(u: UnresolvedAttribute, _, _) => + resolve(u) + } + } + + selectList.map { + case a: Alias => + val result = resolve(a) + val lowerCasedName = a.name.toLowerCase(Locale.ROOT) + aliasMap.get(lowerCasedName) match { + case Some(scala.util.Left(_)) => + aliasMap(lowerCasedName) = scala.util.Right(2) + case Some(scala.util.Right(count)) => + aliasMap(lowerCasedName) = scala.util.Right(count + 1) + case None => + aliasMap += lowerCasedName -> scala.util.Left(a) + } + result + case other => resolve(other) + } + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's output attributes. In order to resolve the nested fields correctly, this function + * makes use of `throws` parameter to control when to raise an AnalysisException. + * + * Example : + * SELECT * FROM t ORDER BY a.b + * + * In the above example, after `a` is resolved to a struct-type column, we may fail to resolve `b` + * if there is no such nested field named "b". We should not fail and wait for other rules to + * resolve it if possible. + */ + def resolveExpressionByPlanOutput( + expr: Expression, + plan: LogicalPlan, + throws: Boolean = false, + allowOuter: Boolean = false): Expression = { + resolveExpression( + expr, + resolveColumnByName = nameParts => { + plan.resolve(nameParts, conf.resolver) + }, + getAttrCandidates = () => plan.output, + throws = throws, + allowOuter = allowOuter) + } + + /** + * Resolves `UnresolvedAttribute`, `GetColumnByOrdinal` and extract value expressions(s) by the + * input plan's children output attributes. + * + * @param e The expression need to be resolved. + * @param q The LogicalPlan whose children are used to resolve expression's attribute. + * @return resolved Expression. + */ + def resolveExpressionByPlanChildren( + e: Expression, + q: LogicalPlan, + allowOuter: Boolean = false): Expression = { + resolveExpression( + e, + resolveColumnByName = nameParts => { + q.resolveChildren(nameParts, conf.resolver) + }, + getAttrCandidates = () => { + assert(q.children.length == 1) + q.children.head.output + }, + throws = true, + allowOuter = allowOuter) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala deleted file mode 100644 index 8c6ba20cd1af..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveGroupByAll.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, UNRESOLVED_ATTRIBUTE} - -/** - * Resolve "group by all" in the following SQL pattern: - * `select col1, col2, agg_expr(...) from table group by all`. - * - * The all is expanded to include all non-aggregate columns in the select clause. - */ -object ResolveGroupByAll extends Rule[LogicalPlan] { - - val ALL = "ALL" - - /** - * Returns true iff this is a GROUP BY ALL aggregate. i.e. an Aggregate expression that has - * a single unresolved all grouping expression. - */ - private def matchToken(a: Aggregate): Boolean = { - if (a.groupingExpressions.size != 1) { - return false - } - a.groupingExpressions.head match { - case a: UnresolvedAttribute => a.equalsIgnoreCase(ALL) - case _ => false - } - } - - /** - * Returns all the grouping expressions inferred from a GROUP BY ALL aggregate. - * The result is optional. If Spark fails to infer the grouping columns, it is None. - * Otherwise, it contains all the non-aggregate expressions from the project list of the input - * Aggregate. - */ - private def getGroupingExpressions(a: Aggregate): Option[Seq[Expression]] = { - val groupingExprs = a.aggregateExpressions.filter(!_.exists(AggregateExpression.isAggregate)) - // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or - // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will - // not automatically infer the grouping column to be "i". - if (groupingExprs.isEmpty && a.aggregateExpressions.exists(containsAttribute)) { - None - } else { - Some(groupingExprs) - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, AGGREGATE), ruleId) { - case a: Aggregate - if a.child.resolved && a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. - // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping - // expression list (because we don't know they would be aggregate expressions until resolved). - val groupingExprs = getGroupingExpressions(a) - - if (groupingExprs.isEmpty) { - // Don't replace the ALL when we fail to infer the grouping columns. We will eventually - // tell the user in checkAnalysis that we cannot resolve the all in group by. - a - } else { - // This is a valid GROUP BY ALL aggregate. - a.copy(groupingExpressions = groupingExprs.get) - } - } - - /** - * Returns true if the expression includes an Attribute outside the aggregate expression part. - * For example: - * "i" -> true - * "i + 2" -> true - * "i + sum(j)" -> true - * "sum(j)" -> false - * "sum(j) / 2" -> false - */ - private def containsAttribute(expr: Expression): Boolean = expr match { - case _ if AggregateExpression.isAggregate(expr) => - // Don't recurse into AggregateExpressions - false - case _: Attribute => - true - case e => - e.children.exists(containsAttribute) - } - - /** - * A check to be used in [[CheckAnalysis]] to see if we have any unresolved group by at the - * end of analysis, so we can tell users that we fail to infer the grouping columns. - */ - def checkAnalysis(operator: LogicalPlan): Unit = operator match { - case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && matchToken(a) => - if (getGroupingExpressions(a).isEmpty) { - operator.failAnalysis( - errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", - messageParameters = Map.empty) - } - case _ => - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala deleted file mode 100644 index 7cf584dadcf3..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveOrderByAll.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.trees.TreePattern.{SORT, UNRESOLVED_ATTRIBUTE} - -/** - * Resolve "order by all" in the following SQL pattern: - * `select col1, col2 from table order by all`. - * - * It orders the query result by all columns, from left to right. The query above becomes: - * - * `select col1, col2 from table order by col1, col2` - * - * This should also support specifying asc/desc, and nulls first/last. - */ -object ResolveOrderByAll extends Rule[LogicalPlan] { - - val ALL = "ALL" - - /** - * An extractor to pull out the SortOrder field in the ORDER BY ALL clause. We pull out that - * SortOrder object so we can pass its direction and null ordering. - */ - object OrderByAll { - def unapply(s: Sort): Option[SortOrder] = { - // This only applies to global ordering. - if (!s.global) { - return None - } - // Don't do this if we have more than one order field. That means it's not order by all. - if (s.order.size != 1) { - return None - } - // Don't do this if there's a child field called ALL. That should take precedence. - if (s.child.output.exists(_.name.toUpperCase() == ALL)) { - return None - } - - s.order.find { so => - so.child match { - case a: UnresolvedAttribute => a.name.toUpperCase() == ALL - case _ => false - } - } - } - } - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning( - _.containsAllPatterns(UNRESOLVED_ATTRIBUTE, SORT), ruleId) { - // This only makes sense if the child is resolved. - case s: Sort if s.child.resolved => - s match { - case OrderByAll(sortOrder) => - // Replace a single order by all with N fields, where N = child's output, while - // retaining the same asc/desc and nulls ordering. - val order = s.child.output.map(a => sortOrder.copy(child = a)) - s.copy(order = order) - case _ => - s - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala new file mode 100644 index 000000000000..8f5193ad73f8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, GetStructField, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[Aggregate]]. It's only used by the real + * rule `ResolveReferences`. The column resolution order for [[Aggregate]] is: + * 1. Resolves the column to [[AttributeReference]] with the output of the child plan. This + * includes metadata columns as well. + * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. + * `SELECT col, current_date FROM t`. + * 3. If `Aggregate.aggregateExpressions` are all resolved, resolve GROUP BY alias and GROUP BY ALL + * for `Aggregate.groupingExpressions`: + * 3.1. If the grouping expressions contain an unresolved column whose name matches an alias in the + * SELECT list, resolves that unresolved column to the alias. This is to support SQL pattern + * like `SELECT a + b AS c, max(col) FROM t GROUP BY c`. + * 3.2. If the grouping expressions only have one single unresolved column named 'ALL', expanded it + * to include all non-aggregate columns in the SELECT list. This is to support SQL pattern like + * `SELECT col1, col2, agg_expr(...) FROM t GROUP BY ALL`. + * 4. Resolves the column in `Aggregate.aggregateExpressions` to [[LateralColumnAliasReference]] if + * it references the alias defined previously in the SELECT list. The rule + * `ResolveLateralColumnAliasReference` will further resolve [[LateralColumnAliasReference]] and + * rewrite the plan. This is to support SQL pattern like + * `SELECT col1 + 1 AS x, x + 1 AS y, y + 1 AS z FROM t`. + * 5. Resolves the column to outer references with the outer plan if we are resolving subquery + * expressions. + */ +object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionHelper { + def apply(a: Aggregate): Aggregate = { + val planForResolve = a.child match { + // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of + // `AppendColumns`, because `AppendColumns`'s serializer might produce conflict attribute + // names leading to ambiguous references exception. + case appendColumns: AppendColumns => appendColumns + case _ => a + } + + val resolvedAggExprsNoOuter = a.aggregateExpressions + .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false) + .asInstanceOf[NamedExpression]) + + val resolvedGroupingExprs = a.groupingExpressions + .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) + // SPARK-31670: Resolve Struct field in groupByExpressions and aggregateExpressions + // with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with + // different ExprId. This cause aggregateExpressions can't be replaced by expanded + // groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim + // unnecessary alias of GetStructField here. + .map(trimTopLevelGetStructFieldAlias) + + // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. + // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping + // expression list (because we don't know they would be aggregate expressions until resolved). + if (resolvedAggExprsNoOuter.forall(_.resolved)) { + val finalGroupExprs = resolveGroupByAll( + resolvedAggExprsNoOuter, + resolveGroupByAlias(resolvedAggExprsNoOuter, resolvedGroupingExprs) + ).map(resolveOuterRef) + a.copy(finalGroupExprs, resolvedAggExprsNoOuter, a.child) + } else { + // If the SELECT list is not full resolved at this point, we need to apply lateral column + // alias and outer reference resolution, which are not supported in GROUP BY. We can't + // resolve group by alias and group by all here. + // Aggregate supports Lateral column alias, which has higher priority than outer reference. + val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) + val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) + .map(_.asInstanceOf[NamedExpression]) + a.copy(resolvedGroupingExprs.map(resolveOuterRef), resolvedAggExprsWithOuter, a.child) + } + } + + private def resolveGroupByAlias( + selectList: Seq[NamedExpression], + groupExprs: Seq[Expression]): Seq[Expression] = { + assert(selectList.forall(_.resolved)) + if (conf.groupByAliases) { + groupExprs.map { g => + g.transformWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) { + case u: UnresolvedAttribute => + selectList.find(ne => conf.resolver(ne.name, u.name)).getOrElse(u) + } + } + } else { + groupExprs + } + } + + private def resolveGroupByAll( + selectList: Seq[NamedExpression], + groupExprs: Seq[Expression]): Seq[Expression] = { + assert(selectList.forall(_.resolved)) + if (isGroupByAll(groupExprs)) { + val expandedGroupExprs = expandGroupByAll(selectList) + if (expandedGroupExprs.isEmpty) { + // Don't replace the ALL when we fail to infer the grouping columns. We will eventually + // tell the user in checkAnalysis that we cannot resolve the all in group by. + groupExprs + } else { + // This is a valid GROUP BY ALL aggregate. + expandedGroupExprs.get + } + } else { + groupExprs + } + } + + /** + * Returns all the grouping expressions inferred from a GROUP BY ALL aggregate. + * The result is optional. If Spark fails to infer the grouping columns, it is None. + * Otherwise, it contains all the non-aggregate expressions from the project list of the input + * Aggregate. + */ + private def expandGroupByAll(selectList: Seq[NamedExpression]): Option[Seq[Expression]] = { + val groupingExprs = selectList.filter(!_.exists(AggregateExpression.isAggregate)) + // If the grouping exprs are empty, this could either be (1) a valid global aggregate, or + // (2) we simply fail to infer the grouping columns. As an example, in "i + sum(j)", we will + // not automatically infer the grouping column to be "i". + if (groupingExprs.isEmpty && selectList.exists(containsAttribute)) { + None + } else { + Some(groupingExprs) + } + } + + /** + * Trim groupByExpression's top-level GetStructField Alias. Since these expressions are not + * NamedExpression originally, we are safe to trim top-level GetStructField Alias. + */ + private def trimTopLevelGetStructFieldAlias(e: Expression): Expression = { + e match { + case Alias(s: GetStructField, _) => s + case other => other + } + } + + /** + * Returns true iff this is a GROUP BY ALL: the grouping expressions only have a single column, + * which is an unresolved column named ALL. + */ + private def isGroupByAll(exprs: Seq[Expression]): Boolean = { + if (exprs.length != 1) return false + exprs.head match { + case a: UnresolvedAttribute => a.equalsIgnoreCase("ALL") + case _ => false + } + } + + /** + * Returns true if the expression includes an Attribute outside the aggregate expression part. + * For example: + * "i" -> true + * "i + 2" -> true + * "i + sum(j)" -> true + * "sum(j)" -> false + * "sum(j) / 2" -> false + */ + private def containsAttribute(expr: Expression): Boolean = expr match { + case _ if AggregateExpression.isAggregate(expr) => + // Don't recurse into AggregateExpressions + false + case _: Attribute => + true + case e => + e.children.exists(containsAttribute) + } + + /** + * A check to be used in [[CheckAnalysis]] to see if we have any unresolved group by at the + * end of analysis, so we can tell users that we fail to infer the grouping columns. + */ + def checkUnresolvedGroupByAll(operator: LogicalPlan): Unit = operator match { + case a: Aggregate if a.aggregateExpressions.forall(_.resolved) && + isGroupByAll(a.groupingExpressions) => + if (expandGroupByAll(a.aggregateExpressions).isEmpty) { + operator.failAnalysis( + errorClass = "UNRESOLVED_ALL_IN_GROUP_BY", + messageParameters = Map.empty) + } + case _ => + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala new file mode 100644 index 000000000000..5a5e636253df --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[Sort]]. It's only used by the real + * rule `ResolveReferences`. The column resolution order for [[Sort]] is: + * 1. Resolves the column to [[AttributeReference]] with the output of the child plan. This + * includes metadata columns as well. + * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. + * `SELECT col, current_date FROM t`. + * 3. If the child plan is Aggregate, resolves the column to [[TempResolvedColumn]] with the output + * of Aggregate's child plan. This is to allow Sort to host grouping expressions and aggregate + * functions, which can be pushed down to the Aggregate later. For example, + * `SELECT max(a) FROM t GROUP BY b ORDER BY min(a)`. + * 4. Resolves the column to [[AttributeReference]] with the output of a descendant plan node. + * Spark will propagate the missing attributes from the descendant plan node to the Sort node. + * This is to allow users to ORDER BY columns that are not in the SELECT clause, which is + * widely supported in other SQL dialects. For example, `SELECT a FROM t ORDER BY b`. + * 5. If the order by expressions only have one single unresolved column named ALL, expanded it to + * include all columns in the SELECT list. This is to support SQL pattern like + * `SELECT col1, col2 FROM t ORDER BY ALL`. This should also support specifying asc/desc, and + * nulls first/last. + * 6. Resolves the column to outer references with the outer plan if we are resolving subquery + * expressions. + */ +object ResolveReferencesInSort extends SQLConfHelper with ColumnResolutionHelper { + + def apply(s: Sort): LogicalPlan = { + val resolvedNoOuter = s.order.map(resolveExpressionByPlanOutput(_, s.child)) + val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, s.child)) + val (missingAttrResolved, newChild) = resolveExprsAndAddMissingAttrs(resolvedWithAgg, s.child) + val orderByAllResolved = resolveOrderByAll( + s.global, newChild, missingAttrResolved.map(_.asInstanceOf[SortOrder])) + val finalOrdering = orderByAllResolved.map(e => resolveOuterRef(e).asInstanceOf[SortOrder]) + if (s.child.output == newChild.output) { + s.copy(order = finalOrdering) + } else { + // Add missing attributes and then project them away. + val newSort = s.copy(order = finalOrdering, child = newChild) + Project(s.child.output, newSort) + } + } + + private def resolveOrderByAll( + globalSort: Boolean, + child: LogicalPlan, + orders: Seq[SortOrder]): Seq[SortOrder] = { + // This only applies to global ordering. + if (!globalSort) return orders + // Don't do this if we have more than one order field. That means it's not order by all. + if (orders.length != 1) return orders + + val order = orders.head + order.child match { + case a: UnresolvedAttribute if a.equalsIgnoreCase("ALL") => + // Replace a single order by all with N fields, where N = child's output, while + // retaining the same asc/desc and nulls ordering. + child.output.map(a => order.copy(child = a)) + case _ => orders + } + } +} diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 1615c43cc7ed..0f3d1a11c567 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -45,6 +45,9 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS SELECT a AS k, COUNT(b) FROM testData GROUP BY k; SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; +-- GROUP BY alias inside subquery expression with conflicting outer reference +SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT 1 AS a GROUP BY a); + -- GROUP BY alias with invalid col in SELECT list SELECT a AS k, COUNT(non_existing) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 0402039fafac..4faa66f9988b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -182,6 +182,15 @@ struct 3 2 +-- !query +SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT 1 AS a GROUP BY a) +-- !query schema +struct +-- !query output +1 1 +1 2 + + -- !query SELECT a AS k, COUNT(non_existing) FROM testData GROUP BY k -- !query schema From a78757a3155e3f27ba19fdc203e8281b3222529b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 13 Jan 2023 23:37:50 +0800 Subject: [PATCH 2/6] address comments --- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 +- .../ResolveReferencesInAggregate.scala | 51 ++++++++++--------- .../sql-tests/inputs/group-by-all.sql | 3 ++ .../resources/sql-tests/inputs/group-by.sql | 6 +++ .../sql-tests/results/group-by-all.sql.out | 16 ++++++ .../sql-tests/results/group-by.sql.out | 44 ++++++++++++++++ 6 files changed, 98 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c020cf727b08..c66105d1715d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -762,8 +762,8 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB private def getAllExpressions(plan: LogicalPlan): Seq[Expression] = { plan match { - // `groupingExpressions` may rely on `aggregateExpressions`, due to the GROUP BY alias - // feature. We should check errors in `aggregateExpressions` first. + // We only resolve `groupingExpressions` if `aggregateExpressions` is resolved first (See + // `ResolveReferencesInAggregate`). We should check errors in `aggregateExpressions` first. case a: Aggregate => a.aggregateExpressions ++ a.groupingExpressions case _ => plan.expressions } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 8f5193ad73f8..77ee12d1e2e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, GetStructField, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} -import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[Aggregate]]. It's only used by the real @@ -56,11 +56,7 @@ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionH case _ => a } - val resolvedAggExprsNoOuter = a.aggregateExpressions - .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false) - .asInstanceOf[NamedExpression]) - - val resolvedGroupingExprs = a.groupingExpressions + val resolvedGroupExprsNoOuter = a.groupingExpressions .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) // SPARK-31670: Resolve Struct field in groupByExpressions and aggregateExpressions // with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with @@ -68,25 +64,32 @@ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionH // groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim // unnecessary alias of GetStructField here. .map(trimTopLevelGetStructFieldAlias) - - // Only makes sense to do the rewrite once all the aggregate expressions have been resolved. - // Otherwise, we might incorrectly pull an actual aggregate expression over to the grouping - // expression list (because we don't know they would be aggregate expressions until resolved). - if (resolvedAggExprsNoOuter.forall(_.resolved)) { - val finalGroupExprs = resolveGroupByAll( - resolvedAggExprsNoOuter, - resolveGroupByAlias(resolvedAggExprsNoOuter, resolvedGroupingExprs) - ).map(resolveOuterRef) - a.copy(finalGroupExprs, resolvedAggExprsNoOuter, a.child) + val resolvedAggExprsNoOuter = a.aggregateExpressions.map( + resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) + val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) + val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) + .map(_.asInstanceOf[NamedExpression]) + // `groupingExpressions` may rely on `aggregateExpressions`, due to features like GROUP BY alias + // and GROUP BY ALL. We only do basic resolution for `groupingExpressions`, and will further + // resolve it after `aggregateExpressions` are all resolved. Note: the basic resolution is + // needed as `aggregateExpressions` may rely on `groupingExpressions` as well, for the session + // window feature. See the rule `SessionWindowing` for more details. + if (resolvedAggExprsWithOuter.forall(_.resolved)) { + // TODO: currently we don't support LCA in `groupingExpressions` yet. + if (resolvedAggExprsWithOuter.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))) { + a.copy(resolvedGroupExprsNoOuter.map(resolveOuterRef), resolvedAggExprsWithOuter, a.child) + } else { + val finalGroupExprs = resolveGroupByAll( + resolvedAggExprsWithOuter, + resolveGroupByAlias( + resolvedAggExprsWithOuter, resolvedGroupExprsNoOuter) + ).map(resolveOuterRef) + a.copy(finalGroupExprs, resolvedAggExprsWithOuter, a.child) + } } else { - // If the SELECT list is not full resolved at this point, we need to apply lateral column - // alias and outer reference resolution, which are not supported in GROUP BY. We can't - // resolve group by alias and group by all here. - // Aggregate supports Lateral column alias, which has higher priority than outer reference. - val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) - val resolvedAggExprsWithOuter = resolvedAggExprsWithLCA.map(resolveOuterRef) - .map(_.asInstanceOf[NamedExpression]) - a.copy(resolvedGroupingExprs.map(resolveOuterRef), resolvedAggExprsWithOuter, a.child) + a.copy( + groupingExpressions = resolvedGroupExprsNoOuter, + aggregateExpressions = resolvedAggExprsWithOuter) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql index 4400c0b57866..6f3f2d640eb9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql @@ -71,6 +71,9 @@ select id + count(*) from data group by all; -- an even more complex case that we choose not to infer; fail with a useful error message select (id + id) / 2 + count(*) * 2 from data group by all; +-- GROUP BY alias has higher priority than GROUP BY all, this query fails as `id` is not in GROUP BY +select country as all, id from data group by all; + -- uncorrelated subquery should work select country, (select count(*) from data) as cnt, count(id) as cnt_id from data group by all; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 0f3d1a11c567..28025d993f15 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -45,6 +45,12 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS SELECT a AS k, COUNT(b) FROM testData GROUP BY k; SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; +-- GROUP BY alias is not triggered if SELECT list has lateral column alias. +SELECT 1 AS x, x + 1 AS k FROM testData GROUP BY k; + +-- GROUP BY alias is not triggered if SELECT list has outer reference. +SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT a AS k GROUP BY k); + -- GROUP BY alias inside subquery expression with conflicting outer reference SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT 1 AS a GROUP BY a); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out index d8a2e743d6b6..202e4234c61b 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out @@ -226,6 +226,22 @@ org.apache.spark.sql.AnalysisException } +-- !query +select country as all, id from data group by all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42000", + "messageParameters" : { + "expression" : "\"id\"", + "expressionAnyValue" : "\"any_value(id)\"" + } +} + + -- !query select country, (select count(*) from data) as cnt, count(id) as cnt_id from data group by all -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4faa66f9988b..06c22ef4ed5d 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -182,6 +182,50 @@ struct 3 2 +-- !query +SELECT 1 AS x, x + 1 AS k FROM testData GROUP BY k +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", + "sqlState" : "42000", + "messageParameters" : { + "objectName" : "`k`", + "proposal" : "`testdata`.`a`, `testdata`.`b`" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 50, + "stopIndex" : 50, + "fragment" : "k" + } ] +} + + +-- !query +SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT a AS k GROUP BY k) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", + "messageParameters" : { + "sqlExprs" : "\"a\",\"a AS k\"" + }, + "queryContext" : [ { + "objectType" : "", + "objectName" : "", + "startIndex" : 62, + "stopIndex" : 71, + "fragment" : "GROUP BY k" + } ] +} + + -- !query SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT 1 AS a GROUP BY a) -- !query schema From 5e3ee81182054e1781685fd42eb5b95155c162d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Jan 2023 14:32:45 +0800 Subject: [PATCH 3/6] update golden files --- .../src/test/resources/sql-tests/results/group-by-all.sql.out | 2 +- sql/core/src/test/resources/sql-tests/results/group-by.sql.out | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out index 202e4234c61b..2b3e8fe9dfd3 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out @@ -234,7 +234,7 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "MISSING_AGGREGATION", - "sqlState" : "42000", + "sqlState" : "42803", "messageParameters" : { "expression" : "\"id\"", "expressionAnyValue" : "\"any_value(id)\"" diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 06c22ef4ed5d..b52bb7c20399 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -190,7 +190,7 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", - "sqlState" : "42000", + "sqlState" : "42703", "messageParameters" : { "objectName" : "`k`", "proposal" : "`testdata`.`a`, `testdata`.`b`" @@ -213,6 +213,7 @@ struct<> org.apache.spark.sql.AnalysisException { "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", + "sqlState" : "0A000", "messageParameters" : { "sqlExprs" : "\"a\",\"a AS k\"" }, From a20f82b05d591b9e0ca76a00bc7981636e605b56 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 29 Jan 2023 21:29:10 +0800 Subject: [PATCH 4/6] address comments --- .../main/resources/error/error-classes.json | 5 + .../sql/catalyst/analysis/Analyzer.scala | 12 +- .../ResolveReferencesInAggregate.scala | 44 ++++--- .../analysis/ResolveReferencesInSort.scala | 3 + .../inputs/column-resolution-aggregate.sql | 30 +++++ .../inputs/column-resolution-sort.sql | 20 +++ .../sql-tests/inputs/group-by-all.sql | 3 - .../resources/sql-tests/inputs/group-by.sql | 9 -- .../column-resolution-aggregate.sql.out | 121 ++++++++++++++++++ .../results/column-resolution-sort.sql.out | 42 ++++++ .../sql-tests/results/group-by-all.sql.out | 16 --- .../sql-tests/results/group-by.sql.out | 54 -------- 12 files changed, 255 insertions(+), 104 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql create mode 100644 sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out create mode 100644 sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 230b616800fb..84af7b5d64f5 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -1532,6 +1532,11 @@ "Referencing a lateral column alias in the aggregate function ." ] }, + "LATERAL_COLUMN_ALIAS_IN_GROUP_BY" : { + "message" : [ + "Referencing a lateral column alias via GROUP BY alias/ALL is not supported yet." + ] + }, "LATERAL_JOIN_USING" : { "message" : [ "JOIN USING with LATERAL correlation." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7330d7de7423..28ae09e123cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1493,17 +1493,23 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * * The general workflow to resolve references: * 1. Expands the star in Project/Aggregate/Generate. - * 2. Resolves the column to [[AttributeReference]] with the output of the children plans. This + * 2. Resolves the columns to [[AttributeReference]] with the output of the children plans. This * includes metadata columns as well. - * 3. Resolves the column to a literal function which is allowed to be invoked without braces, + * 3. Resolves the columns to literal function which is allowed to be invoked without braces, * e.g. `SELECT col, current_date FROM t`. - * 4. Resolves the column to outer references with the outer plan if we are resolving subquery + * 4. Resolves the columns to outer references with the outer plan if we are resolving subquery * expressions. * * Some plan nodes have special column reference resolution logic, please read these sub-rules for * details: * - [[ResolveReferencesInAggregate]] * - [[ResolveReferencesInSort]] + * + * Note: even if we use a single rule to resolve columns, it's still non-trivial to have a + * reliable column resolution order, as the rule will be executed multiple times, with other + * rules in the same batch. We should resolve columns with the next option only if all the + * previous options are permanently not applicable. If the current option can be applicable + * in the next iteration (other rules update the plan), we should not try the next option. */ object ResolveReferences extends Rule[LogicalPlan] with ColumnResolutionHelper { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index 77ee12d1e2e3..b45799c2283a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, GetStructField, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -26,24 +27,23 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REF /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[Aggregate]]. It's only used by the real * rule `ResolveReferences`. The column resolution order for [[Aggregate]] is: - * 1. Resolves the column to [[AttributeReference]] with the output of the child plan. This + * 1. Resolves the columns to [[AttributeReference]] with the output of the child plan. This * includes metadata columns as well. - * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. + * 2. Resolves the columns to a literal function which is allowed to be invoked without braces, e.g. * `SELECT col, current_date FROM t`. - * 3. If `Aggregate.aggregateExpressions` are all resolved, resolve GROUP BY alias and GROUP BY ALL - * for `Aggregate.groupingExpressions`: + * 3. If aggregate expressions are all resolved, resolve GROUP BY alias and GROUP BY ALL. * 3.1. If the grouping expressions contain an unresolved column whose name matches an alias in the * SELECT list, resolves that unresolved column to the alias. This is to support SQL pattern * like `SELECT a + b AS c, max(col) FROM t GROUP BY c`. * 3.2. If the grouping expressions only have one single unresolved column named 'ALL', expanded it * to include all non-aggregate columns in the SELECT list. This is to support SQL pattern like * `SELECT col1, col2, agg_expr(...) FROM t GROUP BY ALL`. - * 4. Resolves the column in `Aggregate.aggregateExpressions` to [[LateralColumnAliasReference]] if + * 4. Resolves the columns in aggregate expressions to [[LateralColumnAliasReference]] if * it references the alias defined previously in the SELECT list. The rule * `ResolveLateralColumnAliasReference` will further resolve [[LateralColumnAliasReference]] and * rewrite the plan. This is to support SQL pattern like * `SELECT col1 + 1 AS x, x + 1 AS y, y + 1 AS z FROM t`. - * 5. Resolves the column to outer references with the outer plan if we are resolving subquery + * 5. Resolves the columns to outer references with the outer plan if we are resolving subquery * expressions. */ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionHelper { @@ -74,23 +74,29 @@ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionH // resolve it after `aggregateExpressions` are all resolved. Note: the basic resolution is // needed as `aggregateExpressions` may rely on `groupingExpressions` as well, for the session // window feature. See the rule `SessionWindowing` for more details. - if (resolvedAggExprsWithOuter.forall(_.resolved)) { + val resolvedGroupExprs = if (resolvedAggExprsWithOuter.forall(_.resolved)) { + val resolved = resolveGroupByAll( + resolvedAggExprsWithOuter, + resolveGroupByAlias(resolvedAggExprsWithOuter, resolvedGroupExprsNoOuter) + ).map(resolveOuterRef) // TODO: currently we don't support LCA in `groupingExpressions` yet. - if (resolvedAggExprsWithOuter.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))) { - a.copy(resolvedGroupExprsNoOuter.map(resolveOuterRef), resolvedAggExprsWithOuter, a.child) - } else { - val finalGroupExprs = resolveGroupByAll( - resolvedAggExprsWithOuter, - resolveGroupByAlias( - resolvedAggExprsWithOuter, resolvedGroupExprsNoOuter) - ).map(resolveOuterRef) - a.copy(finalGroupExprs, resolvedAggExprsWithOuter, a.child) + if (resolved.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))) { + throw new AnalysisException( + errorClass = "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GROUP_BY", + messageParameters = Map.empty) } + resolved } else { - a.copy( - groupingExpressions = resolvedGroupExprsNoOuter, - aggregateExpressions = resolvedAggExprsWithOuter) + // Do not resolve columns in grouping expressions to outer references here, as the aggregate + // expressions are not fully resolved yet and we still have chances to resolve GROUP BY + // alias/ALL in the next iteration. If aggregate expressions end up as unresolved, we don't + // need to resolve grouping expressions at all, as `CheckAnalysis` will report error for + // aggregate expressions first. + resolvedGroupExprsNoOuter } + a.copy( + groupingExpressions = resolvedGroupExprs, + aggregateExpressions = resolvedAggExprsWithOuter) } private def resolveGroupByAlias( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala index 5a5e636253df..54044932d9e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala @@ -41,6 +41,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} * nulls first/last. * 6. Resolves the column to outer references with the outer plan if we are resolving subquery * expressions. + * + * Note, 3 and 4 are actually orthogonal. If the child plan is Aggregate, 4 can only resolve columns + * as the grouping columns, which is completely covered by 3. */ object ResolveReferencesInSort extends SQLConfHelper with ColumnResolutionHelper { diff --git a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql new file mode 100644 index 000000000000..e797b9839460 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql @@ -0,0 +1,30 @@ +-- Tests covering column resolution priority in Aggregate. + +CREATE TEMPORARY VIEW v1 AS VALUES (1, 1, 1), (2, 2, 1) AS t(a, b, k); +CREATE TEMPORARY VIEW v2 AS VALUES (1, 1, 1), (2, 2, 1) AS t(x, y, all); + +-- Relation output columns have higher priority than lateral column alias. This query +-- should fail as `b` is not in GROUP BY. +SELECT max(a) AS b, b FROM v1 GROUP BY k; + +-- Lateral column alias has higher priority than outer reference. +SELECT a FROM v1 WHERE (12, 13) IN (SELECT max(x + 10) AS a, a + 1 FROM v2); + +-- Relation output columns have higher priority than GROUP BY alias. This query should +-- fail as `a` is not in GROUP BY. +SELECT a AS k FROM v1 GROUP BY k; + +-- Relation output columns have higher priority than GROUP BY ALL. This query should +-- fail as `x` is not in GROUP BY. +SELECT x FROM v2 GROUP BY all; + +-- GROUP BY alias has higher priority than GROUP BY ALL, this query fails as `b` is not in GROUP BY. +SELECT a AS all, b FROM v1 GROUP BY all; + +-- GROUP BY alias/ALL does not support lateral column alias. +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY k, col; +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY all; + +-- GROUP BY ALL has higher priority than outer reference. This query should run as `a` and `b` are +-- in GROUP BY due to the GROUP BY ALL resolution. +SELECT * FROM v2 WHERE EXISTS (SELECT a, b FROM v1 GROUP BY all); diff --git a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql new file mode 100644 index 000000000000..da559da8fa07 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql @@ -0,0 +1,20 @@ +--SET spark.sql.leafNodeDefaultParallelism=1 +-- Tests covering column resolution priority in Sort. + +CREATE TEMPORARY VIEW v1 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, k); +CREATE TEMPORARY VIEW v2 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, all); + +-- Relation output columns have higher priority than missing reference. +-- Results will be [2, 1] if we order by the column `v1.b`. +-- Actually results are [1, 2] as we order by `max(a) AS b`. +SELECT max(a) AS b FROM v1 GROUP BY k ORDER BY b; + +-- Missing reference has higher priority than ORDER BY ALL. +-- Results will be [1, 2] if we order by `max(a)`. +-- Actually results are [2, 1] as we order by the grouping column `v2.all`. +SELECT max(a) FROM v2 GROUP BY all ORDER BY all; + +-- ORDER BY ALL has higher priority than outer reference. +-- Results will be [1, 1] if we order by outer reference 'v2.all'. +-- Actually results are [2, 2] as we order by column `v1.b` +SELECT (SELECT b FROM v1 ORDER BY all LIMIT 1) FROM v2; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql index 6f3f2d640eb9..4400c0b57866 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-all.sql @@ -71,9 +71,6 @@ select id + count(*) from data group by all; -- an even more complex case that we choose not to infer; fail with a useful error message select (id + id) / 2 + count(*) * 2 from data group by all; --- GROUP BY alias has higher priority than GROUP BY all, this query fails as `id` is not in GROUP BY -select country as all, id from data group by all; - -- uncorrelated subquery should work select country, (select count(*) from data) as cnt, count(id) as cnt_id from data group by all; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 28025d993f15..1615c43cc7ed 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -45,15 +45,6 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS SELECT a AS k, COUNT(b) FROM testData GROUP BY k; SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; --- GROUP BY alias is not triggered if SELECT list has lateral column alias. -SELECT 1 AS x, x + 1 AS k FROM testData GROUP BY k; - --- GROUP BY alias is not triggered if SELECT list has outer reference. -SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT a AS k GROUP BY k); - --- GROUP BY alias inside subquery expression with conflicting outer reference -SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT 1 AS a GROUP BY a); - -- GROUP BY alias with invalid col in SELECT list SELECT a AS k, COUNT(non_existing) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out b/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out new file mode 100644 index 000000000000..6cc71fbcfb7c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out @@ -0,0 +1,121 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE TEMPORARY VIEW v1 AS VALUES (1, 1, 1), (2, 2, 1) AS t(a, b, k) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TEMPORARY VIEW v2 AS VALUES (1, 1, 1), (2, 2, 1) AS t(x, y, all) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT max(a) AS b, b FROM v1 GROUP BY k +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"b\"", + "expressionAnyValue" : "\"any_value(b)\"" + } +} + + +-- !query +SELECT a FROM v1 WHERE (12, 13) IN (SELECT max(x + 10) AS a, a + 1 FROM v2) +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT a AS k FROM v1 GROUP BY k +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"a\"", + "expressionAnyValue" : "\"any_value(a)\"" + } +} + + +-- !query +SELECT x FROM v2 GROUP BY all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"x\"", + "expressionAnyValue" : "\"any_value(x)\"" + } +} + + +-- !query +SELECT a AS all, b FROM v1 GROUP BY all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "MISSING_AGGREGATION", + "sqlState" : "42803", + "messageParameters" : { + "expression" : "\"b\"", + "expressionAnyValue" : "\"any_value(b)\"" + } +} + + +-- !query +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY k, col +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GROUP_BY", + "sqlState" : "0A000" +} + + +-- !query +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY all +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +{ + "errorClass" : "UNSUPPORTED_FEATURE.LATERAL_COLUMN_ALIAS_IN_GROUP_BY", + "sqlState" : "0A000" +} + + +-- !query +SELECT * FROM v2 WHERE EXISTS (SELECT a, b FROM v1 GROUP BY all) +-- !query schema +struct +-- !query output +1 1 1 +2 2 1 diff --git a/sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out b/sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out new file mode 100644 index 000000000000..67323d734c90 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/column-resolution-sort.sql.out @@ -0,0 +1,42 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE TEMPORARY VIEW v1 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, k) +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TEMPORARY VIEW v2 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, all) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT max(a) AS b FROM v1 GROUP BY k ORDER BY b +-- !query schema +struct +-- !query output +1 +2 + + +-- !query +SELECT max(a) FROM v2 GROUP BY all ORDER BY all +-- !query schema +struct +-- !query output +2 +1 + + +-- !query +SELECT (SELECT b FROM v1 ORDER BY all LIMIT 1) FROM v2 +-- !query schema +struct +-- !query output +1 +1 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out index 2b3e8fe9dfd3..d8a2e743d6b6 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-all.sql.out @@ -226,22 +226,6 @@ org.apache.spark.sql.AnalysisException } --- !query -select country as all, id from data group by all --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "MISSING_AGGREGATION", - "sqlState" : "42803", - "messageParameters" : { - "expression" : "\"id\"", - "expressionAnyValue" : "\"any_value(id)\"" - } -} - - -- !query select country, (select count(*) from data) as cnt, count(id) as cnt_id from data group by all -- !query schema diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index b52bb7c20399..0402039fafac 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -182,60 +182,6 @@ struct 3 2 --- !query -SELECT 1 AS x, x + 1 AS k FROM testData GROUP BY k --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "UNRESOLVED_COLUMN.WITH_SUGGESTION", - "sqlState" : "42703", - "messageParameters" : { - "objectName" : "`k`", - "proposal" : "`testdata`.`a`, `testdata`.`b`" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 50, - "stopIndex" : 50, - "fragment" : "k" - } ] -} - - --- !query -SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT a AS k GROUP BY k) --- !query schema -struct<> --- !query output -org.apache.spark.sql.AnalysisException -{ - "errorClass" : "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.CORRELATED_REFERENCE", - "sqlState" : "0A000", - "messageParameters" : { - "sqlExprs" : "\"a\",\"a AS k\"" - }, - "queryContext" : [ { - "objectType" : "", - "objectName" : "", - "startIndex" : 62, - "stopIndex" : 71, - "fragment" : "GROUP BY k" - } ] -} - - --- !query -SELECT * FROM testData WHERE a = 1 AND EXISTS (SELECT 1 AS a GROUP BY a) --- !query schema -struct --- !query output -1 1 -1 2 - - -- !query SELECT a AS k, COUNT(non_existing) FROM testData GROUP BY k -- !query schema From 79a183ad8ee8fab8dc7f79ac097af5b6d4f6d0e0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 31 Jan 2023 16:32:44 +0800 Subject: [PATCH 5/6] address comments --- .../ResolveReferencesInAggregate.scala | 30 +++++++------------ .../inputs/column-resolution-aggregate.sql | 3 ++ .../inputs/column-resolution-sort.sql | 2 +- .../column-resolution-aggregate.sql.out | 8 +++++ 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala index b45799c2283a..4af2ecc91ab5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, GetStructField, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{AliasHelper, Attribute, Expression, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, AppendColumns, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REFERENCE, UNRESOLVED_ATTRIBUTE} @@ -46,7 +46,8 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{LATERAL_COLUMN_ALIAS_REF * 5. Resolves the columns to outer references with the outer plan if we are resolving subquery * expressions. */ -object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionHelper { +object ResolveReferencesInAggregate extends SQLConfHelper + with ColumnResolutionHelper with AliasHelper { def apply(a: Aggregate): Aggregate = { val planForResolve = a.child match { // SPARK-25942: Resolves aggregate expressions with `AppendColumns`'s children, instead of @@ -58,12 +59,6 @@ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionH val resolvedGroupExprsNoOuter = a.groupingExpressions .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) - // SPARK-31670: Resolve Struct field in groupByExpressions and aggregateExpressions - // with CUBE/ROLLUP will be wrapped with alias like Alias(GetStructField, name) with - // different ExprId. This cause aggregateExpressions can't be replaced by expanded - // groupByExpressions in `ResolveGroupingAnalytics.constructAggregateExprs()`, we trim - // unnecessary alias of GetStructField here. - .map(trimTopLevelGetStructFieldAlias) val resolvedAggExprsNoOuter = a.aggregateExpressions.map( resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false)) val resolvedAggExprsWithLCA = resolveLateralColumnAlias(resolvedAggExprsNoOuter) @@ -95,7 +90,13 @@ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionH resolvedGroupExprsNoOuter } a.copy( - groupingExpressions = resolvedGroupExprs, + // The aliases in grouping expressions are useless and will be removed at the end of analysis + // by the rule `CleanupAliases`. However, some rules need to find the grouping expressions + // from aggregate expressions during analysis. If we don't remove alias here, then these rules + // can't find the grouping expressions via `semanticEquals` and the analysis will fail. + // Example rules: ResolveGroupingAnalytics (See SPARK-31670 for more details) and + // ResolveLateralColumnAliasReference. + groupingExpressions = resolvedGroupExprs.map(trimAliases), aggregateExpressions = resolvedAggExprsWithOuter) } @@ -152,17 +153,6 @@ object ResolveReferencesInAggregate extends SQLConfHelper with ColumnResolutionH } } - /** - * Trim groupByExpression's top-level GetStructField Alias. Since these expressions are not - * NamedExpression originally, we are safe to trim top-level GetStructField Alias. - */ - private def trimTopLevelGetStructFieldAlias(e: Expression): Expression = { - e match { - case Alias(s: GetStructField, _) => s - case other => other - } - } - /** * Returns true iff this is a GROUP BY ALL: the grouping expressions only have a single column, * which is an unresolved column named ALL. diff --git a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql index e797b9839460..4f879fc809d9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-aggregate.sql @@ -25,6 +25,9 @@ SELECT a AS all, b FROM v1 GROUP BY all; SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY k, col; SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY all; +-- GROUP BY alias still works if it does not directly reference lateral column alias. +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY lca; + -- GROUP BY ALL has higher priority than outer reference. This query should run as `a` and `b` are -- in GROUP BY due to the GROUP BY ALL resolution. SELECT * FROM v2 WHERE EXISTS (SELECT a, b FROM v1 GROUP BY all); diff --git a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql index da559da8fa07..2c5b9f9e9dfc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/column-resolution-sort.sql @@ -5,7 +5,7 @@ CREATE TEMPORARY VIEW v1 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, k); CREATE TEMPORARY VIEW v2 AS VALUES (1, 2, 2), (2, 1, 1) AS t(a, b, all); -- Relation output columns have higher priority than missing reference. --- Results will be [2, 1] if we order by the column `v1.b`. +-- Query will fail if we order by the column `v1.b`, as it's not in GROUP BY. -- Actually results are [1, 2] as we order by `max(a) AS b`. SELECT max(a) AS b FROM v1 GROUP BY k ORDER BY b; diff --git a/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out b/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out index 6cc71fbcfb7c..e8ab766751c4 100644 --- a/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/column-resolution-aggregate.sql.out @@ -112,6 +112,14 @@ org.apache.spark.sql.AnalysisException } +-- !query +SELECT k AS lca, lca + 1 AS col FROM v1 GROUP BY lca +-- !query schema +struct +-- !query output +1 2 + + -- !query SELECT * FROM v2 WHERE EXISTS (SELECT a, b FROM v1 GROUP BY all) -- !query schema From 4f09dc9e4aa083984dc45c1cd502ac9a89b1e338 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 31 Jan 2023 18:46:25 +0800 Subject: [PATCH 6/6] fix conflicts --- .../spark/sql/catalyst/analysis/ColumnResolutionHelper.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index d6026133f851..9ac64cf4658d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -146,7 +146,7 @@ trait ColumnResolutionHelper extends Logging { val attrCandidates = getAttrCandidates() val matched = attrCandidates.filter(a => conf.resolver(a.name, colName)) if (matched.length != expectedNumCandidates) { - throw QueryCompilationErrors.incompatibleViewSchemaChange( + throw QueryCompilationErrors.incompatibleViewSchemaChangeError( viewName, colName, expectedNumCandidates, matched, viewDDL) } matched(ordinal)