diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a4c61149dd97..681321113d0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -150,8 +150,8 @@ class Analyzer( * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) + : ExpressionSet = { + val set = new ExpressionSet() var bit = exprs.length - 1 while (bit >= 0) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c8288c676700..7534e0780aa8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -84,16 +84,17 @@ trait CheckAnalysis { s"of type ${f.condition.dataType.simpleString} is not a boolean.") case Aggregate(groupingExprs, aggregateExprs, child) => + val normalizedGroupingExprs = ExpressionSet(groupingExprs) def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.contains(e) => + case e if normalizedGroupingExprs.contains(e) => // OK + case e if e.children.size > 0 => e.children.foreach(checkValidAggregateExpression) + case e: NamedExpression => failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + - s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.contains(e) => // OK - case e if e.references.isEmpty => // OK - case e => e.children.foreach(checkValidAggregateExpression) + s"""expression '${e.prettyString}' is neither present in the group by, + nor is it an aggregate function. + Add to group by or wrap in first() if you don't care which value you get.""") + case _ => // OK e.g Literal } val cleaned = aggregateExprs.map(_.transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMap.scala new file mode 100644 index 000000000000..9c3f84661b4e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionMap.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * Builds a map that is keyed by an normalized expression. Using the expression allows values + * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization + * of the name, or the expected nullability). + */ +sealed class ExpressionMap[A] extends Serializable { + private val baseMap = new collection.mutable.HashMap[Expression, A]() + def get(k: Expression): Option[A] = baseMap.get(ExpressionEquals.normalize(k)) + + def add(k: Expression, value: A): Unit = { + baseMap.put(ExpressionEquals.normalize(k), value) + } + + def values: Iterable[A] = baseMap.values +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala new file mode 100644 index 000000000000..cb5011020218 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +private[expressions] object ExpressionEquals { + def normalize(expr: Expression): Expression = expr.transformUp { + case n: AttributeReference => + // We don't care about the name of AttributeReference in its semantic equality check + new AttributeReference(null, n.dataType, n.nullable, n.metadata)(n.exprId, n.qualifiers) + } +} + +object ExpressionSet { + def apply(exprs: Iterable[Expression]): ExpressionSet = { + val set = new ExpressionSet() + exprs.foreach(e => set.add(e)) + + set + } +} + +/** + * Builds a Expression Set that used to be looked up even when the attributes used + * differ cosmetically (i.e., the capitalization of the name, or the expected nullability). + */ +sealed class ExpressionSet extends Serializable { + private val baseSet: java.util.Set[Expression] = new java.util.HashSet[Expression]() + def contains(expr: Expression): Boolean = { + baseSet.contains(ExpressionEquals.normalize(expr)) + } + def add(expr: Expression): Unit = baseSet.add(ExpressionEquals.normalize(expr)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd54d04814ea..3422259f179a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -143,11 +143,22 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = - groupingExpressions.filter(!_.isInstanceOf[Literal]).map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + val namedGroupingMap = new ExpressionMap[NamedExpression]() + // Output (Raw Expression, A named Expression, Its associated Attribute) + val namedGroupingTuples = groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + case n: NamedExpression => + (n: Expression, n, n.toAttribute) + case other => + val v = Alias(other, "PartialGroup")() + (other, v, v.toAttribute) + } + + val partialGroupingExprs = namedGroupingTuples.map(_._2) + val namedGroupingAttributes = namedGroupingTuples.map(_._3) + // Construct the expression map for substitution in Final Aggregate Expression + namedGroupingTuples.foreach { case (expr, namedExpr, attr) => + namedGroupingMap.add(expr, namedExpr) + } // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. @@ -159,18 +170,16 @@ object PartialAggregation { // Should trim aliases around `GetField`s. These aliases are introduced while // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) - namedGroupingExpressions + namedGroupingMap .get(e.transform { case Alias(g: ExtractValue, _) => g }) .map(_.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = - (namedGroupingExpressions.values ++ + (partialGroupingExprs ++ partialEvaluations.values.flatMap(_.partialEvaluations)).toSeq - val namedGroupingAttributes = namedGroupingExpressions.values.map(_.toAttribute).toSeq - Some( (namedGroupingAttributes, rewrittenAggregateExpressions, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 135380260440..c299ff4038d4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -442,6 +442,24 @@ class SQLQuerySuite extends QueryTest { sql("SELECT `key` FROM src").collect().toSeq) } + test("SPARK-7269 Check analysis failed in case in-sensitive") { + Seq(1,2,3).map { i => + (i.toString, i.toString) + }.toDF("key", "value").registerTempTable("df_analysis") + sql("SELECT kEy from df_analysis group by key").collect() + sql("SELECT kEy+3 from df_analysis group by key+3").collect() + sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + sql("SELECT 2 from df_analysis A group by key+1").collect() + intercept[AnalysisException] { + sql("SELECT kEy+1 from df_analysis group by key+3") + } + intercept[AnalysisException] { + sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + } + } + test("SPARK-3834 Backticks not correctly handled in subquery aliases") { checkAnswer( sql("SELECT a.key FROM (SELECT key FROM src) `a`"),