1818package org .apache .spark .streaming
1919
2020import java .io .{File , BufferedWriter , OutputStreamWriter }
21- import java .net .{SocketException , ServerSocket }
21+ import java .net .{Socket , SocketException , ServerSocket }
2222import java .nio .charset .Charset
23- import java .util .concurrent .{Executors , TimeUnit , ArrayBlockingQueue }
23+ import java .util .concurrent .{CountDownLatch , Executors , TimeUnit , ArrayBlockingQueue }
2424import java .util .concurrent .atomic .AtomicInteger
2525
2626import scala .collection .mutable .{SynchronizedBuffer , ArrayBuffer , SynchronizedQueue }
@@ -36,58 +36,65 @@ import org.scalatest.concurrent.Eventually._
3636import org .apache .spark .Logging
3737import org .apache .spark .rdd .RDD
3838import org .apache .spark .storage .StorageLevel
39+ import org .apache .spark .streaming .scheduler .{StreamingListenerBatchCompleted , StreamingListener }
3940import org .apache .spark .util .{ManualClock , Utils }
4041import org .apache .spark .streaming .dstream .{InputDStream , ReceiverInputDStream }
4142import org .apache .spark .streaming .receiver .Receiver
4243
4344class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
4445
4546 test(" socket input stream" ) {
46- // Start the server
47- val testServer = new TestServer ()
48- testServer.start()
47+ withTestServer( new TestServer ()) { testServer =>
48+ // Start the server
49+ testServer.start()
4950
50- // Set up the streaming context and input streams
51- val ssc = new StreamingContext (conf, batchDuration)
52- val networkStream = ssc.socketTextStream(
53- " localhost" , testServer.port, StorageLevel .MEMORY_AND_DISK )
54- val outputBuffer = new ArrayBuffer [Seq [String ]] with SynchronizedBuffer [Seq [String ]]
55- val outputStream = new TestOutputStream (networkStream, outputBuffer)
56- def output : ArrayBuffer [String ] = outputBuffer.flatMap(x => x)
57- outputStream.register()
58- ssc.start()
59-
60- // Feed data to the server to send to the network receiver
61- val clock = ssc.scheduler.clock.asInstanceOf [ManualClock ]
62- val input = Seq (1 , 2 , 3 , 4 , 5 )
63- val expectedOutput = input.map(_.toString)
64- Thread .sleep(1000 )
65- for (i <- 0 until input.size) {
66- testServer.send(input(i).toString + " \n " )
67- Thread .sleep(500 )
68- clock.advance(batchDuration.milliseconds)
69- }
70- Thread .sleep(1000 )
71- logInfo(" Stopping server" )
72- testServer.stop()
73- logInfo(" Stopping context" )
74- ssc.stop()
75-
76- // Verify whether data received was as expected
77- logInfo(" --------------------------------" )
78- logInfo(" output.size = " + outputBuffer.size)
79- logInfo(" output" )
80- outputBuffer.foreach(x => logInfo(" [" + x.mkString(" ," ) + " ]" ))
81- logInfo(" expected output.size = " + expectedOutput.size)
82- logInfo(" expected output" )
83- expectedOutput.foreach(x => logInfo(" [" + x.mkString(" ," ) + " ]" ))
84- logInfo(" --------------------------------" )
51+ // Set up the streaming context and input streams
52+ withStreamingContext(new StreamingContext (conf, batchDuration)) { ssc =>
53+ val input = Seq (1 , 2 , 3 , 4 , 5 )
54+ // Use "batchCount" to make sure we check the result after all batches finish
55+ val batchCounter = new BatchCounter (ssc)
56+ val networkStream = ssc.socketTextStream(
57+ " localhost" , testServer.port, StorageLevel .MEMORY_AND_DISK )
58+ val outputBuffer = new ArrayBuffer [Seq [String ]] with SynchronizedBuffer [Seq [String ]]
59+ val outputStream = new TestOutputStream (networkStream, outputBuffer)
60+ outputStream.register()
61+ ssc.start()
8562
86- // Verify whether all the elements received are as expected
87- // (whether the elements were received one in each interval is not verified)
88- assert(output.size === expectedOutput.size)
89- for (i <- 0 until output.size) {
90- assert(output(i) === expectedOutput(i))
63+ // Feed data to the server to send to the network receiver
64+ val clock = ssc.scheduler.clock.asInstanceOf [ManualClock ]
65+ val expectedOutput = input.map(_.toString)
66+ for (i <- 0 until input.size) {
67+ testServer.send(input(i).toString + " \n " )
68+ Thread .sleep(500 )
69+ clock.advance(batchDuration.milliseconds)
70+ }
71+ // Make sure we finish all batches before "stop"
72+ if (! batchCounter.waitUntilBatchesCompleted(input.size, 30000 )) {
73+ fail(" Timeout: cannot finish all batches in 30 seconds" )
74+ }
75+ logInfo(" Stopping server" )
76+ testServer.stop()
77+ logInfo(" Stopping context" )
78+ ssc.stop()
79+
80+ // Verify whether data received was as expected
81+ logInfo(" --------------------------------" )
82+ logInfo(" output.size = " + outputBuffer.size)
83+ logInfo(" output" )
84+ outputBuffer.foreach(x => logInfo(" [" + x.mkString(" ," ) + " ]" ))
85+ logInfo(" expected output.size = " + expectedOutput.size)
86+ logInfo(" expected output" )
87+ expectedOutput.foreach(x => logInfo(" [" + x.mkString(" ," ) + " ]" ))
88+ logInfo(" --------------------------------" )
89+
90+ // Verify whether all the elements received are as expected
91+ // (whether the elements were received one in each interval is not verified)
92+ val output : ArrayBuffer [String ] = outputBuffer.flatMap(x => x)
93+ assert(output.size === expectedOutput.size)
94+ for (i <- 0 until output.size) {
95+ assert(output(i) === expectedOutput(i))
96+ }
97+ }
9198 }
9299 }
93100
@@ -368,31 +375,45 @@ class TestServer(portToBind: Int = 0) extends Logging {
368375
369376 val serverSocket = new ServerSocket (portToBind)
370377
378+ private val startLatch = new CountDownLatch (1 )
379+
371380 val servingThread = new Thread () {
372381 override def run () {
373382 try {
374383 while (true ) {
375384 logInfo(" Accepting connections on port " + port)
376385 val clientSocket = serverSocket.accept()
377- logInfo(" New connection" )
378- try {
379- clientSocket.setTcpNoDelay(true )
380- val outputStream = new BufferedWriter (
381- new OutputStreamWriter (clientSocket.getOutputStream))
382-
383- while (clientSocket.isConnected) {
384- val msg = queue.poll(100 , TimeUnit .MILLISECONDS )
385- if (msg != null ) {
386- outputStream.write(msg)
387- outputStream.flush()
388- logInfo(" Message '" + msg + " ' sent" )
386+ if (startLatch.getCount == 1 ) {
387+ // The first connection is a test connection to implement "waitForStart", so skip it
388+ // and send a signal
389+ if (! clientSocket.isClosed) {
390+ clientSocket.close()
391+ }
392+ startLatch.countDown()
393+ } else {
394+ // Real connections
395+ logInfo(" New connection" )
396+ try {
397+ clientSocket.setTcpNoDelay(true )
398+ val outputStream = new BufferedWriter (
399+ new OutputStreamWriter (clientSocket.getOutputStream))
400+
401+ while (clientSocket.isConnected) {
402+ val msg = queue.poll(100 , TimeUnit .MILLISECONDS )
403+ if (msg != null ) {
404+ outputStream.write(msg)
405+ outputStream.flush()
406+ logInfo(" Message '" + msg + " ' sent" )
407+ }
408+ }
409+ } catch {
410+ case e : SocketException => logError(" TestServer error" , e)
411+ } finally {
412+ logInfo(" Connection closed" )
413+ if (! clientSocket.isClosed) {
414+ clientSocket.close()
389415 }
390416 }
391- } catch {
392- case e : SocketException => logError(" TestServer error" , e)
393- } finally {
394- logInfo(" Connection closed" )
395- if (! clientSocket.isClosed) clientSocket.close()
396417 }
397418 }
398419 } catch {
@@ -404,7 +425,29 @@ class TestServer(portToBind: Int = 0) extends Logging {
404425 }
405426 }
406427
407- def start () { servingThread.start() }
428+ def start (): Unit = {
429+ servingThread.start()
430+ if (! waitForStart(10000 )) {
431+ stop()
432+ throw new AssertionError (" Timeout: TestServer cannot start in 10 seconds" )
433+ }
434+ }
435+
436+ /**
437+ * Wait until the server starts. Return true if the server starts in "millis" milliseconds.
438+ * Otherwise, return false to indicate it's timeout.
439+ */
440+ private def waitForStart (millis : Long ): Boolean = {
441+ // We will create a test connection to the server so that we can make sure it has started.
442+ val socket = new Socket (" localhost" , port)
443+ try {
444+ startLatch.await(millis, TimeUnit .MILLISECONDS )
445+ } finally {
446+ if (! socket.isClosed) {
447+ socket.close()
448+ }
449+ }
450+ }
408451
409452 def send (msg : String ) { queue.put(msg) }
410453
0 commit comments