Skip to content

Commit 2a7921a

Browse files
ueshincloud-fan
authored andcommitted
[SPARK-18939][SQL] Timezone support in partition values.
## What changes were proposed in this pull request? This is a follow-up pr of apache#16308 and apache#16750. This pr enables timezone support in partition values. We should use `timeZone` option introduced at apache#16750 to parse/format partition values of the `TimestampType`. For example, if you have timestamp `"2016-01-01 00:00:00"` in `GMT` which will be used for partition values, the values written by the default timezone option, which is `"GMT"` because the session local timezone is `"GMT"` here, are: ```scala scala> spark.conf.set("spark.sql.session.timeZone", "GMT") scala> val df = Seq((1, new java.sql.Timestamp(1451606400000L))).toDF("i", "ts") df: org.apache.spark.sql.DataFrame = [i: int, ts: timestamp] scala> df.show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2016-01-01 00:00:00| +---+-------------------+ scala> df.write.partitionBy("ts").save("/path/to/gmtpartition") ``` ```sh $ ls /path/to/gmtpartition/ _SUCCESS ts=2016-01-01 00%3A00%3A00 ``` whereas setting the option to `"PST"`, they are: ```scala scala> df.write.option("timeZone", "PST").partitionBy("ts").save("/path/to/pstpartition") ``` ```sh $ ls /path/to/pstpartition/ _SUCCESS ts=2015-12-31 16%3A00%3A00 ``` We can properly read the partition values if the session local timezone and the timezone of the partition values are the same: ```scala scala> spark.read.load("/path/to/gmtpartition").show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2016-01-01 00:00:00| +---+-------------------+ ``` And even if the timezones are different, we can properly read the values with setting corrent timezone option: ```scala // wrong result scala> spark.read.load("/path/to/pstpartition").show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2015-12-31 16:00:00| +---+-------------------+ // correct result scala> spark.read.option("timeZone", "PST").load("/path/to/pstpartition").show() +---+-------------------+ | i| ts| +---+-------------------+ | 1|2016-01-01 00:00:00| +---+-------------------+ ``` ## How was this patch tested? Existing tests and added some tests. Author: Takuya UESHIN <[email protected]> Closes apache#17053 from ueshin/issues/SPARK-18939.
1 parent ba186a8 commit 2a7921a

File tree

15 files changed

+175
-59
lines changed

15 files changed

+175
-59
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,13 @@ abstract class ExternalCatalog {
244244
* @param db database name
245245
* @param table table name
246246
* @param predicates partition-pruning predicates
247+
* @param defaultTimeZoneId default timezone id to parse partition values of TimestampType
247248
*/
248249
def listPartitionsByFilter(
249250
db: String,
250251
table: String,
251-
predicates: Seq[Expression]): Seq[CatalogTablePartition]
252+
predicates: Seq[Expression],
253+
defaultTimeZoneId: String): Seq[CatalogTablePartition]
252254

253255
// --------------------------------------------------------------------------
254256
// Functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,8 @@ class InMemoryCatalog(
544544
override def listPartitionsByFilter(
545545
db: String,
546546
table: String,
547-
predicates: Seq[Expression]): Seq[CatalogTablePartition] = {
547+
predicates: Seq[Expression],
548+
defaultTimeZoneId: String): Seq[CatalogTablePartition] = {
548549
// TODO: Provide an implementation
549550
throw new UnsupportedOperationException(
550551
"listPartitionsByFilter is not implemented for InMemoryCatalog")

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ class SessionCatalog(
841841
val table = formatTableName(tableName.table)
842842
requireDbExists(db)
843843
requireTableExists(TableIdentifier(table, Option(db)))
844-
externalCatalog.listPartitionsByFilter(db, table, predicates)
844+
externalCatalog.listPartitionsByFilter(db, table, predicates, conf.sessionLocalTimeZone)
845845
}
846846

847847
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, Internal
2626
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
2727
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal}
2828
import org.apache.spark.sql.catalyst.plans.logical._
29+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
2930
import org.apache.spark.sql.catalyst.util.quoteIdentifier
30-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3131
import org.apache.spark.sql.types.StructType
3232

3333

@@ -113,11 +113,11 @@ case class CatalogTablePartition(
113113
/**
114114
* Given the partition schema, returns a row with that schema holding the partition values.
115115
*/
116-
def toRow(partitionSchema: StructType): InternalRow = {
116+
def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = {
117+
val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties)
118+
val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId)
117119
InternalRow.fromSeq(partitionSchema.map { field =>
118-
// TODO: use correct timezone for partition values.
119-
Cast(Literal(spec(field.name)), field.dataType,
120-
Option(DateTimeUtils.defaultTimeZone().getID)).eval()
120+
Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval()
121121
})
122122
}
123123
}

sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate._
2424
import org.apache.spark.sql.catalyst.plans.logical._
2525
import org.apache.spark.sql.catalyst.rules.Rule
26-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
26+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
2727
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
2828
import org.apache.spark.sql.internal.SQLConf
2929

@@ -103,11 +103,13 @@ case class OptimizeMetadataOnlyQuery(
103103

104104
case relation: CatalogRelation =>
105105
val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation)
106+
val caseInsensitiveProperties =
107+
CaseInsensitiveMap(relation.tableMeta.storage.properties)
108+
val timeZoneId = caseInsensitiveProperties.get("timeZone")
109+
.getOrElse(conf.sessionLocalTimeZone)
106110
val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p =>
107111
InternalRow.fromSeq(partAttrs.map { attr =>
108-
// TODO: use correct timezone for partition values.
109-
Cast(Literal(p.spec(attr.name)), attr.dataType,
110-
Option(DateTimeUtils.defaultTimeZone().getID)).eval()
112+
Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval()
111113
})
112114
}
113115
LocalRelation(partAttrs, partitionData)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ class CatalogFileIndex(
7272
val path = new Path(p.location)
7373
val fs = path.getFileSystem(hadoopConf)
7474
PartitionPath(
75-
p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory))
75+
p.toRow(partitionSchema, sparkSession.sessionState.conf.sessionLocalTimeZone),
76+
path.makeQualified(fs.getUri, fs.getWorkingDirectory))
7677
}
7778
val partitionSpec = PartitionSpec(partitionSchema, partitions)
7879
new PrunedInMemoryFileIndex(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3737
import org.apache.spark.sql.catalyst.expressions._
3838
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
3939
import org.apache.spark.sql.catalyst.InternalRow
40-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
40+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
4141
import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution}
4242
import org.apache.spark.sql.types.{StringType, StructType}
4343
import org.apache.spark.util.{SerializableConfiguration, Utils}
@@ -68,7 +68,8 @@ object FileFormatWriter extends Logging {
6868
val bucketIdExpression: Option[Expression],
6969
val path: String,
7070
val customPartitionLocations: Map[TablePartitionSpec, String],
71-
val maxRecordsPerFile: Long)
71+
val maxRecordsPerFile: Long,
72+
val timeZoneId: String)
7273
extends Serializable {
7374

7475
assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns),
@@ -122,9 +123,11 @@ object FileFormatWriter extends Logging {
122123
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
123124
}
124125

126+
val caseInsensitiveOptions = CaseInsensitiveMap(options)
127+
125128
// Note: prepareWrite has side effect. It sets "job".
126129
val outputWriterFactory =
127-
fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType)
130+
fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType)
128131

129132
val description = new WriteJobDescription(
130133
uuid = UUID.randomUUID().toString,
@@ -136,8 +139,10 @@ object FileFormatWriter extends Logging {
136139
bucketIdExpression = bucketIdExpression,
137140
path = outputSpec.outputPath,
138141
customPartitionLocations = outputSpec.customPartitionLocations,
139-
maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong)
140-
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile)
142+
maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong)
143+
.getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile),
144+
timeZoneId = caseInsensitiveOptions.get("timeZone")
145+
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
141146
)
142147

143148
// We should first sort by partition columns, then bucket id, and finally sorting columns.
@@ -330,11 +335,10 @@ object FileFormatWriter extends Logging {
330335
/** Expressions that given partition columns build a path string like: col1=val/col2=val/... */
331336
private def partitionPathExpression: Seq[Expression] = {
332337
desc.partitionColumns.zipWithIndex.flatMap { case (c, i) =>
333-
// TODO: use correct timezone for partition values.
334338
val escaped = ScalaUDF(
335339
ExternalCatalogUtils.escapePathName _,
336340
StringType,
337-
Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))),
341+
Seq(Cast(c, StringType, Option(desc.timeZoneId))),
338342
Seq(StringType))
339343
val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped)
340344
val partitionName = Literal(c.name + "=") :: str :: Nil

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics
3030
import org.apache.spark.sql.SparkSession
3131
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
3232
import org.apache.spark.sql.catalyst.expressions._
33-
import org.apache.spark.sql.catalyst.util.DateTimeUtils
33+
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
3434
import org.apache.spark.sql.types.{StringType, StructType}
3535
import org.apache.spark.util.SerializableConfiguration
3636

@@ -125,22 +125,27 @@ abstract class PartitioningAwareFileIndex(
125125
val leafDirs = leafDirToChildrenFiles.filter { case (_, files) =>
126126
files.exists(f => isDataPath(f.getPath))
127127
}.keys.toSeq
128+
129+
val caseInsensitiveOptions = CaseInsensitiveMap(parameters)
130+
val timeZoneId = caseInsensitiveOptions.get("timeZone")
131+
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
132+
128133
userPartitionSchema match {
129134
case Some(userProvidedSchema) if userProvidedSchema.nonEmpty =>
130135
val spec = PartitioningUtils.parsePartitions(
131136
leafDirs,
132137
typeInference = false,
133-
basePaths = basePaths)
138+
basePaths = basePaths,
139+
timeZoneId = timeZoneId)
134140

135141
// Without auto inference, all of value in the `row` should be null or in StringType,
136142
// we need to cast into the data type that user specified.
137143
def castPartitionValuesToUserSchema(row: InternalRow) = {
138144
InternalRow((0 until row.numFields).map { i =>
139-
// TODO: use correct timezone for partition values.
140145
Cast(
141146
Literal.create(row.getUTF8String(i), StringType),
142147
userProvidedSchema.fields(i).dataType,
143-
Option(DateTimeUtils.defaultTimeZone().getID)).eval()
148+
Option(timeZoneId)).eval()
144149
}: _*)
145150
}
146151

@@ -151,7 +156,8 @@ abstract class PartitioningAwareFileIndex(
151156
PartitioningUtils.parsePartitions(
152157
leafDirs,
153158
typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
154-
basePaths = basePaths)
159+
basePaths = basePaths,
160+
timeZoneId = timeZoneId)
155161
}
156162
}
157163

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources
1919

2020
import java.lang.{Double => JDouble, Long => JLong}
2121
import java.math.{BigDecimal => JBigDecimal}
22-
import java.sql.{Date => JDate, Timestamp => JTimestamp}
22+
import java.util.TimeZone
2323

2424
import scala.collection.mutable.ArrayBuffer
2525
import scala.util.Try
@@ -31,7 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.analysis.Resolver
3232
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
3333
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal}
34+
import org.apache.spark.sql.catalyst.util.DateTimeUtils
3435
import org.apache.spark.sql.types._
36+
import org.apache.spark.unsafe.types.UTF8String
3537

3638
// TODO: We should tighten up visibility of the classes here once we clean up Hive coupling.
3739

@@ -91,10 +93,19 @@ object PartitioningUtils {
9193
private[datasources] def parsePartitions(
9294
paths: Seq[Path],
9395
typeInference: Boolean,
94-
basePaths: Set[Path]): PartitionSpec = {
96+
basePaths: Set[Path],
97+
timeZoneId: String): PartitionSpec = {
98+
parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId))
99+
}
100+
101+
private[datasources] def parsePartitions(
102+
paths: Seq[Path],
103+
typeInference: Boolean,
104+
basePaths: Set[Path],
105+
timeZone: TimeZone): PartitionSpec = {
95106
// First, we need to parse every partition's path and see if we can find partition values.
96107
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
97-
parsePartition(path, typeInference, basePaths)
108+
parsePartition(path, typeInference, basePaths, timeZone)
98109
}.unzip
99110

100111
// We create pairs of (path -> path's partition value) here
@@ -173,7 +184,8 @@ object PartitioningUtils {
173184
private[datasources] def parsePartition(
174185
path: Path,
175186
typeInference: Boolean,
176-
basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = {
187+
basePaths: Set[Path],
188+
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
177189
val columns = ArrayBuffer.empty[(String, Literal)]
178190
// Old Hadoop versions don't have `Path.isRoot`
179191
var finished = path.getParent == null
@@ -194,7 +206,7 @@ object PartitioningUtils {
194206
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
195207
// Once we get the string, we try to parse it and find the partition column and value.
196208
val maybeColumn =
197-
parsePartitionColumn(currentPath.getName, typeInference)
209+
parsePartitionColumn(currentPath.getName, typeInference, timeZone)
198210
maybeColumn.foreach(columns += _)
199211

200212
// Now, we determine if we should stop.
@@ -226,7 +238,8 @@ object PartitioningUtils {
226238

227239
private def parsePartitionColumn(
228240
columnSpec: String,
229-
typeInference: Boolean): Option[(String, Literal)] = {
241+
typeInference: Boolean,
242+
timeZone: TimeZone): Option[(String, Literal)] = {
230243
val equalSignIndex = columnSpec.indexOf('=')
231244
if (equalSignIndex == -1) {
232245
None
@@ -237,7 +250,7 @@ object PartitioningUtils {
237250
val rawColumnValue = columnSpec.drop(equalSignIndex + 1)
238251
assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'")
239252

240-
val literal = inferPartitionColumnValue(rawColumnValue, typeInference)
253+
val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
241254
Some(columnName -> literal)
242255
}
243256
}
@@ -370,7 +383,8 @@ object PartitioningUtils {
370383
*/
371384
private[datasources] def inferPartitionColumnValue(
372385
raw: String,
373-
typeInference: Boolean): Literal = {
386+
typeInference: Boolean,
387+
timeZone: TimeZone): Literal = {
374388
val decimalTry = Try {
375389
// `BigDecimal` conversion can fail when the `field` is not a form of number.
376390
val bigDecimal = new JBigDecimal(raw)
@@ -390,8 +404,16 @@ object PartitioningUtils {
390404
// Then falls back to fractional types
391405
.orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType)))
392406
// Then falls back to date/timestamp types
393-
.orElse(Try(Literal(JDate.valueOf(raw))))
394-
.orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw)))))
407+
.orElse(Try(
408+
Literal.create(
409+
DateTimeUtils.getThreadLocalTimestampFormat(timeZone)
410+
.parse(unescapePathName(raw)).getTime * 1000L,
411+
TimestampType)))
412+
.orElse(Try(
413+
Literal.create(
414+
DateTimeUtils.millisToDays(
415+
DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime),
416+
DateType)))
395417
// Then falls back to string
396418
.getOrElse {
397419
if (raw == DEFAULT_PARTITION_NAME) {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -742,10 +742,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
742742
.save(iso8601timestampsPath)
743743

744744
// This will load back the timestamps as string.
745+
val stringSchema = StructType(StructField("date", StringType, true) :: Nil)
745746
val iso8601Timestamps = spark.read
746747
.format("csv")
748+
.schema(stringSchema)
747749
.option("header", "true")
748-
.option("inferSchema", "false")
749750
.load(iso8601timestampsPath)
750751

751752
val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US)
@@ -775,10 +776,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
775776
.save(iso8601datesPath)
776777

777778
// This will load back the dates as string.
779+
val stringSchema = StructType(StructField("date", StringType, true) :: Nil)
778780
val iso8601dates = spark.read
779781
.format("csv")
782+
.schema(stringSchema)
780783
.option("header", "true")
781-
.option("inferSchema", "false")
782784
.load(iso8601datesPath)
783785

784786
val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US)
@@ -833,10 +835,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
833835
.save(datesWithFormatPath)
834836

835837
// This will load back the dates as string.
838+
val stringSchema = StructType(StructField("date", StringType, true) :: Nil)
836839
val stringDatesWithFormat = spark.read
837840
.format("csv")
841+
.schema(stringSchema)
838842
.option("header", "true")
839-
.option("inferSchema", "false")
840843
.load(datesWithFormatPath)
841844
val expectedStringDatesWithFormat = Seq(
842845
Row("2015/08/26"),
@@ -864,10 +867,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
864867
.save(timestampsWithFormatPath)
865868

866869
// This will load back the timestamps as string.
870+
val stringSchema = StructType(StructField("date", StringType, true) :: Nil)
867871
val stringTimestampsWithFormat = spark.read
868872
.format("csv")
873+
.schema(stringSchema)
869874
.option("header", "true")
870-
.option("inferSchema", "false")
871875
.load(timestampsWithFormatPath)
872876
val expectedStringTimestampsWithFormat = Seq(
873877
Row("2015/08/26 18:00"),
@@ -896,10 +900,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
896900
.save(timestampsWithFormatPath)
897901

898902
// This will load back the timestamps as string.
903+
val stringSchema = StructType(StructField("date", StringType, true) :: Nil)
899904
val stringTimestampsWithFormat = spark.read
900905
.format("csv")
906+
.schema(stringSchema)
901907
.option("header", "true")
902-
.option("inferSchema", "false")
903908
.load(timestampsWithFormatPath)
904909
val expectedStringTimestampsWithFormat = Seq(
905910
Row("2015/08/27 01:00"),

0 commit comments

Comments
 (0)