diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala index aec756c0eb2a..14046f6a99c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala @@ -110,8 +110,9 @@ class ContinuousCoalesceRDD( context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong) while (!context.isInterrupted() && !context.isCompleted()) { writer.write(prev.compute(prevSplit, context).asInstanceOf[Iterator[UnsafeRow]]) - // Note that current epoch is a non-inheritable thread local, so each writer thread - // can properly increment its own epoch without affecting the main task thread. + // Note that current epoch is a inheritable thread local but makes another instance, + // so each writer thread can properly increment its own epoch without affecting + // the main task thread. EpochTracker.incrementCurrentEpoch() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala index bc0ae428d452..631ae4806d2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochTracker.scala @@ -26,8 +26,15 @@ import java.util.concurrent.atomic.AtomicLong object EpochTracker { // The current epoch. Note that this is a shared reference; ContinuousWriteRDD.compute() will // update the underlying AtomicLong as it finishes epochs. Other code should only read the value. - private val currentEpoch: ThreadLocal[AtomicLong] = new ThreadLocal[AtomicLong] { - override def initialValue() = new AtomicLong(-1) + private val currentEpoch: InheritableThreadLocal[AtomicLong] = { + new InheritableThreadLocal[AtomicLong] { + override protected def childValue(parent: AtomicLong): AtomicLong = { + // Note: make another instance so that changes in the parent epoch aren't reflected in + // those in the children threads. This is required at `ContinuousCoalesceRDD`. + new AtomicLong(parent.get) + } + override def initialValue() = new AtomicLong(-1) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index c6921010a002..5bd75c850fe7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -85,6 +85,7 @@ class ContinuousSuiteBase extends StreamTest { } class ContinuousSuite extends ContinuousSuiteBase { + import IntegratedUDFTestUtils._ import testImplicits._ test("basic") { @@ -252,6 +253,26 @@ class ContinuousSuite extends ContinuousSuiteBase { assert(expected.map(Row(_)).subsetOf(results.toSet), s"Result set ${results.toSet} are not a superset of $expected!") } + + Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).foreach { udf => + test(s"continuous mode with various UDFs - ${udf.prettyName}") { + assume( + shouldTestScalarPandasUDFs && udf.isInstanceOf[TestScalarPandasUDF] || + shouldTestPythonUDFs && udf.isInstanceOf[TestPythonUDF] || + udf.isInstanceOf[TestScalaUDF]) + + val input = ContinuousMemoryStream[Int] + val df = input.toDF() + + testStream(df.select(udf(df("value")).cast("int")))( + AddData(input, 0, 1, 2), + CheckAnswer(0, 1, 2), + StopStream, + AddData(input, 3, 4, 5), + StartStream(), + CheckAnswer(0, 1, 2, 3, 4, 5)) + } + } } class ContinuousStressSuite extends ContinuousSuiteBase {