Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import scala.util.{Failure, Random, Success, Try}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer.DATA_TYPE_MISMATCH_ERROR_MESSAGE
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
import org.apache.spark.sql.catalyst.expressions.{Expression, FrameLessOffsetWindowFunction, _}
Expand Down Expand Up @@ -2647,10 +2646,6 @@ class Analyzer(override val catalogManager: CatalogManager)
(extraAggExprs.toSeq, transformed)
}

private def trimTempResolvedField(input: Expression): Expression = input.transform {
case t: TempResolvedColumn => t.child
}

private def buildAggExprList(
expr: Expression,
agg: Aggregate,
Expand All @@ -2666,12 +2661,12 @@ class Analyzer(override val catalogManager: CatalogManager)
} else {
expr match {
case ae: AggregateExpression =>
val cleaned = trimTempResolvedField(ae)
val cleaned = RemoveTempResolvedColumn.trimTempResolvedColumn(ae)
val alias = Alias(cleaned, cleaned.toString)()
aggExprList += alias
alias.toAttribute
case grouping: Expression if agg.groupingExpressions.exists(grouping.semanticEquals) =>
trimTempResolvedField(grouping) match {
RemoveTempResolvedColumn.trimTempResolvedColumn(grouping) match {
case ne: NamedExpression =>
aggExprList += ne
ne.toAttribute
Expand All @@ -2683,7 +2678,7 @@ class Analyzer(override val catalogManager: CatalogManager)
case t: TempResolvedColumn =>
// Undo the resolution as this column is neither inside aggregate functions nor a
// grouping column. It shouldn't be resolved with `agg.child.output`.
CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
RemoveTempResolvedColumn.restoreTempResolvedColumn(t)
case other =>
other.withNewChildren(other.children.map(buildAggExprList(_, agg, aggExprList)))
}
Expand Down Expand Up @@ -4345,32 +4340,42 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] {
}

/**
* Removes all [[TempResolvedColumn]]s in the query plan. This is the last resort, in case some
* rules in the main resolution batch miss to remove [[TempResolvedColumn]]s. We should run this
* rule right after the main resolution batch.
* The rule `ResolveAggregationFunctions` in the main resolution batch creates
* [[TempResolvedColumn]] in filter conditions and sort expressions to hold the temporarily resolved
* column with `agg.child`. When filter conditions or sort expressions are resolved,
* `ResolveAggregationFunctions` will replace [[TempResolvedColumn]], to [[AttributeReference]] if
* it's inside aggregate functions or group expressions, or to [[UnresolvedAttribute]] otherwise,
* hoping other rules can resolve it.
*
* This rule runs after the main resolution batch, and can still hit [[TempResolvedColumn]] if
* filter conditions or sort expressions are not resolved. When this happens, there is no point to
* turn [[TempResolvedColumn]] to [[UnresolvedAttribute]], as we can't resolve the column
* differently, and query will fail. This rule strips all [[TempResolvedColumn]]s in Filter/Sort and
* turns them to [[AttributeReference]] so that the error message can tell users why the filter
* conditions or sort expressions were not resolved.
*/
object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = {
plan.foreachUp {
// HAVING clause will be resolved as a Filter. When having func(column with wrong data type),
// the column could be wrapped by a TempResolvedColumn, e.g. mean(tempresolvedcolumn(t.c)).
// Because TempResolvedColumn can still preserve column data type, here is a chance to check
// if the data type matches with the required data type of the function. We can throw an error
// when data types mismatches.
case operator: Filter =>
operator.expressions.foreach(_.foreachUp {
case e: Expression if e.childrenResolved && e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
e.setTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE, message)
}
case _ =>
})
case _ =>
plan.resolveOperatorsUp {
case f @ Filter(cond, agg: Aggregate) if agg.resolved =>
withOrigin(f.origin)(f.copy(condition = trimTempResolvedColumn(cond)))
case s @ Sort(sortOrder, _, agg: Aggregate) if agg.resolved =>
val newSortOrder = sortOrder.map { order =>
trimTempResolvedColumn(order).asInstanceOf[SortOrder]
}
withOrigin(s.origin)(s.copy(order = newSortOrder))
case other => other.transformExpressionsUp {
// This should not happen. We restore TempResolvedColumn to UnresolvedAttribute to be safe.
case t: TempResolvedColumn => restoreTempResolvedColumn(t)
}
}
}

plan.resolveExpressions {
case t: TempResolvedColumn => UnresolvedAttribute(t.nameParts)
}
def trimTempResolvedColumn(input: Expression): Expression = input.transform {
case t: TempResolvedColumn => t.child
}

def restoreTempResolvedColumn(t: TempResolvedColumn): Expression = {
CurrentOrigin.withOrigin(t.origin)(UnresolvedAttribute(t.nameParts))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {

val DATA_TYPE_MISMATCH_ERROR = TreeNodeTag[Boolean]("dataTypeMismatchError")

val DATA_TYPE_MISMATCH_ERROR_MESSAGE = TreeNodeTag[String]("dataTypeMismatchError")

Comment on lines -53 to -54
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to make sure this isn't related to TempResolvedColumn, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was used to do error handling of TempResolvedColumn, but we don't need it now as the logic is simplified.

Copy link
Contributor

@LuciferYang LuciferYang Jun 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to make sure this isn't related to TempResolvedColumn, right?

Yes, I added this in #36746. It is not needed after this pr

protected def failAnalysis(msg: String): Nothing = {
throw new AnalysisException(msg)
}
Expand Down Expand Up @@ -176,20 +174,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
}
}

val expressions = getAllExpressions(operator)

expressions.foreach(_.foreachUp {
case e: Expression =>
e.getTagValue(DATA_TYPE_MISMATCH_ERROR_MESSAGE) match {
case Some(message) =>
e.failAnalysis(s"cannot resolve '${e.sql}' due to data type mismatch: $message" +
extraHintForAnsiTypeCoercionExpression(operator))
case _ =>
}
case _ =>
})

expressions.foreach(_.foreachUp {
getAllExpressions(operator).foreach(_.foreachUp {
case a: Attribute if !a.resolved =>
val missingCol = a.sql
val candidates = operator.inputSet.toSeq.map(_.qualifiedName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
|FROM t
|GROUP BY t.c, t.d
|HAVING ${func}(c) > 0d""".stripMargin),
Seq(s"cannot resolve '$func(c)' due to data type mismatch"),
Seq(s"cannot resolve '$func(t.c)' due to data type mismatch"),
false)
}
}
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/having.sql
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ SELECT k, sum(v) FROM hav GROUP BY k HAVING sum(v) > 2;
-- having condition contains grouping column
SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2;

-- invalid having condition contains grouping column
SELECT count(k) FROM hav GROUP BY v HAVING v = array(1);

-- SPARK-11032: resolve having correctly
SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);

Expand Down
9 changes: 9 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/having.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ struct<count(k):bigint>
1


-- !query
SELECT count(k) FROM hav GROUP BY v HAVING v = array(1)
-- !query schema
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
cannot resolve '(hav.v = array(1))' due to data type mismatch: differing types in '(hav.v = array(1))' (int and array<int>).; line 1 pos 43


-- !query
SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0)
-- !query schema
Expand Down