diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 5aea884ad500..051197d4544f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{DIALECT, Dialect} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -1672,33 +1673,59 @@ case class TruncTimestamp( } /** - * Returns the number of days from startDate to endDate. + * Returns the number of days from startDate to endDate or an interval between the dates. */ +// scalastyle:off line.size.limit line.contains.tab @ExpressionDescription( - usage = "_FUNC_(endDate, startDate) - Returns the number of days from `startDate` to `endDate`.", + usage = "_FUNC_(endDate, startDate) - Returns the number of days from `startDate` to `endDate`." + + "When `spark.sql.ansi.enabled` is set to `true` and `spark.sql.dialect` is `Spark`, it returns " + + "an interval between `startDate` (inclusive) and `endDate` (exclusive).", examples = """ Examples: > SELECT _FUNC_('2009-07-31', '2009-07-30'); 1 - > SELECT _FUNC_('2009-07-30', '2009-07-31'); -1 + > SET spark.sql.ansi.enabled=true; + spark.sql.ansi.enabled true + > SET spark.sql.dialect=Spark; + spark.sql.dialect Spark + > select _FUNC_(date'tomorrow', date'yesterday'); + interval 2 days """, since = "1.5.0") +// scalastyle:on line.size.limit line.contains.tab case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = endDate override def right: Expression = startDate override def inputTypes: Seq[AbstractDataType] = Seq(DateType, DateType) - override def dataType: DataType = IntegerType + private val returnInterval: Boolean = { + val isSparkDialect = SQLConf.get.getConf(DIALECT) == Dialect.SPARK.toString() + SQLConf.get.ansiEnabled && isSparkDialect + } + override def dataType: DataType = if (returnInterval) CalendarIntervalType else IntegerType override def nullSafeEval(end: Any, start: Any): Any = { - end.asInstanceOf[Int] - start.asInstanceOf[Int] + val startDate = start.asInstanceOf[Int] + val endDate = end.asInstanceOf[Int] + if (returnInterval) { + dateDiff(endDate, startDate) + } else { + endDate - startDate + } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (end, start) => s"$end - $start") + defineCodeGen(ctx, ev, (end, start) => { + if (returnInterval) { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + s"$dtu.dateDiff($end, $start)" + } else { + s"$end - $start" + } + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 34e8012106bb..00066a571734 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -27,7 +27,7 @@ import java.util.concurrent.TimeUnit._ import scala.util.control.NonFatal import org.apache.spark.sql.types.Decimal -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * Helper functions for converting between internal and external date and time representations. @@ -950,4 +950,20 @@ object DateTimeUtils { None } } + + /** + * Gets difference between two dates. + * @param endDate - the end date, exclusive + * @param startDate - the start date, inclusive + * @return an interval between two dates. The interval can be negative + * if the end date is before the start date. + */ + def dateDiff(endDate: SQLDate, startDate: SQLDate): CalendarInterval = { + val period = Period.between( + LocalDate.ofEpochDay(startDate), + LocalDate.ofEpochDay(endDate)) + val months = period.getMonths + 12 * period.getYears + val microseconds = period.getDays * MICROS_PER_DAY + new CalendarInterval(months, microseconds) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 8680a15ee1cd..dd53ef0a25d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.time.{Instant, LocalDateTime, ZoneId, ZoneOffset} +import java.time.{Instant, LocalDate, LocalDateTime, ZoneId, ZoneOffset} import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit._ @@ -836,18 +836,51 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } - test("datediff") { - checkEvaluation( - DateDiff(Literal(Date.valueOf("2015-07-24")), Literal(Date.valueOf("2015-07-21"))), 3) - checkEvaluation( - DateDiff(Literal(Date.valueOf("2015-07-21")), Literal(Date.valueOf("2015-07-24"))), -3) - checkEvaluation(DateDiff(Literal.create(null, DateType), Literal(Date.valueOf("2015-07-24"))), - null) - checkEvaluation(DateDiff(Literal(Date.valueOf("2015-07-24")), Literal.create(null, DateType)), - null) - checkEvaluation( - DateDiff(Literal.create(null, DateType), Literal.create(null, DateType)), - null) + test("datediff returns an integer") { + Seq( + (true, "PostgreSQL"), + (false, "PostgreSQL"), + (false, "Spark")).foreach { case (ansi, dialect) => + withSQLConf( + SQLConf.ANSI_ENABLED.key -> ansi.toString, + SQLConf.DIALECT.key -> dialect.toString) { + checkEvaluation( + DateDiff(Literal(Date.valueOf("2015-07-24")), Literal(Date.valueOf("2015-07-21"))), 3) + checkEvaluation( + DateDiff(Literal(Date.valueOf("2015-07-21")), Literal(Date.valueOf("2015-07-24"))), -3) + checkEvaluation( + DateDiff(Literal.create(null, DateType), Literal(Date.valueOf("2015-07-24"))), + null) + checkEvaluation( + DateDiff(Literal(Date.valueOf("2015-07-24")), Literal.create(null, DateType)), + null) + checkEvaluation( + DateDiff(Literal.create(null, DateType), Literal.create(null, DateType)), + null) + } + } + } + + test("datediff returns an interval") { + withSQLConf(SQLConf.ANSI_ENABLED.key -> "true", SQLConf.DIALECT.key -> "Spark") { + val end = LocalDate.of(2019, 10, 5) + checkEvaluation(DateDiff(Literal(end), Literal(end)), + new CalendarInterval(0, 0)) + checkEvaluation(DateDiff(Literal(end.plusDays(1)), Literal(end)), + CalendarInterval.fromString("interval 1 days")) + checkEvaluation(DateDiff(Literal(end.minusDays(1)), Literal(end)), + CalendarInterval.fromString("interval -1 days")) + val epochDate = Literal(LocalDate.ofEpochDay(0)) + checkEvaluation(DateDiff(Literal(end), epochDate), + CalendarInterval.fromString("interval 49 years 9 months 4 days")) + checkEvaluation(DateDiff(epochDate, Literal(end)), + CalendarInterval.fromString("interval -49 years -9 months -4 days")) + checkEvaluation( + DateDiff( + Literal(LocalDate.of(10000, 1, 1)), + Literal(LocalDate.of(1, 1, 1))), + CalendarInterval.fromString("interval 9999 years")) + } } test("to_utc_timestamp") {