Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ import java.util

import scala.collection.JavaConverters._

import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalog.v2.expressions.Transform
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.streaming.{FileStreamSink, MetadataLogFileIndex}
import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.sources.v2.TableCapability._
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -44,23 +45,37 @@ abstract class FileTable(
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf,
checkEmptyGlobPath = true, checkFilesExist = true)
val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
new InMemoryFileIndex(
sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache)
if (FileStreamSink.hasMetadata(paths, hadoopConf, sparkSession.sessionState.conf)) {
// We are reading from the results of a streaming query. We will load files from
// the metadata log instead of listing them using HDFS APIs.
new MetadataLogFileIndex(sparkSession, new Path(paths.head),
options.asScala.toMap, userSpecifiedSchema)
} else {
// This is a non-streaming file based datasource.
val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf,
checkEmptyGlobPath = true, checkFilesExist = true)
val fileStatusCache = FileStatusCache.getOrCreate(sparkSession)
new InMemoryFileIndex(
sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache)
}
}

lazy val dataSchema: StructType = userSpecifiedSchema.map { schema =>
val partitionSchema = fileIndex.partitionSchema
val resolver = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionSchema.exists(p => resolver(p.name, f.name))))
}.orElse {
inferSchema(fileIndex.allFiles())
}.getOrElse {
throw new AnalysisException(
s"Unable to infer schema for $formatName. It must be specified manually.")
}.asNullable
lazy val dataSchema: StructType = {
val schema = userSpecifiedSchema.map { schema =>
val partitionSchema = fileIndex.partitionSchema
val resolver = sparkSession.sessionState.conf.resolver
StructType(schema.filterNot(f => partitionSchema.exists(p => resolver(p.name, f.name))))
}.orElse {
inferSchema(fileIndex.allFiles())
}.getOrElse {
throw new AnalysisException(
s"Unable to infer schema for $formatName. It must be specified manually.")
}
fileIndex match {
case _: MetadataLogFileIndex => schema
case _ => schema.asNullable
}
}

override lazy val schema: StructType = {
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,19 @@ import scala.collection.JavaConverters._

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation, FileScan, FileTable}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.Utils

class FileStreamSinkSuite extends StreamTest {
abstract class FileStreamSinkSuite extends StreamTest {
import testImplicits._

override def beforeAll(): Unit = {
Expand Down Expand Up @@ -112,91 +114,6 @@ class FileStreamSinkSuite extends StreamTest {
}
}

test("partitioned writing and batch reading") {
val inputData = MemoryStream[Int]
val ds = inputData.toDS()

val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath

var query: StreamingQuery = null

// TODO: test file source V2 as well.
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "parquet") {
try {
query =
ds.map(i => (i, i * 1000))
.toDF("id", "value")
.writeStream
.partitionBy("id")
.option("checkpointLocation", checkpointDir)
.format("parquet")
.start(outputDir)

inputData.addData(1, 2, 3)
failAfter(streamingTimeout) {
query.processAllAvailable()
}

val outputDf = spark.read.parquet(outputDir)
val expectedSchema = new StructType()
.add(StructField("value", IntegerType, nullable = false))
.add(StructField("id", IntegerType))
assert(outputDf.schema === expectedSchema)

// Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
// been inferred
val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect {
case LogicalRelation(baseRelation: HadoopFsRelation, _, _, _) => baseRelation
}
assert(hadoopdFsRelations.size === 1)
assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex])
assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == "id"))
assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value"))

// Verify the data is correctly read
checkDatasetUnorderly(
outputDf.as[(Int, Int)],
(1000, 1), (2000, 2), (3000, 3))

/** Check some condition on the partitions of the FileScanRDD generated by a DF */
def checkFileScanPartitions(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
val getFileScanRDD = df.queryExecution.executedPlan.collect {
case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
func(getFileScanRDD.filePartitions)
}

// Read without pruning
checkFileScanPartitions(outputDf) { partitions =>
// There should be as many distinct partition values as there are distinct ids
assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3)
}

// Read with pruning, should read only files in partition dir id=1
checkFileScanPartitions(outputDf.filter("id = 1")) { partitions =>
val filesToBeRead = partitions.flatMap(_.files)
assert(filesToBeRead.map(_.filePath).forall(_.contains("/id=1/")))
assert(filesToBeRead.map(_.partitionValues).distinct.size === 1)
}

// Read with pruning, should read only files in partition dir id=1 and id=2
checkFileScanPartitions(outputDf.filter("id in (1,2)")) { partitions =>
val filesToBeRead = partitions.flatMap(_.files)
assert(!filesToBeRead.map(_.filePath).exists(_.contains("/id=3/")))
assert(filesToBeRead.map(_.partitionValues).distinct.size === 2)
}
} finally {
if (query != null) {
query.stop()
}
}
}
}

test("partitioned writing and batch reading with 'basePath'") {
withTempDir { outputDir =>
withTempDir { checkpointDir =>
Expand Down Expand Up @@ -512,3 +429,174 @@ class FileStreamSinkSuite extends StreamTest {
}
}
}

class FileStreamSinkV1Suite extends FileStreamSinkSuite {
import testImplicits._

override protected def sparkConf: SparkConf =
super
.sparkConf
.set(SQLConf.USE_V1_SOURCE_READER_LIST, "csv,json,orc,text,parquet")
.set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "csv,json,orc,text,parquet")

test("partitioned writing and batch reading") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you highlight the difference between this test case and the one in FileStreamSinkV2Suite?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For V1 suite, it uses Parquet V1, and it matches HadoopFsRelation / FileScanRDD to check if the plan is as expected. Also, partition pruning is tested.

For V2 suite, it uses Parquet V2, and it matches FileTable / BatchScanExec to check if the plan is as expected. As partition pruning is not supported in V2 yet, we can't test it.

I can abstract the code if the two cases look duplicated to you.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or, do you mean changing the test case name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's abstract the code, this test case is too long to duplicate.

val inputData = MemoryStream[Int]
val ds = inputData.toDS()

val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath

var query: StreamingQuery = null

try {
query =
ds.map(i => (i, i * 1000))
.toDF("id", "value")
.writeStream
.partitionBy("id")
.option("checkpointLocation", checkpointDir)
.format("parquet")
.start(outputDir)

inputData.addData(1, 2, 3)
failAfter(streamingTimeout) {
query.processAllAvailable()
}

val outputDf = spark.read.parquet(outputDir)
val expectedSchema = new StructType()
.add(StructField("value", IntegerType, nullable = false))
.add(StructField("id", IntegerType))
assert(outputDf.schema === expectedSchema)

// Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
// been inferred
val hadoopdFsRelations = outputDf.queryExecution.analyzed.collect {
case LogicalRelation(baseRelation: HadoopFsRelation, _, _, _) => baseRelation
}
assert(hadoopdFsRelations.size === 1)
assert(hadoopdFsRelations.head.location.isInstanceOf[MetadataLogFileIndex])
assert(hadoopdFsRelations.head.partitionSchema.exists(_.name == "id"))
assert(hadoopdFsRelations.head.dataSchema.exists(_.name == "value"))

// Verify the data is correctly read
checkDatasetUnorderly(
outputDf.as[(Int, Int)],
(1000, 1), (2000, 2), (3000, 3))

/** Check some condition on the partitions of the FileScanRDD generated by a DF */
def checkFileScanPartitions(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
val getFileScanRDD = df.queryExecution.executedPlan.collect {
case scan: DataSourceScanExec if scan.inputRDDs().head.isInstanceOf[FileScanRDD] =>
scan.inputRDDs().head.asInstanceOf[FileScanRDD]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
func(getFileScanRDD.filePartitions)
}

// Read without pruning
checkFileScanPartitions(outputDf) { partitions =>
// There should be as many distinct partition values as there are distinct ids
assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3)
}

// Read with pruning, should read only files in partition dir id=1
checkFileScanPartitions(outputDf.filter("id = 1")) { partitions =>
val filesToBeRead = partitions.flatMap(_.files)
assert(filesToBeRead.map(_.filePath).forall(_.contains("/id=1/")))
assert(filesToBeRead.map(_.partitionValues).distinct.size === 1)
}

// Read with pruning, should read only files in partition dir id=1 and id=2
checkFileScanPartitions(outputDf.filter("id in (1,2)")) { partitions =>
val filesToBeRead = partitions.flatMap(_.files)
assert(!filesToBeRead.map(_.filePath).exists(_.contains("/id=3/")))
assert(filesToBeRead.map(_.partitionValues).distinct.size === 2)
}
} finally {
if (query != null) {
query.stop()
}
}
}
}

class FileStreamSinkV2Suite extends FileStreamSinkSuite {
import testImplicits._

override protected def sparkConf: SparkConf =
super
.sparkConf
.set(SQLConf.USE_V1_SOURCE_READER_LIST, "")
.set(SQLConf.USE_V1_SOURCE_WRITER_LIST, "")

test("partitioned writing and batch reading") {
val inputData = MemoryStream[Int]
val ds = inputData.toDS()

val outputDir = Utils.createTempDir(namePrefix = "stream.output").getCanonicalPath
val checkpointDir = Utils.createTempDir(namePrefix = "stream.checkpoint").getCanonicalPath

var query: StreamingQuery = null

try {
query =
ds.map(i => (i, i * 1000))
.toDF("id", "value")
.writeStream
.partitionBy("id")
.option("checkpointLocation", checkpointDir)
.format("parquet")
.start(outputDir)

inputData.addData(1, 2, 3)
failAfter(streamingTimeout) {
query.processAllAvailable()
}

val outputDf = spark.read.parquet(outputDir)
val expectedSchema = new StructType()
.add(StructField("value", IntegerType, nullable = false))
.add(StructField("id", IntegerType))
assert(outputDf.schema === expectedSchema)

// Verify that MetadataLogFileIndex is being used and the correct partitioning schema has
// been inferred
val table = outputDf.queryExecution.analyzed.collect {
case DataSourceV2Relation(table: FileTable, _, _) => table
}
assert(table.size === 1)
assert(table.head.fileIndex.isInstanceOf[MetadataLogFileIndex])
assert(table.head.fileIndex.partitionSchema.exists(_.name == "id"))
assert(table.head.dataSchema.exists(_.name == "value"))

// Verify the data is correctly read
checkDatasetUnorderly(
outputDf.as[(Int, Int)],
(1000, 1), (2000, 2), (3000, 3))

/** Check some condition on the partitions of the FileScanRDD generated by a DF */
def checkFileScanPartitions(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = {
val fileScan = df.queryExecution.executedPlan.collect {
case batch: BatchScanExec if batch.scan.isInstanceOf[FileScan] =>
batch.scan.asInstanceOf[FileScan]
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
func(fileScan.planInputPartitions().map(_.asInstanceOf[FilePartition]))
}

// Read without pruning
checkFileScanPartitions(outputDf) { partitions =>
// There should be as many distinct partition values as there are distinct ids
assert(partitions.flatMap(_.files.map(_.partitionValues)).distinct.size === 3)
}
// TODO: test partition pruning when file source V2 supports it.
} finally {
if (query != null) {
query.stop()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,12 @@ class StreamSuite extends StreamTest {
}
}

// TODO: fix file source V2 as well.
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "parquet") {
val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load()
assertDF(df)
assertDF(df)
val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load()
Seq("", "parquet").foreach { useV1SourceReader =>
withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1SourceReader) {
assertDF(df)
assertDF(df)
}
}
}

Expand Down
Loading