Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -791,9 +791,9 @@ class Analyzer(
// AggregateFunction's with the exception of First and Last in their default mode
// (which we handle) and possibly some Hive UDAF's.
case First(expr, _) =>
First(ifExpr(expr), Literal(true))
First(ifExpr(expr), true)
case Last(expr, _) =>
Last(ifExpr(expr), Literal(true))
Last(ifExpr(expr), true)
case a: AggregateFunction =>
a.withNewChildren(a.children.map(ifExpr))
}.transform {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -49,12 +50,16 @@ import org.apache.spark.sql.types._
""",
group = "agg_funcs",
since = "2.0.0")
case class First(child: Expression, ignoreNullsExpr: Expression)
case class First(child: Expression, ignoreNulls: Boolean)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))
def this(child: Expression) = this(child, false)

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
def this(child: Expression, ignoreNullsExpr: Expression) = {
this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "first"))
}

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

Expand All @@ -71,16 +76,11 @@ case class First(child: Expression, ignoreNullsExpr: Expression)
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!ignoreNullsExpr.foldable) {
TypeCheckFailure(
s"The second argument of First must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
} else {
TypeCheckSuccess
}
}

private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]

private lazy val first = AttributeReference("first", child.dataType)()

private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down Expand Up @@ -120,3 +120,11 @@ case class First(child: Expression, ignoreNullsExpr: Expression)

override def toString: String = s"$prettyName($child)${if (ignoreNulls) " ignore nulls"}"
}

object FirstLast {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this deduplication is a little bit too much but I guess it's fine.

def validateIgnoreNullExpr(exp: Expression, funcName: String): Boolean = exp match {
case Literal(b: Boolean, BooleanType) => b
case _ => throw new AnalysisException(
s"The second argument in $funcName should be a boolean literal.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckSuccess
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -49,12 +49,16 @@ import org.apache.spark.sql.types._
""",
group = "agg_funcs",
since = "2.0.0")
case class Last(child: Expression, ignoreNullsExpr: Expression)
case class Last(child: Expression, ignoreNulls: Boolean)
extends DeclarativeAggregate with ExpectsInputTypes {

def this(child: Expression) = this(child, Literal.create(false, BooleanType))
def this(child: Expression) = this(child, false)

override def children: Seq[Expression] = child :: ignoreNullsExpr :: Nil
def this(child: Expression, ignoreNullsExpr: Expression) = {
this(child, FirstLast.validateIgnoreNullExpr(ignoreNullsExpr, "last"))
}

override def children: Seq[Expression] = child :: Nil

override def nullable: Boolean = true

Expand All @@ -71,16 +75,11 @@ case class Last(child: Expression, ignoreNullsExpr: Expression)
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else if (!ignoreNullsExpr.foldable) {
TypeCheckFailure(
s"The second argument of Last must be a boolean literal, but got: ${ignoreNullsExpr.sql}")
} else {
TypeCheckSuccess
}
}

private def ignoreNulls: Boolean = ignoreNullsExpr.eval().asInstanceOf[Boolean]

private lazy val last = AttributeReference("last", child.dataType)()

private lazy val valueSet = AttributeReference("valueSet", BooleanType)()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {

// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), Literal(true)),
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute), true),
mode = Complete,
isDistinct = false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1535,15 +1535,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitFirst(ctx: FirstContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
First(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression()
First(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}

/**
* Create a [[Last]] expression.
*/
override def visitLast(ctx: LastContext): Expression = withOrigin(ctx) {
val ignoreNullsExpr = ctx.IGNORE != null
Last(expression(ctx.expression), Literal(ignoreNullsExpr)).toAggregateExpression()
Last(expression(ctx.expression), ignoreNullsExpr).toAggregateExpression()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal}
import org.apache.spark.sql.types.IntegerType

class LastTestSuite extends SparkFunSuite {
class FirstLastTestSuite extends SparkFunSuite {
val input = AttributeReference("input", IntegerType, nullable = true)()
val evaluator = DeclarativeAggregateEvaluator(Last(input, Literal(false)), Seq(input))
val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, Literal(true)), Seq(input))
val evaluator = DeclarativeAggregateEvaluator(Last(input, false), Seq(input))
val evaluatorIgnoreNulls = DeclarativeAggregateEvaluator(Last(input, true), Seq(input))

test("empty buffer") {
assert(evaluator.initialize() === InternalRow(null, false))
Expand Down Expand Up @@ -106,4 +107,15 @@ class LastTestSuite extends SparkFunSuite {
val m1 = evaluatorIgnoreNulls.merge(p1, p2)
assert(evaluatorIgnoreNulls.eval(m1) === InternalRow(1))
}

test("SPARK-32344: correct error handling for a type mismatch") {
val msg1 = intercept[AnalysisException] {
new First(input, Literal(1, IntegerType))
}.getMessage
assert(msg1.contains("The second argument in first should be a boolean literal"))
val msg2 = intercept[AnalysisException] {
new Last(input, Literal(1, IntegerType))
}.getMessage
assert(msg2.contains("The second argument in last should be a boolean literal"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -785,10 +785,10 @@ class ExpressionParserSuite extends AnalysisTest {
}

test("SPARK-19526 Support ignore nulls keywords for first and last") {
assertEqual("first(a ignore nulls)", First('a, Literal(true)).toAggregateExpression())
assertEqual("first(a)", First('a, Literal(false)).toAggregateExpression())
assertEqual("last(a ignore nulls)", Last('a, Literal(true)).toAggregateExpression())
assertEqual("last(a)", Last('a, Literal(false)).toAggregateExpression())
assertEqual("first(a ignore nulls)", First('a, true).toAggregateExpression())
assertEqual("first(a)", First('a, false).toAggregateExpression())
assertEqual("last(a ignore nulls)", Last('a, true).toAggregateExpression())
assertEqual("last(a)", Last('a, false).toAggregateExpression())
}

test("timestamp literals") {
Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ object functions {
* @since 2.0.0
*/
def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new First(e.expr, Literal(ignoreNulls))
First(e.expr, ignoreNulls)
}

/**
Expand Down Expand Up @@ -586,7 +586,7 @@ object functions {
* @since 2.0.0
*/
def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction {
new Last(e.expr, Literal(ignoreNulls))
new Last(e.expr, ignoreNulls)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,12 @@
| org.apache.spark.sql.catalyst.expressions.aggregate.CountMinSketchAgg | count_min_sketch | N/A | N/A |
| org.apache.spark.sql.catalyst.expressions.aggregate.CovPopulation | covar_pop | SELECT covar_pop(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct<covar_pop(CAST(c1 AS DOUBLE), CAST(c2 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct<covar_samp(CAST(c1 AS DOUBLE), CAST(c2 AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first_value(col, false):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first(col, false):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first_value(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct<approx_count_distinct(col1):bigint> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct<kurtosis(CAST(col AS DOUBLE)):double> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last_value(col, false):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last(col, false):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last_value | SELECT last_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last_value(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Max | max | SELECT max(col) FROM VALUES (10), (50), (20) AS tab(col) | struct<max(col):int> |
| org.apache.spark.sql.catalyst.expressions.aggregate.MaxBy | max_by | SELECT max_by(x, y) FROM VALUES (('a', 10)), (('b', 50)), (('c', 20)) AS tab(x, y) | struct<max_by(x, y):string> |
| org.apache.spark.sql.catalyst.expressions.aggregate.Min | min | SELECT min(col) FROM VALUES (10), (-1), (20) AS tab(col) | struct<min(col):int> |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ struct<lead((ten * 2), 1, -1) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIR
-- !query
SELECT first(ten) OVER (PARTITION BY four ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10
-- !query schema
struct<first(ten, false) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int,ten:int,four:int>
struct<first(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int,ten:int,four:int>
-- !query output
0 0 0
0 0 0
Expand All @@ -287,7 +287,7 @@ struct<first(ten, false) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RA
-- !query
SELECT last(four) OVER (ORDER BY ten), ten, four FROM tenk1 WHERE unique2 < 10
-- !query schema
struct<last(four, false) OVER (ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int,ten:int,four:int>
struct<last(four) OVER (ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int,ten:int,four:int>
-- !query output
0 4 0
1 1 1
Expand All @@ -306,7 +306,7 @@ SELECT last(ten) OVER (PARTITION BY four), ten, four FROM
(SELECT * FROM tenk1 WHERE unique2 < 10 ORDER BY four, ten)s
ORDER BY four, ten
-- !query schema
struct<last(ten, false) OVER (PARTITION BY four ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):int,ten:int,four:int>
struct<last(ten) OVER (PARTITION BY four ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):int,ten:int,four:int>
-- !query output
4 0 0
4 0 0
Expand Down Expand Up @@ -476,7 +476,7 @@ sum(ten) over (partition by four order by ten),
last(ten) over (partition by four order by ten)
FROM (select distinct ten, four from tenk1) ss
-- !query schema
struct<four:int,ten:int,sum(CAST(ten AS BIGINT)) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):bigint,last(ten, false) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int>
struct<four:int,ten:int,sum(CAST(ten AS BIGINT)) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):bigint,last(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int>
-- !query output
0 0 0 0
0 2 2 2
Expand Down Expand Up @@ -506,7 +506,7 @@ sum(ten) over (partition by four order by ten range between unbounded preceding
last(ten) over (partition by four order by ten range between unbounded preceding and current row)
FROM (select distinct ten, four from tenk1) ss
-- !query schema
struct<four:int,ten:int,sum(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):bigint,last(ten, false) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int>
struct<four:int,ten:int,sum(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):bigint,last(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW):int>
-- !query output
0 0 0 0
0 2 2 2
Expand Down Expand Up @@ -536,7 +536,7 @@ sum(ten) over (partition by four order by ten range between unbounded preceding
last(ten) over (partition by four order by ten range between unbounded preceding and unbounded following)
FROM (select distinct ten, four from tenk1) ss
-- !query schema
struct<four:int,ten:int,sum(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint,last(ten, false) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):int>
struct<four:int,ten:int,sum(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):bigint,last(ten) OVER (PARTITION BY four ORDER BY ten ASC NULLS FIRST RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING):int>
-- !query output
0 0 20 8
0 2 20 8
Expand Down
Loading