diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index aad9f20e022b..9a6e6c73001a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -541,7 +541,7 @@ case class DateFormatClass(left: Expression, right: Expression, timeZoneId: Opti override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val tf = TimestampFormatter.getClass.getName.stripSuffix("$") - val zid = ctx.addReferenceObj("zoneId", zoneId, "java.time.ZoneId") + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val locale = ctx.addReferenceObj("locale", Locale.US) defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString($tf$$.MODULE$$.apply($format.toString(), $zid, $locale) @@ -710,13 +710,13 @@ abstract class UnixTime }""") } case StringType => - val tz = ctx.addReferenceObj("zoneId", zoneId) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val locale = ctx.addReferenceObj("locale", Locale.US) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (string, format) => { s""" try { - ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $tz, $locale) + ${ev.value} = $tf$$.MODULE$$.apply($format.toString(), $zid, $locale) .parse($string.toString()) / $MICROS_PER_SECOND; } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; @@ -849,13 +849,13 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ }""") } } else { - val tz = ctx.addReferenceObj("zoneId", zoneId) + val zid = ctx.addReferenceObj("zoneId", zoneId, classOf[ZoneId].getName) val locale = ctx.addReferenceObj("locale", Locale.US) val tf = TimestampFormatter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (seconds, f) => { s""" try { - ${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $tz, $locale). + ${ev.value} = UTF8String.fromString($tf.apply($f.toString(), $zid, $locale). format($seconds * 1000000L)); } catch (java.lang.IllegalArgumentException e) { ${ev.isNull} = true; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 64bf89926b47..88607d1740b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -26,13 +26,14 @@ import java.util.concurrent.TimeUnit._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.util.{DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -652,7 +653,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("from_unixtime") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val fmt1 = "yyyy-MM-dd HH:mm:ss" + val sdf1 = new SimpleDateFormat(fmt1, Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" val sdf2 = new SimpleDateFormat(fmt2, Locale.US) for (tz <- Seq(TimeZoneGMT, TimeZonePST, TimeZoneJST)) { @@ -661,10 +663,10 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { sdf2.setTimeZone(tz) checkEvaluation( - FromUnixTime(Literal(0L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + FromUnixTime(Literal(0L), Literal(fmt1), timeZoneId), sdf1.format(new Timestamp(0))) checkEvaluation(FromUnixTime( - Literal(1000L), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + Literal(1000L), Literal(fmt1), timeZoneId), sdf1.format(new Timestamp(1000000))) checkEvaluation( FromUnixTime(Literal(-1000L), Literal(fmt2), timeZoneId), @@ -673,13 +675,22 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { FromUnixTime(Literal.create(null, LongType), Literal.create(null, StringType), timeZoneId), null) checkEvaluation( - FromUnixTime(Literal.create(null, LongType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + FromUnixTime(Literal.create(null, LongType), Literal(fmt1), timeZoneId), null) checkEvaluation( FromUnixTime(Literal(1000L), Literal.create(null, StringType), timeZoneId), null) checkEvaluation( FromUnixTime(Literal(0L), Literal("not a valid format"), timeZoneId), null) + + // The codegen path for non-literal input should also work + checkEvaluation( + expression = FromUnixTime( + BoundReference(ordinal = 0, dataType = LongType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + timeZoneId), + expected = UTF8String.fromString(sdf1.format(new Timestamp(0))), + inputRow = InternalRow(0L, UTF8String.fromString(fmt1))) } } @@ -739,7 +750,8 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("to_unix_timestamp") { - val sdf1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) + val fmt1 = "yyyy-MM-dd HH:mm:ss" + val sdf1 = new SimpleDateFormat(fmt1, Locale.US) val fmt2 = "yyyy-MM-dd HH:mm:ss.SSS" val sdf2 = new SimpleDateFormat(fmt2, Locale.US) val fmt3 = "yy-MM-dd" @@ -754,15 +766,15 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val date1 = Date.valueOf("2015-07-24") checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(0))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), 0L) + Literal(sdf1.format(new Timestamp(0))), Literal(fmt1), timeZoneId), 0L) checkEvaluation(ToUnixTimestamp( - Literal(sdf1.format(new Timestamp(1000000))), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + Literal(sdf1.format(new Timestamp(1000000))), Literal(fmt1), timeZoneId), 1000L) checkEvaluation(ToUnixTimestamp( - Literal(new Timestamp(1000000)), Literal("yyyy-MM-dd HH:mm:ss")), + Literal(new Timestamp(1000000)), Literal(fmt1)), 1000L) checkEvaluation( - ToUnixTimestamp(Literal(date1), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + ToUnixTimestamp(Literal(date1), Literal(fmt1), timeZoneId), MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) checkEvaluation( ToUnixTimestamp(Literal(sdf2.format(new Timestamp(-1000000))), Literal(fmt2), timeZoneId), @@ -772,21 +784,31 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis( DateTimeUtils.fromJavaDate(Date.valueOf("2015-07-24")), tz))) val t1 = ToUnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] val t2 = ToUnixTimestamp( - CurrentTimestamp(), Literal("yyyy-MM-dd HH:mm:ss")).eval().asInstanceOf[Long] + CurrentTimestamp(), Literal(fmt1)).eval().asInstanceOf[Long] assert(t2 - t1 <= 1) checkEvaluation(ToUnixTimestamp( Literal.create(null, DateType), Literal.create(null, StringType), timeZoneId), null) checkEvaluation( ToUnixTimestamp( - Literal.create(null, DateType), Literal("yyyy-MM-dd HH:mm:ss"), timeZoneId), + Literal.create(null, DateType), Literal(fmt1), timeZoneId), null) checkEvaluation(ToUnixTimestamp( Literal(date1), Literal.create(null, StringType), timeZoneId), MILLISECONDS.toSeconds(DateTimeUtils.daysToMillis(DateTimeUtils.fromJavaDate(date1), tz))) checkEvaluation( ToUnixTimestamp(Literal("2015-07-24"), Literal("not a valid format"), timeZoneId), null) + + // The codegen path for non-literal input should also work + checkEvaluation( + expression = ToUnixTimestamp( + BoundReference(ordinal = 0, dataType = StringType, nullable = true), + BoundReference(ordinal = 1, dataType = StringType, nullable = true), + timeZoneId), + expected = 0L, + inputRow = InternalRow( + UTF8String.fromString(sdf1.format(new Timestamp(0))), UTF8String.fromString(fmt1))) } } }