diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index e50abebe57987..3ee177f90be4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1546,15 +1546,15 @@ case class TruncTimestamp( override def eval(input: InternalRow): Any = { evalHelper(input, maxLevel = DateTimeUtils.TRUNC_TO_SECOND) { (t: Any, level: Int) => - DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, timeZone) + DateTimeUtils.truncTimestamp(t.asInstanceOf[Long], level, zoneId) } } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val tz = ctx.addReferenceObj("timeZone", timeZone) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) codeGenHelper(ctx, ev, maxLevel = DateTimeUtils.TRUNC_TO_SECOND, true) { (date: String, fmt: String) => - s"truncTimestamp($date, $fmt, $tz);" + s"truncTimestamp($date, $fmt, $zid);" } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 10a7f9bd550e2..56ebc5a1963c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.time._ import java.time.temporal.{ChronoUnit, IsoFields} +import java.time.temporal.TemporalAdjusters._ import java.util.{Locale, TimeZone} import java.util.concurrent.TimeUnit._ @@ -647,41 +648,30 @@ object DateTimeUtils { * Returns the trunc date time from original date time and trunc level. * Trunc level should be generated using `parseTruncLevel()`, should be between 1 and 8 */ - def truncTimestamp(t: SQLTimestamp, level: Int, timeZone: TimeZone): SQLTimestamp = { - var millis = MICROSECONDS.toMillis(t) + def truncTimestamp(t: SQLTimestamp, level: Int, zoneId: ZoneId): SQLTimestamp = { + val zonedDateTime = microsToInstant(t).atZone(zoneId) val truncated = level match { case TRUNC_TO_YEAR => - val dDays = millisToDays(millis, timeZone) - daysToMillis(truncDate(dDays, level), timeZone) + zonedDateTime.`with`(firstDayOfYear()).truncatedTo(ChronoUnit.DAYS) case TRUNC_TO_MONTH => - val dDays = millisToDays(millis, timeZone) - daysToMillis(truncDate(dDays, level), timeZone) + zonedDateTime.`with`(firstDayOfMonth()).truncatedTo(ChronoUnit.DAYS) + case TRUNC_TO_QUARTER => + zonedDateTime.`with`(IsoFields.DAY_OF_QUARTER, 1L).truncatedTo(ChronoUnit.DAYS) + case TRUNC_TO_WEEK => + zonedDateTime.`with`(DayOfWeek.MONDAY).truncatedTo(ChronoUnit.DAYS) case TRUNC_TO_DAY => - val offset = timeZone.getOffset(millis) - millis += offset - millis - millis % MILLIS_PER_DAY - offset + zonedDateTime.truncatedTo(ChronoUnit.DAYS) case TRUNC_TO_HOUR => - val offset = timeZone.getOffset(millis) - millis += offset - millis - millis % MILLIS_PER_HOUR - offset + zonedDateTime.truncatedTo(ChronoUnit.HOURS) case TRUNC_TO_MINUTE => - millis - millis % MILLIS_PER_MINUTE + zonedDateTime.truncatedTo(ChronoUnit.MINUTES) case TRUNC_TO_SECOND => - millis - millis % MILLIS_PER_SECOND - case TRUNC_TO_WEEK => - val dDays = millisToDays(millis, timeZone) - val prevMonday = getNextDateForDayOfWeek(dDays - 7, MONDAY) - daysToMillis(prevMonday, timeZone) - case TRUNC_TO_QUARTER => - val dDays = millisToDays(millis, timeZone) - val daysOfQuarter = LocalDate.ofEpochDay(dDays) - .`with`(IsoFields.DAY_OF_QUARTER, 1L).toEpochDay.toInt - daysToMillis(daysOfQuarter, timeZone) + zonedDateTime.truncatedTo(ChronoUnit.SECONDS) case _ => // caller make sure that this should never be reached sys.error(s"Invalid trunc level: $level") } - truncated * MICROS_PER_MILLIS + instantToMicros(truncated.toInstant) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 8ff691fb17f27..ea8346f64f056 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -453,9 +453,9 @@ class DateTimeUtilsSuite extends SparkFunSuite { level: Int, expected: String, inputTS: SQLTimestamp, - timezone: TimeZone = DateTimeUtils.defaultTimeZone()): Unit = { + zoneId: ZoneId = defaultZoneId): Unit = { val truncated = - DateTimeUtils.truncTimestamp(inputTS, level, timezone) + DateTimeUtils.truncTimestamp(inputTS, level, zoneId) val expectedTS = DateTimeUtils.stringToTimestamp(UTF8String.fromString(expected), defaultZoneId) assert(truncated === expectedTS.get) @@ -499,21 +499,21 @@ class DateTimeUtilsSuite extends SparkFunSuite { UTF8String.fromString("2015-03-30T02:32:05.359"), defaultZoneId) val inputTS4 = DateTimeUtils.stringToTimestamp( UTF8String.fromString("2015-03-29T02:32:05.359"), defaultZoneId) - - testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, tz) - testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, tz) + val zid = tz.toZoneId + testTrunc(DateTimeUtils.TRUNC_TO_YEAR, "2015-01-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MONTH, "2015-03-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_DAY, "2015-03-05T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_HOUR, "2015-03-05T09:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_MINUTE, "2015-03-05T09:32:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_SECOND, "2015-03-05T09:32:05", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-02T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS1.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS2.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-30T00:00:00", inputTS3.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_WEEK, "2015-03-23T00:00:00", inputTS4.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-01-01T00:00:00", inputTS1.get, zid) + testTrunc(DateTimeUtils.TRUNC_TO_QUARTER, "2015-04-01T00:00:00", inputTS2.get, zid) } } }