diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 43ab651881487..afe7b4f2c9d48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -681,9 +681,9 @@ class Analyzer( // AggregateFunction's with the exception of First and Last in their default mode // (which we handle) and possibly some Hive UDAF's. case First(expr, _) => - First(ifExpr(expr), Literal(true)) + First(ifExpr(expr), true) case Last(expr, _) => - Last(ifExpr(expr), Literal(true)) + Last(ifExpr(expr), true) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) }.transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index f51bfd591204a..d2032772d0519 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -35,12 +36,16 @@ import org.apache.spark.sql.types._ _FUNC_(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows. If `isIgnoreNull` is true, returns only non-null values. """) -case class First(child: Expression, ignoreNullsExpr: Expression) +case class First(child: Expression, ignoreNulls: Boolean) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + def this(child: Expression) = this(child, false) - override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil + def this(child: Expression, ignoreNullsExpr: Expression) = { + this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "first")) + } + + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true @@ -57,16 +62,11 @@ case class First(child: Expression, ignoreNullsExpr: Expression) val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { defaultCheck - } else if (!ignoreNullsExpr.foldable) { - TypeCheckFailure( - s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}") } else { TypeCheckSuccess } } - private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean] - private lazy val first = AttributeReference("first", child.dataType)() private lazy val valueSet = AttributeReference("valueSet", BooleanType)() @@ -106,3 +106,11 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}" } + +object FirstLast { + def validateIgnoreNullExpr(exp: Expression, funcName: String): Boolean = exp match { + case Literal(b: Boolean, BooleanType) => b + case _ => throw new AnalysisException( + s"The second argument in $funcName should be a boolean literal.") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 2650d7b5908fd..57a62a0383637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -35,12 +35,16 @@ import org.apache.spark.sql.types._ _FUNC_(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows. If `isIgnoreNull` is true, returns only non-null values. """) -case class Last(child: Expression, ignoreNullsExpr: Expression) +case class Last(child: Expression, ignoreNulls: Boolean) extends DeclarativeAggregate with ExpectsInputTypes { - def this(child: Expression) = this(child, Literal.create(false, BooleanType)) + def this(child: Expression) = this(child, false) - override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil + def this(child: Expression, ignoreNullsExpr: Expression) = { + this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "last")) + } + + override def children: Seq[Expression] = child :: Nil override def nullable: Boolean = true @@ -57,16 +61,11 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) val defaultCheck = super.checkInputDataTypes() if (defaultCheck.isFailure) { defaultCheck - } else if (!ignoreNullsExpr.foldable) { - TypeCheckFailure( - s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}") } else { TypeCheckSuccess } } - private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean] - private lazy val last = AttributeReference("last", child.dataType)() private lazy val valueSet = AttributeReference("valueSet", BooleanType)() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index b9468007cac61..22eb604390d39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -198,7 +198,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // Select the result of the first aggregate in the last aggregate. val result = AggregateExpression( - aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)), + aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true), mode = Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e2e8a45976d75..90e7d1c3917e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1225,7 +1225,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() } /** @@ -1233,7 +1233,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) { val ignoreNullsExpr = ctx.IGNORE != null - Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression() + Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression() } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/FirstLastTestSuite.scala similarity index 84% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/FirstLastTestSuite.scala index ba36bc074e154..bb6672e1046da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/LastTestSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/FirstLastTestSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} import org.apache.spark.sql.types.IntegerType -class LastTestSuite extends SparkFunSuite { +class FirstLastTestSuite extends SparkFunSuite { val input = AttributeReference("input", IntegerType, nullable = true)() - val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input)) - val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input)) + val evaluator = DeclarativeAggregateEvaluator(Last(input, false), Seq(input)) + val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, true), Seq(input)) test("empty buffer") { assert(evaluator.initialize() === InternalRow(null, false)) @@ -106,4 +107,15 @@ class LastTestSuite extends SparkFunSuite { val m1 = evaluatorIgnoreNulls.merge(p1, p2) assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1)) } + + test("SPARK-32344: correct error handling for a type mismatch") { + val msg1 = intercept[AnalysisException] { + new First(input, Literal(1, IntegerType)) + }.getMessage + assert(msg1.contains("The second argument in first should be a boolean literal")) + val msg2 = intercept[AnalysisException] { + new Last(input, Literal(1, IntegerType)) + }.getMessage + assert(msg2.contains("The second argument in last should be a boolean literal")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 2bc41dad9d46d..a72342633f051 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -688,9 +688,9 @@ class ExpressionParserSuite extends PlanTest { } test("SPARK-19526 Support ignore nulls keywords for first and last") { - assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression()) - assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression()) - assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression()) - assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression()) + assertEqual("first(a ignore nulls)", First('a, true).toAggregateExpression()) + assertEqual("first(a)", First('a, false).toAggregateExpression()) + assertEqual("last(a ignore nulls)", Last('a, true).toAggregateExpression()) + assertEqual("last(a)", Last('a, false).toAggregateExpression()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 21ad1fd0ad395..427017117769e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -455,7 +455,7 @@ object functions { * @since 2.0.0 */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new First(e.expr, Literal(ignoreNulls)) + new First(e.expr, ignoreNulls) } /** @@ -580,7 +580,7 @@ object functions { * @since 2.0.0 */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { - new Last(e.expr, Literal(ignoreNulls)) + new Last(e.expr, ignoreNulls) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index bb7c68abb168d..86a1086efbc8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -743,4 +743,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { "grouping expressions: [current_date(None)], value: [key: int, value: string], " + "type: GroupBy]")) } + + test("SPARK-32344: Unevaluable's set to FIRST/LAST ignoreNullsExpr in distinct aggregates") { + val queryTemplate = (agg: String) => + s"SELECT $agg(DISTINCT v) FROM (SELECT v FROM VALUES 1, 2, 3 t(v) ORDER BY v)" + checkAnswer(sql(queryTemplate("FIRST")), Row(1)) + checkAnswer(sql(queryTemplate("LAST")), Row(3)) + } }