diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala index 905bce4d614e..35ee685a52a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLog.scala @@ -193,7 +193,7 @@ abstract class CompactibleFileStreamLog[T <: AnyRef : ClassTag]( * Returns all files except the deleted ones. */ def allFiles(): Array[T] = { - var latestId = getLatest().map(_._1).getOrElse(-1L) + var latestId = getLatestBatchId().getOrElse(-1L) // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileIndex` // is calling this method. This loop will retry the reading to deal with the // race condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index b679f163fc56..32245470d8f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -142,7 +142,7 @@ class FileStreamSink( } override def addBatch(batchId: Long, data: DataFrame): Unit = { - if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { + if (batchId <= fileLog.getLatestBatchId().getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") } else { val committer = FileCommitProtocol.instantiate( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala index 7b2ea9627a98..c43887774c13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSourceLog.scala @@ -96,7 +96,7 @@ class FileStreamSourceLog( val searchKeys = removedBatches.map(_._1) val retrievedBatches = if (searchKeys.nonEmpty) { logWarning(s"Get batches from removed files, this is unexpected in the current code path!!!") - val latestBatchId = getLatest().map(_._1).getOrElse(-1L) + val latestBatchId = getLatestBatchId().getOrElse(-1L) if (latestBatchId < 0) { Map.empty[Long, Option[Array[FileEntry]]] } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index ed0c44da08c5..5c86f8a50dda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -182,17 +182,26 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - override def getLatest(): Option[(Long, T)] = { - val batchIds = fileManager.list(metadataPath, batchFilesFilter) + /** + * Return the latest batch Id without reading the file. This method only checks for existence of + * file to avoid cost on reading and deserializing log file. + */ + def getLatestBatchId(): Option[Long] = { + fileManager.list(metadataPath, batchFilesFilter) .map(f => pathToBatchId(f.getPath)) .sorted(Ordering.Long.reverse) - for (batchId <- batchIds) { - val batch = get(batchId) - if (batch.isDefined) { - return Some((batchId, batch.get)) + .headOption + } + + override def getLatest(): Option[(Long, T)] = { + getLatestBatchId().map { batchId => + val content = get(batchId).getOrElse { + // If we find the last batch file, we must read that file, other than failing back to + // old batches. + throw new IllegalStateException(s"failed to read log file for batch $batchId") } + (batchId, content) } - None } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index f95daafdfe19..6d615b5ef044 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -18,7 +18,15 @@ package org.apache.spark.sql.execution.streaming import java.io.{ByteArrayInputStream, ByteArrayOutputStream} +import java.lang.{Long => JLong} +import java.net.URI import java.nio.charset.StandardCharsets.UTF_8 +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicLong + +import scala.util.Random + +import org.apache.hadoop.fs.{FSDataInputStream, Path, RawLocalFileSystem} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.internal.SQLConf @@ -240,6 +248,44 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSparkSession { )) } + test("getLatestBatchId") { + withCountOpenLocalFileSystemAsLocalFileSystem { + val scheme = CountOpenLocalFileSystem.scheme + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withTempDir { dir => + val sinkLog = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, + s"$scheme:///${dir.getCanonicalPath}") + for (batchId <- 0L to 2L) { + sinkLog.add( + batchId, + Array(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) + } + + def getCountForOpenOnMetadataFile(batchId: Long): Long = { + val path = sinkLog.batchIdToPath(batchId).toUri.getPath + CountOpenLocalFileSystem.pathToNumOpenCalled.getOrDefault(path, 0L) + } + + CountOpenLocalFileSystem.resetCount() + + assert(sinkLog.getLatestBatchId() === Some(2L)) + // getLatestBatchId doesn't open the latest metadata log file + (0L to 2L).foreach { batchId => + assert(getCountForOpenOnMetadataFile(batchId) === 0L) + } + + assert(sinkLog.getLatest().map(_._1).getOrElse(-1L) === 2L) + (0L to 1L).foreach { batchId => + assert(getCountForOpenOnMetadataFile(batchId) === 0L) + } + // getLatest opens the latest metadata log file, which explains the needs on + // having "getLatestBatchId". + assert(getCountForOpenOnMetadataFile(2L) === 1L) + } + } + } + } + /** * Create a fake SinkFileStatus using path and action. Most of tests don't care about other fields * in SinkFileStatus. @@ -267,4 +313,41 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSparkSession { val log = new FileStreamSinkLog(FileStreamSinkLog.VERSION, spark, input.toString) log.allFiles() } + + private def withCountOpenLocalFileSystemAsLocalFileSystem(body: => Unit): Unit = { + val optionKey = s"fs.${CountOpenLocalFileSystem.scheme}.impl" + val originClassForLocalFileSystem = spark.conf.getOption(optionKey) + try { + spark.conf.set(optionKey, classOf[CountOpenLocalFileSystem].getName) + body + } finally { + originClassForLocalFileSystem match { + case Some(fsClazz) => spark.conf.set(optionKey, fsClazz) + case _ => spark.conf.unset(optionKey) + } + } + } +} + +class CountOpenLocalFileSystem extends RawLocalFileSystem { + import CountOpenLocalFileSystem._ + + override def getUri: URI = { + URI.create(s"$scheme:///") + } + + override def open(f: Path, bufferSize: Int): FSDataInputStream = { + val path = f.toUri.getPath + pathToNumOpenCalled.compute(path, (_, v) => { + if (v == null) 1L else v + 1 + }) + super.open(f, bufferSize) + } +} + +object CountOpenLocalFileSystem { + val scheme = s"FileStreamSinkLogSuite${math.abs(Random.nextInt)}fs" + val pathToNumOpenCalled = new ConcurrentHashMap[String, JLong] + + def resetCount(): Unit = pathToNumOpenCalled.clear() }