diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b8fa6e421ca6..817543bb7a09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -28,7 +28,6 @@ import scala.util.{Failure, Random, Success, Try} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.DATA_TYPE_MISMATCH_ERROR_MESSAGE import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _} @@ -2647,10 +2646,6 @@ class Analyzer(override val catalogManager: CatalogManager) (extraAggExprs.toSeq, transformed) } - private def trimTempResolvedField(input: Expression): Expression = input.transform { - case t: TempResolvedColumn => t.child - } - private def buildAggExprList( expr: Expression, agg: Aggregate, @@ -2666,12 +2661,12 @@ class Analyzer(override val catalogManager: CatalogManager) } else { expr match { case ae: AggregateExpression => - val cleaned = trimTempResolvedField(ae) + val cleaned = RemoveTempResolvedColumn.trimTempResolvedColumn(ae) val alias = Alias(cleaned, cleaned.toString)() aggExprList += alias alias.toAttribute case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) => - trimTempResolvedField(grouping) match { + RemoveTempResolvedColumn.trimTempResolvedColumn(grouping) match { case ne: NamedExpression => aggExprList += ne ne.toAttribute @@ -2683,7 +2678,7 @@ class Analyzer(override val catalogManager: CatalogManager) case t: TempResolvedColumn => // Undo the resolution as this column is neither inside aggregate functions nor a // grouping column. It shouldn't be resolved with `agg.child.output`. - CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts)) + RemoveTempResolvedColumn.restoreTempResolvedColumn(t) case other => other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList))) } @@ -4345,32 +4340,42 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } /** - * Removes all [[TempResolvedColumn]]s in the query plan. This is the last resort, in case some - * rules in the main resolution batch miss to remove [[TempResolvedColumn]]s. We should run this - * rule right after the main resolution batch. + * The rule `ResolveAggregationFunctions` in the main resolution batch creates + * [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved + * column with `agg.child`. When filter conditions or sort expressions are resolved, + * `ResolveAggregationFunctions` will replace [[TempResolvedColumn]], to [[AttributeReference]] if + * it's inside aggregate functions or group expressions, or to [[UnresolvedAttribute]] otherwise, + * hoping other rules can resolve it. + * + * This rule runs after the main resolution batch, and can still hit [[TempResolvedColumn]] if + * filter conditions or sort expressions are not resolved. When this happens, there is no point to + * turn [[TempResolvedColumn]] to [[UnresolvedAttribute]], as we can't resolve the column + * differently, and query will fail. This rule strips all [[TempResolvedColumn]]s in Filter/Sort and + * turns them to [[AttributeReference]] so that the error message can tell users why the filter + * conditions or sort expressions were not resolved. */ object RemoveTempResolvedColumn extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = { - plan.foreachUp { - // HAVING clause will be resolved as a Filter. When having func(column with wrong data type), - // the column could be wrapped by a TempResolvedColumn, e.g. mean(tempresolvedcolumn(t.c)). - // Because TempResolvedColumn can still preserve column data type, here is a chance to check - // if the data type matches with the required data type of the function. We can throw an error - // when data types mismatches. - case operator: Filter => - operator.expressions.foreach(_.foreachUp { - case e: Expression if e.childrenResolved && e.checkInputDataTypes().isFailure => - e.checkInputDataTypes() match { - case TypeCheckResult.TypeCheckFailure(message) => - e.setTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE, message) - } - case _ => - }) - case _ => + plan.resolveOperatorsUp { + case f @ Filter(cond, agg: Aggregate) if agg.resolved => + withOrigin(f.origin)(f.copy(condition = trimTempResolvedColumn(cond))) + case s @ Sort(sortOrder, _, agg: Aggregate) if agg.resolved => + val newSortOrder = sortOrder.map { order => + trimTempResolvedColumn(order).asInstanceOf[SortOrder] + } + withOrigin(s.origin)(s.copy(order = newSortOrder)) + case other => other.transformExpressionsUp { + // This should not happen. We restore TempResolvedColumn to UnresolvedAttribute to be safe. + case t: TempResolvedColumn => restoreTempResolvedColumn(t) + } } + } - plan.resolveExpressions { - case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts) - } + def trimTempResolvedColumn(input: Expression): Expression = input.transform { + case t: TempResolvedColumn => t.child + } + + def restoreTempResolvedColumn(t: TempResolvedColumn): Expression = { + CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7635918279a2..83246406a8c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -50,8 +50,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError") - val DATA_TYPE_MISMATCH_ERROR_MESSAGE = TreeNodeTag[String]("dataTypeMismatchError") - protected def failAnalysis(msg: String): Nothing = { throw new AnalysisException(msg) } @@ -176,20 +174,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } } - val expressions = getAllExpressions(operator) - - expressions.foreach(_.foreachUp { - case e: Expression => - e.getTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE) match { - case Some(message) => - e.failAnalysis(s"cannot resolve '${e.sql}' due to data type mismatch: $message" + - extraHintForAnsiTypeCoercionExpression(operator)) - case _ => - } - case _ => - }) - - expressions.foreach(_.foreachUp { + getAllExpressions(operator).foreach(_.foreachUp { case a: Attribute if !a.resolved => val missingCol = a.sql val candidates = operator.inputSet.toSeq.map(_.qualifiedName) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a6e952fd8657..5c3f4b5f5585 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -1172,7 +1172,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { |FROM t |GROUP BY t.c, t.d |HAVING ${func}(c) > 0d""".stripMargin), - Seq(s"cannot resolve '$func(c)' due to data type mismatch"), + Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"), false) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 2799b1a94d08..056b99e363d2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -11,6 +11,9 @@ SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2; -- having condition contains grouping column SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; +-- invalid having condition contains grouping column +SELECT count(k) FROM hav GROUP BY v HAVING v = array(1); + -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index fff470b3d81d..e9e24562d1ba 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -29,6 +29,15 @@ struct 1 +-- !query +SELECT count(k) FROM hav GROUP BY v HAVING v = array(1) +-- !query schema +struct<> +-- !query output +org.apache.spark.sql.AnalysisException +cannot resolve '(hav.v = array(1))' due to data type mismatch: differing types in '(hav.v = array(1))' (int and array).; line 1 pos 43 + + -- !query SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) -- !query schema