Skip to content

Commit 3106324

Browse files
committed
[SPARK-25184][SS] Fixed race condition in StreamExecution that caused flaky test in FlatMapGroupsWithState
## What changes were proposed in this pull request? The race condition that caused test failure is between 2 threads. - The MicrobatchExecution thread that processes inputs to produce answers and then generates progress events. - The test thread that generates some input data, checked the answer and then verified the query generated progress event. The synchronization structure between these threads is as follows 1. MicrobatchExecution thread, in every batch, does the following in order. a. Processes batch input to generate answer. b. Signals `awaitProgressLockCondition` to wake up threads waiting for progress using `awaitOffset` c. Generates progress event 2. Test execution thread a. Calls `awaitOffset` to wait for progress, which waits on `awaitProgressLockCondition`. b. As soon as `awaitProgressLockCondition` is signaled, it would move on the in the test to check answer. c. Finally, it would verify the last generated progress event. What can happen is the following sequence of events: 2a -> 1a -> 1b -> 2b -> 2c -> 1c. In other words, the progress event may be generated after the test tries to verify it. The solution has two steps. 1. Signal the waiting thread after the progress event has been generated, that is, after `finishTrigger()`. 2. Increase the timeout of `awaitProgressLockCondition.await(100 ms)` to a large value. This latter is to ensure that test thread for keeps waiting on `awaitProgressLockCondition`until the MicroBatchExecution thread explicitly signals it. With the existing small timeout of 100ms the following sequence can occur. - MicroBatchExecution thread updates committed offsets - Test thread waiting on `awaitProgressLockCondition` accidentally times out after 100 ms, finds that the committed offsets have been updated, therefore returns from `awaitOffset` and moves on to the progress event tests. - MicroBatchExecution thread then generates progress event and signals. But the test thread has already attempted to verify the event and failed. By increasing the timeout to large (e.g., `streamingTimeoutMs = 60 seconds`, similar to `awaitInitialization`), this above type of race condition is also avoided. ## How was this patch tested? Ran locally many times. Closes #22182 from tdas/SPARK-25184. Authored-by: Tathagata Das <[email protected]> Signed-off-by: Tathagata Das <[email protected]>
1 parent 68ec4d6 commit 3106324

File tree

5 files changed

+33
-25
lines changed

5 files changed

+33
-25
lines changed

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,8 @@ abstract class KafkaSourceSuiteBase extends KafkaSourceTest {
970970
makeSureGetOffsetCalled,
971971
Execute { q =>
972972
// wait to reach the last offset in every partition
973-
q.awaitOffset(0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)))
973+
q.awaitOffset(
974+
0, KafkaSourceOffset(partitionOffsets.mapValues(_ => 3L)), streamingTimeout.toMillis)
974975
},
975976
CheckAnswer(-20, -21, -22, 0, 1, 2, 11, 12, 22),
976977
StopStream,

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ class MicroBatchExecution(
200200

201201
finishTrigger(currentBatchHasNewData) // Must be outside reportTimeTaken so it is recorded
202202

203+
// Signal waiting threads. Note this must be after finishTrigger() to ensure all
204+
// activities (progress generation, etc.) have completed before signaling.
205+
withProgressLocked { awaitProgressLockCondition.signalAll() }
206+
203207
// If the current batch has been executed, then increment the batch id and reset flag.
204208
// Otherwise, there was no data to execute the batch and sleep for some time
205209
if (isCurrentBatchConstructed) {
@@ -538,7 +542,6 @@ class MicroBatchExecution(
538542
watermarkTracker.updateWatermark(lastExecution.executedPlan)
539543
commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
540544
committedOffsets ++= availableOffsets
541-
awaitProgressLockCondition.signalAll()
542545
}
543546
logDebug(s"Completed batch ${currentBatchId}")
544547
}

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ abstract class StreamExecution(
382382
* Blocks the current thread until processing for data from the given `source` has reached at
383383
* least the given `Offset`. This method is intended for use primarily when writing tests.
384384
*/
385-
private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
385+
private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = {
386386
assertAwaitThread()
387387
def notDone = {
388388
val localCommittedOffsets = committedOffsets
@@ -398,7 +398,7 @@ abstract class StreamExecution(
398398
while (notDone) {
399399
awaitProgressLock.lock()
400400
try {
401-
awaitProgressLockCondition.await(100, TimeUnit.MILLISECONDS)
401+
awaitProgressLockCondition.await(timeoutMs, TimeUnit.MILLISECONDS)
402402
if (streamDeathCause != null) {
403403
throw streamDeathCause
404404
}

sql/core/src/test/scala/org/apache/spark/sql/streaming/StateStoreMetricsTest.scala

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,33 +31,37 @@ trait StateStoreMetricsTest extends StreamTest {
3131

3232
def assertNumStateRows(total: Seq[Long], updated: Seq[Long]): AssertOnQuery =
3333
AssertOnQuery(s"Check total state rows = $total, updated state rows = $updated") { q =>
34-
val recentProgress = q.recentProgress
35-
require(recentProgress.nonEmpty, "No progress made, cannot check num state rows")
36-
require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention,
37-
"This test assumes that all progresses are present in q.recentProgress but " +
38-
"some may have been dropped due to retention limits")
34+
// This assumes that the streaming query will not make any progress while the eventually
35+
// is being executed.
36+
eventually(timeout(streamingTimeout)) {
37+
val recentProgress = q.recentProgress
38+
require(recentProgress.nonEmpty, "No progress made, cannot check num state rows")
39+
require(recentProgress.length < spark.sessionState.conf.streamingProgressRetention,
40+
"This test assumes that all progresses are present in q.recentProgress but " +
41+
"some may have been dropped due to retention limits")
3942

40-
if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1
41-
lastQuery = q
43+
if (q.ne(lastQuery)) lastCheckedRecentProgressIndex = -1
44+
lastQuery = q
4245

43-
val numStateOperators = recentProgress.last.stateOperators.length
44-
val progressesSinceLastCheck = recentProgress
45-
.slice(lastCheckedRecentProgressIndex + 1, recentProgress.length)
46-
.filter(_.stateOperators.length == numStateOperators)
46+
val numStateOperators = recentProgress.last.stateOperators.length
47+
val progressesSinceLastCheck = recentProgress
48+
.slice(lastCheckedRecentProgressIndex + 1, recentProgress.length)
49+
.filter(_.stateOperators.length == numStateOperators)
4750

48-
val allNumUpdatedRowsSinceLastCheck =
49-
progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated))
51+
val allNumUpdatedRowsSinceLastCheck =
52+
progressesSinceLastCheck.map(_.stateOperators.map(_.numRowsUpdated))
5053

51-
lazy val debugString = "recent progresses:\n" +
52-
progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n")
54+
lazy val debugString = "recent progresses:\n" +
55+
progressesSinceLastCheck.map(_.prettyJson).mkString("\n\n")
5356

54-
val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal)
55-
assert(numTotalRows === total, s"incorrect total rows, $debugString")
57+
val numTotalRows = recentProgress.last.stateOperators.map(_.numRowsTotal)
58+
assert(numTotalRows === total, s"incorrect total rows, $debugString")
5659

57-
val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators)
58-
assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString")
60+
val numUpdatedRows = arraySum(allNumUpdatedRowsSinceLastCheck, numStateOperators)
61+
assert(numUpdatedRows === updated, s"incorrect updates rows, $debugString")
5962

60-
lastCheckedRecentProgressIndex = recentProgress.length - 1
63+
lastCheckedRecentProgressIndex = recentProgress.length - 1
64+
}
6165
true
6266
}
6367

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
467467
// Block until all data added has been processed for all the source
468468
awaiting.foreach { case (sourceIndex, offset) =>
469469
failAfter(streamingTimeout) {
470-
currentStream.awaitOffset(sourceIndex, offset)
470+
currentStream.awaitOffset(sourceIndex, offset, streamingTimeout.toMillis)
471471
// Make sure all processing including no-data-batches have been executed
472472
if (!currentStream.triggerClock.isInstanceOf[StreamManualClock]) {
473473
currentStream.processAllAvailable()

0 commit comments

Comments
 (0)