diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index d45dc937910d..99b4e894bf0a 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -162,9 +162,9 @@ private[spark] object ThreadUtils { /** * Wrapper over newSingleThreadExecutor. */ - def newDaemonSingleThreadExecutor(threadName: String): ExecutorService = { + def newDaemonSingleThreadExecutor(threadName: String): ThreadPoolExecutor = { val threadFactory = new ThreadFactoryBuilder().setDaemon(true).setNameFormat(threadName).build() - Executors.newSingleThreadExecutor(threadFactory) + Executors.newFixedThreadPool(1, threadFactory).asInstanceOf[ThreadPoolExecutor] } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 72eb420de374..ebff9ce546d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1982,6 +1982,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ASYNC_LOG_PURGE = + buildConf("spark.sql.streaming.asyncLogPurge.enabled") + .internal() + .doc("When true, purging the offset log and " + + "commit log of old entries will be done asynchronously.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala new file mode 100644 index 000000000000..b3729dbc7b45 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/AsyncLogPurge.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.ThreadUtils + +/** + * Used to enable the capability to allow log purges to be done asynchronously + */ +trait AsyncLogPurge extends Logging { + + protected var currentBatchId: Long + + protected val minLogEntriesToMaintain: Int + + + protected[sql] val errorNotifier: ErrorNotifier + + protected val sparkSession: SparkSession + + private val asyncPurgeExecutorService + = ThreadUtils.newDaemonSingleThreadExecutor("async-log-purge") + + private val purgeRunning = new AtomicBoolean(false) + + protected def purge(threshold: Long): Unit + + protected lazy val useAsyncPurge: Boolean = sparkSession.conf.get(SQLConf.ASYNC_LOG_PURGE) + + protected def purgeAsync(): Unit = { + if (purgeRunning.compareAndSet(false, true)) { + // save local copy because currentBatchId may get updated. There are not really + // any concurrency issues here in regards to calculating the purge threshold + // but for the sake of defensive coding lets make a copy + val currentBatchIdCopy: Long = currentBatchId + asyncPurgeExecutorService.execute(() => { + try { + purge(currentBatchIdCopy - minLogEntriesToMaintain) + } catch { + case throwable: Throwable => + logError("Encountered error while performing async log purge", throwable) + errorNotifier.markError(throwable) + } finally { + purgeRunning.set(false) + } + }) + } else { + log.debug("Skipped log purging since there is already one in progress.") + } + } + + protected def asyncLogPurgeShutdown(): Unit = { + ThreadUtils.shutdown(asyncPurgeExecutorService) + } + + // used for testing + private[sql] def arePendingAsyncPurge: Boolean = { + purgeRunning.get() || + asyncPurgeExecutorService.getQueue.size() > 0 || + asyncPurgeExecutorService.getActiveCount > 0 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala new file mode 100644 index 000000000000..0f25d0667a0e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ErrorNotifier.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.concurrent.atomic.AtomicReference + +import org.apache.spark.internal.Logging + +/** + * Class to notify of any errors that might have occurred out of band + */ +class ErrorNotifier extends Logging { + + private val error = new AtomicReference[Throwable] + + /** To indicate any errors that have occurred */ + def markError(th: Throwable): Unit = { + logError("A fatal error has occurred.", th) + error.set(th) + } + + /** Get any errors that have occurred */ + def getError(): Option[Throwable] = { + Option(error.get()) + } + + /** Throw errors that have occurred */ + def throwErrorIfExists(): Unit = { + getError().foreach({th => throw th}) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 153bc82f8928..5f8fb93827b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -46,7 +46,9 @@ class MicroBatchExecution( plan: WriteToStream) extends StreamExecution( sparkSession, plan.name, plan.resolvedCheckpointLocation, plan.inputQuery, plan.sink, trigger, - triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) { + triggerClock, plan.outputMode, plan.deleteCheckpointOnStop) with AsyncLogPurge { + + protected[sql] val errorNotifier = new ErrorNotifier() @volatile protected var sources: Seq[SparkDataStream] = Seq.empty @@ -210,6 +212,14 @@ class MicroBatchExecution( logInfo(s"Query $prettyIdString was stopped") } + override def cleanup(): Unit = { + super.cleanup() + + // shutdown and cleanup required for async log purge mechanism + asyncLogPurgeShutdown() + logInfo(s"Async log purge executor pool for query ${prettyIdString} has been shutdown") + } + /** Begins recording statistics about query progress for a given trigger. */ override protected def startTrigger(): Unit = { super.startTrigger() @@ -226,6 +236,10 @@ class MicroBatchExecution( triggerExecutor.execute(() => { if (isActive) { + + // check if there are any previous errors and bubble up any existing async operations + errorNotifier.throwErrorIfExists + var currentBatchHasNewData = false // Whether the current batch had new data startTrigger() @@ -536,7 +550,11 @@ class MicroBatchExecution( // It is now safe to discard the metadata beyond the minimum number to retain. // Note that purge is exclusive, i.e. it purges everything before the target ID. if (minLogEntriesToMaintain < currentBatchId) { - purge(currentBatchId - minLogEntriesToMaintain) + if (useAsyncPurge) { + purgeAsync() + } else { + purge(currentBatchId - minLogEntriesToMaintain) + } } } noNewData = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index eeaa37aa7ffb..5afd744f5e9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -347,6 +347,7 @@ abstract class StreamExecution( try { stopSources() + cleanup() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -410,6 +411,12 @@ abstract class StreamExecution( } } + + /** + * Any clean up that needs to happen when the query is stopped or exits + */ + protected def cleanup(): Unit = {} + /** * Interrupts the query execution thread and awaits its termination until until it exceeds the * timeout. The timeout can be set on "spark.sql.streaming.stopTimeout". diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index 749ca9d06eaf..0ddd48420ef3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -21,17 +21,20 @@ import java.io.File import org.apache.commons.io.FileUtils import org.scalatest.BeforeAndAfter +import org.scalatest.matchers.should._ +import org.scalatest.time.{Seconds, Span} import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.connector.read.streaming import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.functions.{count, timestamp_seconds, window} -import org.apache.spark.sql.streaming.{StreamTest, Trigger} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest, Trigger} import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.util.Utils -class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { +class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Matchers { import testImplicits._ @@ -39,6 +42,84 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter { sqlContext.streams.active.foreach(_.stop()) } + def getListOfFiles(dir: String): List[File] = { + val d = new File(dir) + if (d.exists && d.isDirectory) { + d.listFiles.filter(_.isFile).toList + } else { + List[File]() + } + } + + test("async log purging") { + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { + withTempDir { checkpointLocation => + val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val ds = inputData.toDS() + testStream(ds)( + StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), + AddData(inputData, 0), + CheckNewAnswer(0), + AddData(inputData, 1), + CheckNewAnswer(1), + Execute { q => + getListOfFiles(checkpointLocation + "/offsets") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(0, 1)) + getListOfFiles(checkpointLocation + "/commits") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(0, 1)) + }, + AddData(inputData, 2), + CheckNewAnswer(2), + AddData(inputData, 3), + CheckNewAnswer(3), + Execute { q => + eventually(timeout(Span(5, Seconds))) { + q.asInstanceOf[MicroBatchExecution].arePendingAsyncPurge should be(false) + } + + getListOfFiles(checkpointLocation + "/offsets") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(1, 2, 3)) + getListOfFiles(checkpointLocation + "/commits") + .filter(file => !file.isHidden) + .map(file => file.getName.toInt) + .sorted should equal(Array(1, 2, 3)) + }, + StopStream + ) + } + } + } + + test("error notifier test") { + withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { + withTempDir { checkpointLocation => + val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val ds = inputData.toDS() + val e = intercept[StreamingQueryException] { + + testStream(ds)( + StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), + AddData(inputData, 0), + CheckNewAnswer(0), + AddData(inputData, 1), + CheckNewAnswer(1), + Execute { q => + q.asInstanceOf[MicroBatchExecution].errorNotifier.markError(new Exception("test")) + }, + AddData(inputData, 2), + CheckNewAnswer(2)) + } + e.getCause.getMessage should include("test") + } + } + } + test("SPARK-24156: do not plan a no-data batch again after it has already been planned") { val inputData = MemoryStream[Int] val df = inputData.toDF()