diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d36a04f1fff8e..cbe8ce421f92b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -96,6 +96,24 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param fileIndex A FileIndex that will perform partition inference + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { + val resolved = fileIndex.partitionSchema.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } + /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -139,12 +157,7 @@ case class DataSource( val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - val resolved = tempFileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) + combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -336,7 +349,13 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) + val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) + val fileCatalog = if (userSpecifiedSchema.nonEmpty) { + val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) + new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) + } else { + tempFileCatalog + } val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index a9e64c640042a..4b1b2520390ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -195,7 +195,7 @@ class FileStreamSource( private def allFilesUsingMetadataLogFileIndex() = { // Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a // non-glob path - new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles() + new MetadataLogFileIndex(sparkSession, qualifiedBasePath, None).allFiles() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index aeaa134736937..1da703cefd8ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -23,14 +23,21 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.types.StructType /** * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. + * + * @param userPartitionSchema an optional partition schema that will be use to provide types for + * the discovered partitions */ -class MetadataLogFileIndex(sparkSession: SparkSession, path: Path) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) { +class MetadataLogFileIndex( + sparkSession: SparkSession, + path: Path, + userPartitionSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 84b34d5ad26d1..2f5fd8438f682 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition} +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -1022,4 +1023,36 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha } } } + + test("SPARK-21463: MetadataLogFileIndex should respect userSpecifiedSchema for partition cols") { + withTempDir { tempDir => + val output = new File(tempDir, "output").toString + val checkpoint = new File(tempDir, "chkpoint").toString + try { + val stream = MemoryStream[(String, Int)] + val df = stream.toDS().toDF("time", "value") + val sq = df.writeStream + .option("checkpointLocation", checkpoint) + .format("parquet") + .partitionBy("time") + .start(output) + + stream.addData(("2017-01-01-00", 1), ("2017-01-01-01", 2)) + sq.processAllAvailable() + + val schema = new StructType() + .add("time", StringType) + .add("value", IntegerType) + val readBack = spark.read.schema(schema).parquet(output) + assert(readBack.schema.toSet === schema.toSet) + + checkAnswer( + readBack, + Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2)) + ) + } finally { + spark.streams.active.foreach(_.stop()) + } + } + } }