diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4980b0cd41f81..bcca821755950 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -36,18 +36,43 @@ class ContinuousSuiteBase extends StreamTest { "continuous-stream-test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) - protected def waitForRateSourceTriggers(query: StreamExecution, numTriggers: Int): Unit = { - query match { - case s: ContinuousExecution => - assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") - val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, _, _, _, r: RateStreamContinuousReader) => r - }.get - - val deltaMs = numTriggers * 1000 + 300 - while (System.currentTimeMillis < reader.creationTime + deltaMs) { - Thread.sleep(reader.creationTime + deltaMs - System.currentTimeMillis) + protected def waitForRateSourceTriggers(query: ContinuousExecution, numTriggers: Int): Unit = { + query.awaitEpoch(0) + + // This is called after waiting first epoch to be committed, so we can just treat + // it as partition readers for rate source are already initialized. + val firstCommittedTime = System.nanoTime() + val deltaNs = (numTriggers * 1000 + 300) * 1000000L + var toWaitNs = firstCommittedTime + deltaNs - System.nanoTime() + while (toWaitNs > 0) { + Thread.sleep(toWaitNs / 1000000) + toWaitNs = firstCommittedTime + deltaNs - System.nanoTime() + } + } + + protected def waitForRateSourceCommittedValue( + query: ContinuousExecution, + desiredValue: Long, + maxWaitTimeMs: Long): Unit = { + def readHighestCommittedValue(c: ContinuousExecution): Option[Long] = { + c.committedOffsets.lastOption.map { case (_, offset) => + offset match { + case o: RateStreamOffset => + o.partitionToValueAndRunTimeMs.map { + case (_, ValueRunTimeMsPair(value, _)) => value + }.max } + } + } + + val maxWait = System.currentTimeMillis() + maxWaitTimeMs + while (System.currentTimeMillis() < maxWait && + readHighestCommittedValue(query).getOrElse(Long.MinValue) < desiredValue) { + Thread.sleep(100) + } + if (System.currentTimeMillis() > maxWait) { + logWarning(s"Couldn't reach desired value in $maxWaitTimeMs milliseconds!" + + s"Current highest committed value is ${readHighestCommittedValue(query)}") } } @@ -216,14 +241,16 @@ class ContinuousSuite extends ContinuousSuiteBase { .queryName("noharness") .trigger(Trigger.Continuous(100)) .start() + + val expected = Set(0, 1, 2, 3) val continuousExecution = query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution] - continuousExecution.awaitEpoch(0) - waitForRateSourceTriggers(continuousExecution, 2) + waitForRateSourceCommittedValue(continuousExecution, expected.max, 20 * 1000) query.stop() val results = spark.read.table("noharness").collect() - assert(Set(0, 1, 2, 3).map(Row(_)).subsetOf(results.toSet)) + assert(expected.map(Row(_)).subsetOf(results.toSet), + s"Result set ${results.toSet} are not a superset of $expected!") } } @@ -241,7 +268,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { testStream(df, useV2Sink = true)( StartStream(longContinuousTrigger), AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 201)), + Execute { exec => + waitForRateSourceTriggers(exec.asInstanceOf[ContinuousExecution], 50) + }, IncrementEpoch(), StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))) @@ -259,7 +288,9 @@ class ContinuousStressSuite extends ContinuousSuiteBase { testStream(df, useV2Sink = true)( StartStream(Trigger.Continuous(2012)), AwaitEpoch(0), - Execute(waitForRateSourceTriggers(_, 201)), + Execute { exec => + waitForRateSourceTriggers(exec.asInstanceOf[ContinuousExecution], 50) + }, IncrementEpoch(), StopStream, CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))