Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -681,9 +681,9 @@ class Analyzer(
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
First(ifExpr(expr), true)
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
Last(ifExpr(expr), true)
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand All @@ -35,12 +36,16 @@ import org.apache.spark.sql.types._
_FUNC_(expr[, isIgnoreNull]) - Returns the first value of `expr` for a group of rows.
If `isIgnoreNull` is true, returns only non-null values.
""")
case class First(child: Expression, ignoreNullsExpr: Expression)
case class First(child: Expression, ignoreNulls: Boolean)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))
def this(child: Expression) = this(child, false)

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
def this(child: Expression, ignoreNullsExpr: Expression) = {
this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "first"))
}

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

Expand All @@ -57,16 +62,11 @@ case class First(child: Expression, ignoreNullsExpr: Expression)
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!ignoreNullsExpr.foldable) {
TypeCheckFailure(
s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
} else {
TypeCheckSuccess
}
}

private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]

private lazy val first = AttributeReference("first", child.dataType)()

private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down Expand Up @@ -106,3 +106,11 @@ case class First(child: Expression, ignoreNullsExpr: Expression)

override def toString: String = s"first($child)${if (ignoreNulls) " ignore nulls"}"
}

object FirstLast {
def validateIgnoreNullExpr(exp: Expression, funcName: String): Boolean = exp match {
case Literal(b: Boolean, BooleanType) => b
case _ => throw new AnalysisException(
s"The second argument in $funcName should be a boolean literal.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand All @@ -35,12 +35,16 @@ import org.apache.spark.sql.types._
_FUNC_(expr[, isIgnoreNull]) - Returns the last value of `expr` for a group of rows.
If `isIgnoreNull` is true, returns only non-null values.
""")
case class Last(child: Expression, ignoreNullsExpr: Expression)
case class Last(child: Expression, ignoreNulls: Boolean)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))
def this(child: Expression) = this(child, false)

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
def this(child: Expression, ignoreNullsExpr: Expression) = {
this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "last"))
}

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

Expand All @@ -57,16 +61,11 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!ignoreNullsExpr.foldable) {
TypeCheckFailure(
s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
} else {
TypeCheckSuccess
}
}

private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]

private lazy val last = AttributeReference("last", child.dataType)()

private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true),
mode = Complete,
isDistinct = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1225,15 +1225,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression()
First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}

/**
* Create a [[Last]] expression.
*/
override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression()
Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
import org.apache.spark.sql.types.IntegerType

class LastTestSuite extends SparkFunSuite {
class FirstLastTestSuite extends SparkFunSuite {
val input = AttributeReference("input", IntegerType, nullable = true)()
val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input))
val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input))
val evaluator = DeclarativeAggregateEvaluator(Last(input, false), Seq(input))
val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, true), Seq(input))

test("empty buffer") {
assert(evaluator.initialize() === InternalRow(null, false))
Expand Down Expand Up @@ -106,4 +107,15 @@ class LastTestSuite extends SparkFunSuite {
val m1 = evaluatorIgnoreNulls.merge(p1, p2)
assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1))
}

test("SPARK-32344: correct error handling for a type mismatch") {
val msg1 = intercept[AnalysisException] {
new First(input, Literal(1, IntegerType))
}.getMessage
assert(msg1.contains("The second argument in first should be a boolean literal"))
val msg2 = intercept[AnalysisException] {
new Last(input, Literal(1, IntegerType))
}.getMessage
assert(msg2.contains("The second argument in last should be a boolean literal"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,9 @@ class ExpressionParserSuite extends PlanTest {
}

test("SPARK-19526 Support ignore nulls keywords for first and last") {
assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression())
assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression())
assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression())
assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression())
assertEqual("first(a ignore nulls)", First('a, true).toAggregateExpression())
assertEqual("first(a)", First('a, false).toAggregateExpression())
assertEqual("last(a ignore nulls)", Last('a, true).toAggregateExpression())
assertEqual("last(a)", Last('a, false).toAggregateExpression())
}
}
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ object functions {
* @since 2.0.0
*/
def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new First(e.expr, Literal(ignoreNulls))
new First(e.expr, ignoreNulls)
}

/**
Expand Down Expand Up @@ -580,7 +580,7 @@ object functions {
* @since 2.0.0
*/
def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new Last(e.expr, Literal(ignoreNulls))
new Last(e.expr, ignoreNulls)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,4 +743,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
"grouping expressions: [current_date(None)], value: [key: int, value: string], " +
"type: GroupBy]"))
}

test("SPARK-32344: Unevaluable's set to FIRST/LAST ignoreNullsExpr in distinct aggregates") {
val queryTemplate = (agg: String) =>
s"SELECT $agg(DISTINCT v) FROM (SELECT v FROM VALUES 1, 2, 3 t(v) ORDER BY v)"
checkAnswer(sql(queryTemplate("FIRST")), Row(1))
checkAnswer(sql(queryTemplate("LAST")), Row(3))
}
}