diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index b1cafd67820c..2cac86599ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -511,6 +511,8 @@ class MicroBatchExecution( sparkSessionToRunBatch.sparkContext.setLocalProperty( MicroBatchExecution.BATCH_ID_KEY, currentBatchId.toString) + sparkSessionToRunBatch.sparkContext.setLocalProperty( + StreamExecution.IS_CONTINUOUS_PROCESSING, false.toString) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index a39bb715c991..f6c60c1c9212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -529,6 +529,7 @@ abstract class StreamExecution( object StreamExecution { val QUERY_ID_KEY = "sql.streaming.queryId" + val IS_CONTINUOUS_PROCESSING = "__is_continuous_processing" def isInterruptionException(e: Throwable): Boolean = e match { // InterruptedIOException - thrown when an I/O operation is interrupted diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 4ddebb33b79d..ccca72667a21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -209,6 +209,8 @@ class ContinuousExecution( scan.readSupport.asInstanceOf[ContinuousReadSupport] -> scan.scanConfig }.head + sparkSessionForQuery.sparkContext.setLocalProperty( + StreamExecution.IS_CONTINUOUS_PROCESSING, true.toString) sparkSessionForQuery.sparkContext.setLocalProperty( ContinuousExecution.START_EPOCH_KEY, currentBatchId.toString) // Add another random ID on top of the run ID, to distinguish epoch coordinators across diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3f11b8f79943..4a69a48fed75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.EpochTracker import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.StructType @@ -74,9 +75,14 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( // If we're in continuous processing mode, we should get the store version for the current // epoch rather than the one at planning time. - val currentVersion = EpochTracker.getCurrentEpoch match { - case None => storeVersion - case Some(value) => value + val isContinuous = Option(ctxt.getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + .map(_.toBoolean).getOrElse(false) + val currentVersion = if (isContinuous) { + val epoch = EpochTracker.getCurrentEpoch + assert(epoch.isDefined, "Current epoch must be defined for continuous processing streams.") + epoch.get + } else { + storeVersion } store = StateStore.get( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index bf509b1976ed..f55ddb5419d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -29,13 +29,14 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SparkConf, SparkContext, TaskContext} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ @@ -788,7 +789,7 @@ class StreamSuite extends StreamTest { val query = input .toDS() .map { i => - while (!org.apache.spark.TaskContext.get().isInterrupted()) { + while (!TaskContext.get().isInterrupted()) { // keep looping till interrupted by query.stop() Thread.sleep(100) } @@ -1029,6 +1030,34 @@ class StreamSuite extends StreamTest { false)) } + test("is_continuous_processing property should be false for microbatch processing") { + val input = MemoryStream[Int] + val df = input.toDS() + .map(i => TaskContext.get().getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + testStream(df) ( + AddData(input, 1), + CheckAnswer("false") + ) + } + + test("is_continuous_processing property should be true for continuous processing") { + val input = ContinuousMemoryStream[Int] + val stream = input.toDS() + .map(i => TaskContext.get().getLocalProperty(StreamExecution.IS_CONTINUOUS_PROCESSING)) + .writeStream.format("memory") + .queryName("output") + .trigger(Trigger.Continuous("1 seconds")) + .start() + try { + input.addData(1) + stream.processAllAvailable() + } finally { + stream.stop() + } + + checkAnswer(spark.sql("select * from output"), Row("true")) + } + for (e <- Seq( new InterruptedException, new InterruptedIOException,