Skip to content

Commit 4d29867

Browse files
zsxwingtdas
authored andcommitted
[SPARK-7341] [STREAMING] [TESTS] Fix the flaky test: org.apache.spark.stre...
...aming.InputStreamsSuite.socket input stream Remove non-deterministic "Thread.sleep" and use deterministic strategies to fix the flaky failure: https://amplab.cs.berkeley.edu/jenkins/job/Spark-Master-Maven-pre-YARN/hadoop.version=1.0.4,label=centos/2127/testReport/junit/org.apache.spark.streaming/InputStreamsSuite/socket_input_stream/ Author: zsxwing <[email protected]> Closes #5891 from zsxwing/SPARK-7341 and squashes the following commits: 611157a [zsxwing] Add wait methods to BatchCounter and use BatchCounter in InputStreamsSuite 014b58f [zsxwing] Use withXXX to clean up the resources c9bf746 [zsxwing] Move 'waitForStart' into the 'start' method and fix the code style 9d0de6d [zsxwing] [SPARK-7341][Streaming][Tests] Fix the flaky test: org.apache.spark.streaming.InputStreamsSuite.socket input stream
1 parent 8436f7e commit 4d29867

File tree

2 files changed

+140
-63
lines changed

2 files changed

+140
-63
lines changed

streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala

Lines changed: 106 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
package org.apache.spark.streaming
1919

2020
import java.io.{File, BufferedWriter, OutputStreamWriter}
21-
import java.net.{SocketException, ServerSocket}
21+
import java.net.{Socket, SocketException, ServerSocket}
2222
import java.nio.charset.Charset
23-
import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
23+
import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue}
2424
import java.util.concurrent.atomic.AtomicInteger
2525

2626
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue}
@@ -36,58 +36,65 @@ import org.scalatest.concurrent.Eventually._
3636
import org.apache.spark.Logging
3737
import org.apache.spark.rdd.RDD
3838
import org.apache.spark.storage.StorageLevel
39+
import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener}
3940
import org.apache.spark.util.{ManualClock, Utils}
4041
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
4142
import org.apache.spark.streaming.receiver.Receiver
4243

4344
class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
4445

4546
test("socket input stream") {
46-
// Start the server
47-
val testServer = new TestServer()
48-
testServer.start()
47+
withTestServer(new TestServer()) { testServer =>
48+
// Start the server
49+
testServer.start()
4950

50-
// Set up the streaming context and input streams
51-
val ssc = new StreamingContext(conf, batchDuration)
52-
val networkStream = ssc.socketTextStream(
53-
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
54-
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
55-
val outputStream = new TestOutputStream(networkStream, outputBuffer)
56-
def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
57-
outputStream.register()
58-
ssc.start()
59-
60-
// Feed data to the server to send to the network receiver
61-
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
62-
val input = Seq(1, 2, 3, 4, 5)
63-
val expectedOutput = input.map(_.toString)
64-
Thread.sleep(1000)
65-
for (i <- 0 until input.size) {
66-
testServer.send(input(i).toString + "\n")
67-
Thread.sleep(500)
68-
clock.advance(batchDuration.milliseconds)
69-
}
70-
Thread.sleep(1000)
71-
logInfo("Stopping server")
72-
testServer.stop()
73-
logInfo("Stopping context")
74-
ssc.stop()
75-
76-
// Verify whether data received was as expected
77-
logInfo("--------------------------------")
78-
logInfo("output.size = " + outputBuffer.size)
79-
logInfo("output")
80-
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
81-
logInfo("expected output.size = " + expectedOutput.size)
82-
logInfo("expected output")
83-
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
84-
logInfo("--------------------------------")
51+
// Set up the streaming context and input streams
52+
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
53+
val input = Seq(1, 2, 3, 4, 5)
54+
// Use "batchCount" to make sure we check the result after all batches finish
55+
val batchCounter = new BatchCounter(ssc)
56+
val networkStream = ssc.socketTextStream(
57+
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
58+
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
59+
val outputStream = new TestOutputStream(networkStream, outputBuffer)
60+
outputStream.register()
61+
ssc.start()
8562

86-
// Verify whether all the elements received are as expected
87-
// (whether the elements were received one in each interval is not verified)
88-
assert(output.size === expectedOutput.size)
89-
for (i <- 0 until output.size) {
90-
assert(output(i) === expectedOutput(i))
63+
// Feed data to the server to send to the network receiver
64+
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
65+
val expectedOutput = input.map(_.toString)
66+
for (i <- 0 until input.size) {
67+
testServer.send(input(i).toString + "\n")
68+
Thread.sleep(500)
69+
clock.advance(batchDuration.milliseconds)
70+
}
71+
// Make sure we finish all batches before "stop"
72+
if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) {
73+
fail("Timeout: cannot finish all batches in 30 seconds")
74+
}
75+
logInfo("Stopping server")
76+
testServer.stop()
77+
logInfo("Stopping context")
78+
ssc.stop()
79+
80+
// Verify whether data received was as expected
81+
logInfo("--------------------------------")
82+
logInfo("output.size = " + outputBuffer.size)
83+
logInfo("output")
84+
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
85+
logInfo("expected output.size = " + expectedOutput.size)
86+
logInfo("expected output")
87+
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
88+
logInfo("--------------------------------")
89+
90+
// Verify whether all the elements received are as expected
91+
// (whether the elements were received one in each interval is not verified)
92+
val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
93+
assert(output.size === expectedOutput.size)
94+
for (i <- 0 until output.size) {
95+
assert(output(i) === expectedOutput(i))
96+
}
97+
}
9198
}
9299
}
93100

@@ -368,31 +375,45 @@ class TestServer(portToBind: Int = 0) extends Logging {
368375

369376
val serverSocket = new ServerSocket(portToBind)
370377

378+
private val startLatch = new CountDownLatch(1)
379+
371380
val servingThread = new Thread() {
372381
override def run() {
373382
try {
374383
while(true) {
375384
logInfo("Accepting connections on port " + port)
376385
val clientSocket = serverSocket.accept()
377-
logInfo("New connection")
378-
try {
379-
clientSocket.setTcpNoDelay(true)
380-
val outputStream = new BufferedWriter(
381-
new OutputStreamWriter(clientSocket.getOutputStream))
382-
383-
while(clientSocket.isConnected) {
384-
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
385-
if (msg != null) {
386-
outputStream.write(msg)
387-
outputStream.flush()
388-
logInfo("Message '" + msg + "' sent")
386+
if (startLatch.getCount == 1) {
387+
// The first connection is a test connection to implement "waitForStart", so skip it
388+
// and send a signal
389+
if (!clientSocket.isClosed) {
390+
clientSocket.close()
391+
}
392+
startLatch.countDown()
393+
} else {
394+
// Real connections
395+
logInfo("New connection")
396+
try {
397+
clientSocket.setTcpNoDelay(true)
398+
val outputStream = new BufferedWriter(
399+
new OutputStreamWriter(clientSocket.getOutputStream))
400+
401+
while (clientSocket.isConnected) {
402+
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
403+
if (msg != null) {
404+
outputStream.write(msg)
405+
outputStream.flush()
406+
logInfo("Message '" + msg + "' sent")
407+
}
408+
}
409+
} catch {
410+
case e: SocketException => logError("TestServer error", e)
411+
} finally {
412+
logInfo("Connection closed")
413+
if (!clientSocket.isClosed) {
414+
clientSocket.close()
389415
}
390416
}
391-
} catch {
392-
case e: SocketException => logError("TestServer error", e)
393-
} finally {
394-
logInfo("Connection closed")
395-
if (!clientSocket.isClosed) clientSocket.close()
396417
}
397418
}
398419
} catch {
@@ -404,7 +425,29 @@ class TestServer(portToBind: Int = 0) extends Logging {
404425
}
405426
}
406427

407-
def start() { servingThread.start() }
428+
def start(): Unit = {
429+
servingThread.start()
430+
if (!waitForStart(10000)) {
431+
stop()
432+
throw new AssertionError("Timeout: TestServer cannot start in 10 seconds")
433+
}
434+
}
435+
436+
/**
437+
* Wait until the server starts. Return true if the server starts in "millis" milliseconds.
438+
* Otherwise, return false to indicate it's timeout.
439+
*/
440+
private def waitForStart(millis: Long): Boolean = {
441+
// We will create a test connection to the server so that we can make sure it has started.
442+
val socket = new Socket("localhost", port)
443+
try {
444+
startLatch.await(millis, TimeUnit.MILLISECONDS)
445+
} finally {
446+
if (!socket.isClosed) {
447+
socket.close()
448+
}
449+
}
450+
}
408451

409452
def send(msg: String) { queue.put(msg) }
410453

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,40 @@ class BatchCounter(ssc: StreamingContext) {
146146
def getNumStartedBatches: Int = this.synchronized {
147147
numStartedBatches
148148
}
149+
150+
/**
151+
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
152+
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
153+
* timeout.
154+
*
155+
* @param expectedNumCompletedBatches the `expectedNumCompletedBatches` batches to wait
156+
* @param timeout the maximum time to wait in milliseconds.
157+
*/
158+
def waitUntilBatchesCompleted(expectedNumCompletedBatches: Int, timeout: Long): Boolean =
159+
waitUntilConditionBecomeTrue(numCompletedBatches >= expectedNumCompletedBatches, timeout)
160+
161+
/**
162+
* Wait until `expectedNumStartedBatches` batches are completed, or timeout. Return true if
163+
* `expectedNumStartedBatches` batches are completed. Otherwise, return false to indicate it's
164+
* timeout.
165+
*
166+
* @param expectedNumStartedBatches the `expectedNumStartedBatches` batches to wait
167+
* @param timeout the maximum time to wait in milliseconds.
168+
*/
169+
def waitUntilBatchesStarted(expectedNumStartedBatches: Int, timeout: Long): Boolean =
170+
waitUntilConditionBecomeTrue(numStartedBatches >= expectedNumStartedBatches, timeout)
171+
172+
private def waitUntilConditionBecomeTrue(condition: => Boolean, timeout: Long): Boolean = {
173+
synchronized {
174+
var now = System.currentTimeMillis()
175+
val timeoutTick = now + timeout
176+
while (!condition && timeoutTick > now) {
177+
wait(timeoutTick - now)
178+
now = System.currentTimeMillis()
179+
}
180+
condition
181+
}
182+
}
149183
}
150184

151185
/**

0 commit comments

Comments
 (0)