Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,24 @@ case class DataSource(
bucket.sortColumnNames, "in the sort definition", equality)
}

/**
* In the read path, only managed tables by Hive provide the partition columns properly when
* initializing this class. All other file based data sources will try to infer the partitioning,
* and then cast the inferred types to user specified dataTypes if the partition columns exist
* inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or
* inconsistent data types as reported in SPARK-21463.
* @param fileIndex A FileIndex that will perform partition inference
* @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema`
*/
private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = {
val resolved = fileIndex.partitionSchema.map { partitionField =>
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
partitionField)
}
StructType(resolved)
}

/**
* Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer
* it. In the read path, only managed tables by Hive provide the partition columns properly when
Expand Down Expand Up @@ -139,12 +157,7 @@ case class DataSource(
val partitionSchema = if (partitionColumns.isEmpty) {
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
// columns properly unless it is a Hive DataSource
val resolved = tempFileIndex.partitionSchema.map { partitionField =>
// SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred
userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse(
partitionField)
}
StructType(resolved)
combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex)
} else {
// maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred
// partitioning
Expand Down Expand Up @@ -336,7 +349,13 @@ case class DataSource(
caseInsensitiveOptions.get("path").toSeq ++ paths,
sparkSession.sessionState.newHadoopConf()) =>
val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head)
val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath)
val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None)
val fileCatalog = if (userSpecifiedSchema.nonEmpty) {
val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog)
new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema))
} else {
tempFileCatalog
}
val dataSchema = userSpecifiedSchema.orElse {
format.inferSchema(
sparkSession,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class FileStreamSource(
private def allFilesUsingMetadataLogFileIndex() = {
// Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a
// non-glob path
new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles()
new MetadataLogFileIndex(sparkSession, qualifiedBasePath, None).allFiles()
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@ import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.types.StructType


/**
* A [[FileIndex]] that generates the list of files to processing by reading them from the
* metadata log files generated by the [[FileStreamSink]].
*
* @param userPartitionSchema an optional partition schema that will be use to provide types for
* the discovered partitions
*/
class MetadataLogFileIndex(sparkSession: SparkSession, path: Path)
extends PartitioningAwareFileIndex(sparkSession, Map.empty, None) {
class MetadataLogFileIndex(
sparkSession: SparkSession,
path: Path,
userPartitionSchema: Option[StructType])
extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) {

private val metadataDirectory = new Path(path, FileStreamSink.metadataDir)
logInfo(s"Reading streaming file log from $metadataDirectory")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.{PartitionPath => Partition}
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
Expand Down Expand Up @@ -1022,4 +1023,36 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
}
}
}

test("SPARK-21463: MetadataLogFileIndex should respect userSpecifiedSchema for partition cols") {
withTempDir { tempDir =>
val output = new File(tempDir, "output").toString
val checkpoint = new File(tempDir, "chkpoint").toString
try {
val stream = MemoryStream[(String, Int)]
val df = stream.toDS().toDF("time", "value")
val sq = df.writeStream
.option("checkpointLocation", checkpoint)
.format("parquet")
.partitionBy("time")
.start(output)

stream.addData(("2017-01-01-00", 1), ("2017-01-01-01", 2))
sq.processAllAvailable()

val schema = new StructType()
.add("time", StringType)
.add("value", IntegerType)
val readBack = spark.read.schema(schema).parquet(output)
assert(readBack.schema.toSet === schema.toSet)

checkAnswer(
readBack,
Seq(Row("2017-01-01-00", 1), Row("2017-01-01-01", 2))
)
} finally {
spark.streams.active.foreach(_.stop())
}
}
}
}