diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 70354200c82df..4a993846d6661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -1212,6 +1212,77 @@ case class ToUTCTimestamp(left: Expression, right: Expression) } } +/** + * This modifies a timestamp to show how the display time changes going from one timezone to + * another, for the same instant in time. + * + * We intentionally do not provide an ExpressionDescription as this is not meant to be exposed to + * users, it's only used for internal conversions. + */ +private[spark] case class TimestampTimezoneCorrection( + time: Expression, + from: Expression, + to: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + // convertTz() does the *opposite* conversion we want, which is why from & to appear reversed + // in all the calls to convertTz. It's used for showing how the display time changes when we go + // from one timezone to another. We want to see how the SQLTimestamp value should change to + // ensure the display does *not* change, despite going from one TZ to another. + + override def children: Seq[Expression] = Seq(time, from, to) + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, StringType, StringType) + override def dataType: DataType = TimestampType + override def prettyName: String = "timestamp_timezone_correction" + + override def nullSafeEval(time: Any, from: Any, to: Any): Any = { + DateTimeUtils.convertTz( + time.asInstanceOf[Long], + DateTimeUtils.getTimeZone(to.asInstanceOf[UTF8String].toString()), + DateTimeUtils.getTimeZone(from.asInstanceOf[UTF8String].toString())) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + if (from.foldable && to.foldable) { + val fromTz = from.eval() + val toTz = to.eval() + if (fromTz == null || toTz == null) { + ev.copy(code = s""" + |boolean ${ev.isNull} = true; + |long ${ev.value} = 0; + """.stripMargin) + } else { + val fromTerm = ctx.freshName("from") + val toTerm = ctx.freshName("to") + val tzClass = classOf[TimeZone].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + ctx.addMutableState(tzClass, fromTerm, s"""$fromTerm = $dtu.getTimeZone("$fromTz");""") + ctx.addMutableState(tzClass, toTerm, s"""$toTerm = $dtu.getTimeZone("$toTz");""") + + val eval = time.genCode(ctx) + ev.copy(code = s""" + |${eval.code} + |boolean ${ev.isNull} = ${eval.isNull}; + |long ${ev.value} = 0; + |if (!${ev.isNull}) { + | ${ev.value} = $dtu.convertTz(${eval.value}, $toTerm, $fromTerm); + |} + """.stripMargin) + } + } else { + nullSafeCodeGen(ctx, ev, (time, from, to) => + s""" + |${ev.value} = $dtu.convertTz( + | $time, + | $dtu.getTimeZone($to.toString()), + | $dtu.getTimeZone($from.toString())); + """.stripMargin + ) + } + } +} + /** * Parses a column to a date based on the given format. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 746c3e8950f7b..173a66c9bbfff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -26,6 +26,8 @@ import javax.xml.bind.DatatypeConverter import scala.annotation.tailrec +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.unsafe.types.UTF8String /** @@ -65,6 +67,13 @@ object DateTimeUtils { val TIMEZONE_OPTION = "timeZone" + /** + * Property that holds the time zone used for adjusting "timestamp without time zone" + * columns to the session's time zone. See SPARK-12297 for more details (including the + * specified name of this property). + */ + val TIMEZONE_PROPERTY = "table.timezone-adjustment" + def defaultTimeZone(): TimeZone = TimeZone.getDefault() // Reuse the Calendar object in each thread as it is expensive to create in each method call. @@ -109,6 +118,12 @@ object DateTimeUtils { computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone) } + private lazy val validTimezones = TimeZone.getAvailableIDs().toSet + + def isValidTimezone(timezoneId: String): Boolean = { + validTimezones.contains(timezoneId) + } + def newDateFormat(formatString: String, timeZone: TimeZone): DateFormat = { val sdf = new SimpleDateFormat(formatString, Locale.US) sdf.setTimeZone(timeZone) @@ -1065,4 +1080,24 @@ object DateTimeUtils { threadLocalTimestampFormat.remove() threadLocalDateFormat.remove() } + + /** + * Throw an AnalysisException if we're trying to set an invalid timezone for this table. + */ + def checkTableTz(table: TableIdentifier, properties: Map[String, String]): Unit = { + checkTableTz(s"in table ${table.toString}", properties) + } + + /** + * Throw an AnalysisException if we're trying to set an invalid timezone for this table. + */ + def checkTableTz(dest: String, properties: Map[String, String]): Unit = { + properties.get(TIMEZONE_PROPERTY).foreach { tz => + if (!DateTimeUtils.isValidTimezone(tz)) { + throw new AnalysisException(s"Cannot set $TIMEZONE_PROPERTY to invalid " + + s"timezone $tz $dest") + } + } + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 89d99f9678cda..0cf448f718ff2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -741,4 +741,32 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("2015-07-24 00:00:00", null, null) test(null, null, null) } + + test("timestamp_timezone_correction") { + def test(t: String, fromTz: String, toTz: String, expected: String): Unit = { + checkEvaluation( + TimestampTimezoneCorrection( + Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType), + Literal.create(fromTz, StringType), + Literal.create(toTz, StringType)), + if (expected != null) Timestamp.valueOf(expected) else null) + checkEvaluation( + TimestampTimezoneCorrection( + Literal.create(if (t != null) Timestamp.valueOf(t) else null, TimestampType), + NonFoldableLiteral.create(fromTz, StringType), + NonFoldableLiteral.create(toTz, StringType)), + if (expected != null) Timestamp.valueOf(expected) else null) + } + // These conversions may look backwards -- but this is *NOT* saying: + // when the clock says 2015-07-24 00:00:00 in PST, what would it say to somebody in UTC? + // Instead, its saying -- suppose somebody stored "2015-07-24 00:00:00" while in PST, but + // as millis-since-epoch. What millis-since-epoch would I need to also see + // "2015-07-24 00:00:00" if my clock were in UTC? Just for testing convenience, we input + // that last value as "what would my clock in PST say for that final millis-since-epoch?" + test("2015-07-24 00:00:00", "PST", "UTC", "2015-07-23 17:00:00") + test("2015-01-24 00:00:00", "PST", "UTC", "2015-01-23 16:00:00") + test(null, "UTC", "UTC", null) + test("2015-07-24 00:00:00", null, null, null) + test(null, null, null, null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index c69acc413e87f..cde431f1b9882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -27,6 +27,7 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{DataSource, FailureSafeParser} import org.apache.spark.sql.execution.datasources.csv._ @@ -179,6 +180,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { throw new AnalysisException("Hive data source can only be used with tables, you can not " + "read files of Hive data source directly.") } + DateTimeUtils.checkTableTz("", extraOptions.toMap) sparkSession.baseRelationToDataFrame( DataSource.apply( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d2748544..133ede5a58d8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} @@ -230,6 +231,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } assertNotBucketed("save") + val dest = extraOptions.get("path") match { + case Some(path) => s"for path $path" + case _ => s"with format $source" + } + DateTimeUtils.checkTableTz(dest, extraOptions.toMap) runCommand(df.sparkSession, "save") { DataSource( @@ -266,6 +272,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def insertInto(tableName: String): Unit = { + extraOptions.get(DateTimeUtils.TIMEZONE_PROPERTY).foreach { tz => + throw new AnalysisException("Cannot provide a table timezone on insert; tried to insert " + + s"$tableName with ${DateTimeUtils.TIMEZONE_PROPERTY}=$tz") + } insertInto(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)) } @@ -406,6 +416,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } else { CatalogTableType.MANAGED } + val props = extraOptions.filterKeys(_ == DateTimeUtils.TIMEZONE_PROPERTY).toMap + DateTimeUtils.checkTableTz(tableIdent, props) val tableDesc = CatalogTable( identifier = tableIdent, @@ -414,7 +426,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { schema = new StructType, provider = Some(source), partitionColumnNames = partitioningColumns.getOrElse(Nil), - bucketSpec = getBucketSpec) + bucketSpec = getBucketSpec, + properties = props) runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 162e1d5be2938..ead034f554846 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitioningUtils} import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaConverter @@ -230,6 +231,13 @@ case class AlterTableSetPropertiesCommand( isView: Boolean) extends RunnableCommand { + if (isView) { + properties.get(DateTimeUtils.TIMEZONE_PROPERTY).foreach { _ => + throw new AnalysisException("Timezone cannot be set for view") + } + } + DateTimeUtils.checkTableTz(tableName, properties) + override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8d95ca6921cf8..c3e6ba281bdf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.util.quoteIdentifier +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -76,6 +76,8 @@ case class CreateTableLikeCommand( // If the location is specified, we create an external table internally. // Otherwise create a managed table. val tblType = if (location.isEmpty) CatalogTableType.MANAGED else CatalogTableType.EXTERNAL + val properties = + sourceTableDesc.properties.filterKeys(_ == DateTimeUtils.TIMEZONE_PROPERTY) val newTableDesc = CatalogTable( @@ -86,7 +88,8 @@ case class CreateTableLikeCommand( schema = sourceTableDesc.schema, provider = newProvider, partitionColumnNames = sourceTableDesc.partitionColumnNames, - bucketSpec = sourceTableDesc.bucketSpec) + bucketSpec = sourceTableDesc.bucketSpec, + properties = properties) catalog.createTable(newTableDesc, ifNotExists) Seq.empty[Row] @@ -126,6 +129,8 @@ case class CreateTableCommand( sparkSession.sessionState.catalog.createTable(table, ignoreIfExists) Seq.empty[Row] } + + DateTimeUtils.checkTableTz(table.identifier, table.properties) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index ffdfd527fa701..ab6f9c237ab01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.MetadataBuilder import org.apache.spark.sql.util.SchemaUtils @@ -123,6 +124,10 @@ case class CreateViewCommand( s"It is not allowed to add database prefix `$database` for the TEMPORARY view name.") } + properties.get(DateTimeUtils.TIMEZONE_PROPERTY).foreach { _ => + throw new AnalysisException("Timezone cannot be set for view") + } + override def run(sparkSession: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. val qe = sparkSession.sessionState.executePlan(child) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AdjustTimestamps.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AdjustTimestamps.scala new file mode 100644 index 0000000000000..2e57c5ea7dc21 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AdjustTimestamps.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.{AnalysisException} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, HiveTableRelation} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StringType, TimestampType} + +/** + * Apply a correction to data loaded from, or saved to, tables that have a configured time zone, so + * that timestamps can be read like TIMESTAMP WITHOUT TIMEZONE. This gives correct behavior if you + * process data with machines in different timezones, or if you access the data from multiple SQL + * engines. + */ +case class AdjustTimestamps(conf: SQLConf) extends Rule[LogicalPlan] { + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case insert: InsertIntoHadoopFsRelationCommand => + val adjusted = adjustTimestampsForWrite(insert.query, insert.catalogTable, insert.options) + insert.copy(query = adjusted) + + case insert @ InsertIntoTable(table: HiveTableRelation, _, query, _, _) => + val adjusted = adjustTimestampsForWrite(insert.query, Some(table.tableMeta), Map()) + insert.copy(query = adjusted) + + case other => + convertInputs(plan) + } + + private def convertInputs(plan: LogicalPlan): LogicalPlan = plan match { + case adjusted @ Project(exprs, _) if hasCorrection(exprs) => + adjusted + + case lr @ LogicalRelation(fsRelation: HadoopFsRelation, _, _, _) => + adjustTimestamps(lr, lr.catalogTable, fsRelation.options, true) + + case hr @ HiveTableRelation(table, _, _) => + adjustTimestamps(hr, Some(table), Map(), true) + + case other => + other.mapChildren { originalPlan => + convertInputs(originalPlan) + } + } + + private def adjustTimestamps( + plan: LogicalPlan, + table: Option[CatalogTable], + options: Map[String, String], + reading: Boolean): LogicalPlan = { + val tableTz = table.flatMap(_.properties.get(DateTimeUtils.TIMEZONE_PROPERTY)) + .orElse(options.get(DateTimeUtils.TIMEZONE_PROPERTY)) + + tableTz.map { tz => + val sessionTz = conf.sessionLocalTimeZone + val toTz = if (reading) sessionTz else tz + val fromTz = if (reading) tz else sessionTz + logDebug( + s"table tz = $tz; converting ${if (reading) "to" else "from"} session tz = $sessionTz\n") + + var hasTimestamp = false + val adjusted = plan.expressions.map { + case e: NamedExpression if e.dataType == TimestampType => + val adjustment = TimestampTimezoneCorrection(e.toAttribute, + Literal.create(fromTz, StringType), Literal.create(toTz, StringType)) + hasTimestamp = true + Alias(adjustment, e.name)() + + case other: NamedExpression => + other.toAttribute + + case unnamed => + throw new AnalysisException(s"Unexpected expr: $unnamed") + }.toList + + if (hasTimestamp) Project(adjusted, plan) else plan + }.getOrElse(plan) + } + + private def adjustTimestampsForWrite( + query: LogicalPlan, + table: Option[CatalogTable], + options: Map[String, String]): LogicalPlan = query match { + case unadjusted if !hasOutputCorrection(unadjusted.expressions) => + // The query might be reading from a table with a configured time zone; this makes sure we + // apply the correct conversions for that data. + val fixedInputs = convertInputs(unadjusted) + adjustTimestamps(fixedInputs, table, options, false) + + case _ => + query + } + + private def hasCorrection(exprs: Seq[Expression]): Boolean = { + exprs.exists { expr => + expr.isInstanceOf[TimestampTimezoneCorrection] || hasCorrection(expr.children) + } + } + + private def hasOutputCorrection(exprs: Seq[Expression]): Boolean = { + // Output correction is any TimestampTimezoneCorrection that converts from the current + // session's time zone. + val sessionTz = conf.sessionLocalTimeZone + exprs.exists { + case TimestampTimezoneCorrection(_, from, _) => from.toString() == sessionTz + case other => hasOutputCorrection(other.children) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4e756084bbdbb..beb8db497928e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -158,6 +158,7 @@ abstract class BaseSessionStateBuilder( override val extendedResolutionRules: Seq[Rule[LogicalPlan]] = new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: + AdjustTimestamps(conf) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/AdjustTimestampsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/AdjustTimestampsSuite.scala new file mode 100644 index 0000000000000..8a5b7452c8a93 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/AdjustTimestampsSuite.scala @@ -0,0 +1,343 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources + +import java.io.File +import java.sql.Timestamp +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.SparkPlanTest +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} + +abstract class BaseAdjustTimestampsSuite extends SparkPlanTest with SQLTestUtils + with BeforeAndAfterAll { + + protected val SESSION_TZ = "America/Los_Angeles" + protected val TABLE_TZ = "Europe/Berlin" + protected val UTC = "UTC" + protected val TZ_KEY = DateTimeUtils.TIMEZONE_PROPERTY + + var originalTz: TimeZone = _ + protected override def beforeAll(): Unit = { + super.beforeAll() + originalTz = TimeZone.getDefault() + TimeZone.setDefault(TimeZone.getTimeZone(SESSION_TZ)) + } + + protected override def afterAll(): Unit = { + TimeZone.setDefault(originalTz) + super.afterAll() + } + + val desiredTimestampStrings = Seq( + "2015-12-31 22:49:59.123", + "2015-12-31 23:50:59.123", + "2016-01-01 00:39:59.123", + "2016-01-01 01:29:59.123" + ) + // We don't want to mess with timezones inside the tests themselves, since we use a shared + // spark context in the hive tests, and then we might be prone to issues from lazy vals for + // timezones. Instead, we manually adjust the timezone just to determine what the desired millis + // (since epoch, in utc) is for various "wall-clock" times in different timezones, and then we can + // compare against those in our tests. + val timestampTimezoneToMillis = { + val originalTz = TimeZone.getDefault + try { + desiredTimestampStrings.flatMap { timestampString => + Seq(SESSION_TZ, TABLE_TZ, UTC).map { tzId => + TimeZone.setDefault(TimeZone.getTimeZone(tzId)) + val timestamp = Timestamp.valueOf(timestampString) + (timestampString, tzId) -> timestamp.getTime() + } + }.toMap + } finally { + TimeZone.setDefault(originalTz) + } + } + + protected def createRawData(spark: SparkSession): Dataset[(String, Timestamp)] = { + import spark.implicits._ + val df = desiredTimestampStrings.toDF("display") + // this will get the millis corresponding to the display time given the current session tz + df.withColumn("ts", expr("cast(display as timestamp)")).as[(String, Timestamp)] + } + + protected def checkHasTz(spark: SparkSession, table: String, tz: Option[String]): Unit = { + val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + assert(tableMetadata.properties.get(TZ_KEY) === tz, + s"for table $table") + } + + /** + * This checks that the dataframe contains data that matches our original data set, modified for + * the given timezone, *if* we didn't apply any conversions when reading it back. You should + * pass in a dataframe that reads from the raw files, without any timezone specified. + */ + private def checkRawData(data: DataFrame, tz: String): Unit = { + val rows = data.collect() + assert(rows.size == 4) + rows.foreach { row => + val disp = row.getAs[String]("display") + val ts = row.getAs[Timestamp]("ts") + val expMillis = timestampTimezoneToMillis((disp, tz)) + assert(ts.getTime() === expMillis) + } + } + + private val formats = Seq("parquet", "csv", "json") + + // we want to test that this works w/ hive-only methods as well, so provide a few extension + // points so we can also easily re-use this with hive support. + protected def createAndSaveTableFunctions(): Map[String, CreateAndSaveTable] = { + formats.map { f => (f, new CreateAndSaveDatasourceTable(f)) }.toMap + } + + protected def ctasFunctions(): Map[String, CTAS] = { + formats.map { f => (f, new DatasourceCTAS(f)) }.toMap + } + + trait CreateAndSaveTable { + /** Create the table and save the contents of the dataset into it. */ + def createAndSave(df: DataFrame, table: String, tz: Option[String]): Unit + + /** The target table's format. */ + val format: String + } + + class CreateAndSaveDatasourceTable(override val format: String) extends CreateAndSaveTable { + override def createAndSave(df: DataFrame, table: String, tz: Option[String]): Unit = { + val writer = df.write.format(format) + tz.foreach(writer.option(TZ_KEY, _)) + writer.saveAsTable(table) + } + } + + trait CTAS { + /** + * Create a table with the given time zone, and copy the entire contents of the source table + * into it. + */ + def createFromSource(source: String, dest: String, destTz: Option[String]): Unit + + /** The target table's format. */ + val format: String + } + + class DatasourceCTAS(override val format: String) extends CTAS { + override def createFromSource(source: String, dest: String, destTz: Option[String]): Unit = { + val writer = spark.sql(s"select * from $source").write.format(format) + destTz.foreach { writer.option(TZ_KEY, _)} + writer.saveAsTable(dest) + } + } + + createAndSaveTableFunctions().foreach { case (fmt, createFn) => + ctasFunctions().foreach { case (destFmt, ctasFn) => + test(s"timestamp adjustment: in=$fmt, out=$destFmt") { + testTimestampAdjustment(fmt, destFmt, createFn, ctasFn) + } + } + } + + private def testTimestampAdjustment( + format: String, + destFormat: String, + createFn: CreateAndSaveTable, + ctasFn: CTAS): Unit = { + assert(TimeZone.getDefault.getID() === SESSION_TZ) + val originalData = createRawData(spark) + withTempPath { basePath => + val dsPath = new File(basePath, "dsFlat").getAbsolutePath + originalData.write + .option(TZ_KEY, TABLE_TZ) + .parquet(dsPath) + + /** + * Reads the raw data underlying the table, and assuming the data came from + * [[createRawData()]], make sure the values are correct. + */ + def checkTableData(table: String, format: String): Unit = { + // These queries should return the entire dataset, but if the predicates were + // applied to the raw values in parquet, they would incorrectly filter data out. + Seq( + "ts > '2015-12-31 22:00:00'", + "ts < '2016-01-01 02:00:00'" + ).foreach { filter => + val query = + s"select ts from $table where $filter" + val countWithFilter = spark.sql(query).count() + assert(countWithFilter === desiredTimestampStrings.size, query) + } + + // also, read the raw table data, without any TZ correction, and make sure the raw + // values have been adjusted as we expect. + val tableMeta = spark.sessionState.catalog.getTableMetadata(TableIdentifier(table)) + val location = tableMeta.location.toString() + val tz = tableMeta.properties.get(TZ_KEY) + // some formats need the schema specified + val df = spark.read.schema(originalData.schema).format(format).load(location) + checkRawData(df, tz.getOrElse(SESSION_TZ)) + } + + // read it back, without supplying the right timezone. Won't match the original, but we + // expect specific values. + val readNoCorrection = spark.read.parquet(dsPath) + checkRawData(readNoCorrection, TABLE_TZ) + + // now read it back *with* the right timezone -- everything should match. + val readWithCorrection = spark.read.option(TZ_KEY, TABLE_TZ).parquet(dsPath) + + readWithCorrection.collect().foreach { row => + assert(row.getAs[String]("display") === row.getAs[Timestamp]("ts").toString()) + } + + // save to tables, and read the data back -- this time, the timezone conversion should be + // automatic from the table metadata, we don't need to supply any options when reading the + // data. Works across different ways of creating the tables and different data formats. + val tblName = s"save_$format" + withTable(tblName) { + // create the table (if we can -- not all createAndSave() methods support all formats, + // eg. hive tables don't support json) + createFn.createAndSave(readWithCorrection, tblName, Some(UTC)) + // make sure it has the right timezone, and the data is correct. + checkHasTz(spark, tblName, Some(UTC)) + checkTableData(tblName, createFn.format) + + // also try to copy this table directly into another table with a different timezone + // setting, for all formats. + val destTableUTC = s"copy_to_utc_$destFormat" + val destTableNoTZ = s"copy_to_no_tz_$destFormat" + withTable(destTableUTC, destTableNoTZ) { + ctasFn.createFromSource(tblName, destTableUTC, Some(UTC)) + checkHasTz(spark, destTableUTC, Some(UTC)) + checkTableData(destTableUTC, ctasFn.format) + + ctasFn.createFromSource(tblName, destTableNoTZ, None) + checkHasTz(spark, destTableNoTZ, None) + checkTableData(destTableNoTZ, ctasFn.format) + + // By now, we've checked that the data in both tables is different in terms + // of the raw values on disk, but they are the same after we apply the + // timezone conversions from the table properties. Just to be extra-sure, + // we join the tables and make sure its OK. + val joinedRows = spark.sql( + s"""SELECT a.display, a.ts + |FROM $tblName AS a + |JOIN $destTableUTC AS b + |ON (a.ts = b.ts)""".stripMargin).collect() + assert(joinedRows.size === 4) + joinedRows.foreach { row => + assert(row.getAs[String]("display") === + row.getAs[Timestamp]("ts").toString()) + } + } + + // Finally, try changing the tbl timezone. This destroys integrity + // of the existing data, but at this point we're just checking we can change + // the metadata + spark.sql( + s"""ALTER TABLE $tblName SET TBLPROPERTIES ("$TZ_KEY"="$SESSION_TZ")""") + checkHasTz(spark, tblName, Some(SESSION_TZ)) + + spark.sql(s"""ALTER TABLE $tblName UNSET TBLPROPERTIES ("$TZ_KEY")""") + checkHasTz(spark, tblName, None) + + spark.sql(s"""ALTER TABLE $tblName SET TBLPROPERTIES ("$TZ_KEY"="$UTC")""") + checkHasTz(spark, tblName, Some(UTC)) + } + } + } + + test("exception on bad timezone") { + // make sure there is an exception anytime we try to read or write with a bad timezone + val badVal = "Blart Versenwald III" + val data = createRawData(spark) + def hasBadTzException(command: => Unit): Unit = { + withTable("bad_tz_table") { + val badTzException = intercept[AnalysisException] { command } + assert(badTzException.getMessage.contains(badVal)) + } + } + + withTempPath { p => + hasBadTzException { + data.write.option(TZ_KEY, badVal).parquet(p.getAbsolutePath) + } + + data.write.parquet(p.getAbsolutePath) + hasBadTzException { + spark.read.option(TZ_KEY, badVal).parquet(p.getAbsolutePath) + } + } + + createAndSaveTableFunctions().foreach { case (_, createFn) => + hasBadTzException{ + createFn.createAndSave(data.toDF(), "bad_tz_table", Some(badVal)) + } + + createFn.createAndSave(data.toDF(), "bad_tz_table", None) + hasBadTzException { + spark.sql(s"""ALTER TABLE bad_tz_table SET TBLPROPERTIES("$TZ_KEY"="$badVal")""") + } + } + } + + test("insertInto must not specify timezone") { + // You can't specify the timezone for just a portion of inserted data. You can only specify + // the timezone for the *entire* table (data previously in the table and any future data) so + // complain loudly if the user tries to set the timezone on an insert. + withTable("some_table") { + val origData = createRawData(spark) + origData.write.saveAsTable("some_table") + val exc = intercept[AnalysisException]{ + createRawData(spark).write.option(TZ_KEY, UTC) + .insertInto("some_table") + } + assert(exc.getMessage.contains("Cannot provide a table timezone on insert")) + } + } + + test("disallow table timezone on views") { + val originalData = createRawData(spark) + + withTable("ok_table") { + originalData.write.saveAsTable("ok_table") + withView("view_with_tz") { + val exc1 = intercept[AnalysisException]{ + spark.sql(s"""CREATE VIEW view_with_tz + |TBLPROPERTIES ("$TZ_KEY"="$UTC") + |AS SELECT * FROM ok_table + """.stripMargin) + } + assert(exc1.getMessage.contains("Timezone cannot be set for view")) + spark.sql("CREATE VIEW view_with_tz AS SELECT * FROM ok_table") + val exc2 = intercept[AnalysisException]{ + spark.sql(s"""ALTER VIEW view_with_tz SET TBLPROPERTIES("$TZ_KEY"="$UTC")""") + } + assert(exc2.getMessage.contains("Timezone cannot be set for view")) + } + } + } +} + +class AdjustTimestampsSuite extends BaseAdjustTimestampsSuite with SharedSQLContext diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e3..218926385c650 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -71,6 +71,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new ResolveHiveSerdeTable(session) +: new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: + AdjustTimestamps(conf) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 65e8b4e3c725c..55901a1fd5126 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -92,7 +92,7 @@ case class CreateHiveTableAsSelectCommand( } override def argString: String = { - s"[Database:${tableDesc.database}}, " + + s"[Database:${tableDesc.database}, " + s"TableName: ${tableDesc.identifier.table}, " + s"InsertIntoHiveTable]" } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveAdjustTimestampsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveAdjustTimestampsSuite.scala new file mode 100644 index 0000000000000..2fd35eb01b886 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveAdjustTimestampsSuite.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.execution.datasources.BaseAdjustTimestampsSuite +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveAdjustTimestampsSuite extends BaseAdjustTimestampsSuite with TestHiveSingleton { + + override protected def createAndSaveTableFunctions(): Map[String, CreateAndSaveTable] = { + val hiveFns = Map( + "hive_parquet" -> new CreateHiveTableAndInsert()) + + super.createAndSaveTableFunctions() ++ hiveFns + } + + override protected def ctasFunctions(): Map[String, CTAS] = { + // Disabling metastore conversion will also modify how data is read when the the CTAS query is + // run if the source is a Hive table; so, the test that uses "hive_parquet" as the source and + // "hive_parquet_no_conversion" as the target is actually using the "no metastore conversion" + // path for both, making it unnecessary to also have the "no conversion" case in the save + // functions. + val hiveFns = Map( + "hive_parquet" -> new CreateHiveTableWithTimezoneAndInsert(true), + "hive_parquet_no_conversion" -> new CreateHiveTableWithTimezoneAndInsert(false)) + + super.ctasFunctions() ++ hiveFns + } + + class CreateHiveTableAndInsert extends CreateAndSaveTable { + override def createAndSave(df: DataFrame, table: String, tzOpt: Option[String]): Unit = { + val tblProperties = tzOpt.map { tz => + s"""TBLPROPERTIES ("$TZ_KEY"="$tz")""" + }.getOrElse("") + spark.sql( + s"""CREATE TABLE $table ( + | display string, + | ts timestamp + |) + |STORED AS parquet + |$tblProperties + |""".stripMargin) + df.write.insertInto(table) + } + + override val format: String = "parquet" + } + + class CreateHiveTableWithTimezoneAndInsert(convertMetastore: Boolean) extends CTAS { + override def createFromSource(source: String, dest: String, destTz: Option[String]): Unit = { + withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> convertMetastore.toString) { + val tblProperties = destTz.map { tz => + s"""TBLPROPERTIES ("$TZ_KEY"="$tz")""" + }.getOrElse("") + // this isn't just a "ctas" sql statement b/c that doesn't let us specify the table tz + spark.sql( + s"""CREATE TABLE $dest ( + | display string, + | ts timestamp + |) + |STORED AS parquet + |$tblProperties + |""".stripMargin) + spark.sql(s"insert into $dest select * from $source") + } + } + + override val format: String = "parquet" + } + + test("copy table timezone in CREATE TABLE LIKE") { + withTable("orig_hive", "copy_hive", "orig_ds", "copy_ds") { + spark.sql( + s"""CREATE TABLE orig_hive ( + | display string, + | ts timestamp + |) + |STORED AS parquet + |TBLPROPERTIES ("$TZ_KEY"="$UTC") + |""". + stripMargin) + spark.sql("CREATE TABLE copy_hive LIKE orig_hive") + checkHasTz(spark, "copy_hive", Some(UTC)) + + createRawData(spark).write.option(TZ_KEY, TABLE_TZ).saveAsTable("orig_ds") + spark.sql("CREATE TABLE copy_ds LIKE orig_ds") + checkHasTz(spark, "copy_ds", Some(TABLE_TZ)) + } + + } +}