Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.spark.streaming

import java.io.{File, BufferedWriter, OutputStreamWriter}
import java.net.{SocketException, ServerSocket}
import java.net.{Socket, SocketException, ServerSocket}
import java.nio.charset.Charset
import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue}
Expand All @@ -36,58 +36,65 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener}
import org.apache.spark.util.{ManualClock, Utils}
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
import org.apache.spark.streaming.receiver.Receiver

class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {

test("socket input stream") {
// Start the server
val testServer = new TestServer()
testServer.start()
withTestServer(new TestServer()) { testServer =>
// Start the server
testServer.start()

// Set up the streaming context and input streams
val ssc = new StreamingContext(conf, batchDuration)
val networkStream = ssc.socketTextStream(
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
def output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
outputStream.register()
ssc.start()

// Feed data to the server to send to the network receiver
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val input = Seq(1, 2, 3, 4, 5)
val expectedOutput = input.map(_.toString)
Thread.sleep(1000)
for (i <- 0 until input.size) {
testServer.send(input(i).toString + "\n")
Thread.sleep(500)
clock.advance(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping server")
testServer.stop()
logInfo("Stopping context")
ssc.stop()

// Verify whether data received was as expected
logInfo("--------------------------------")
logInfo("output.size = " + outputBuffer.size)
logInfo("output")
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("expected output.size = " + expectedOutput.size)
logInfo("expected output")
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")
// Set up the streaming context and input streams
withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc =>
val input = Seq(1, 2, 3, 4, 5)
// Use "batchCount" to make sure we check the result after all batches finish
val batchCounter = new BatchCounter(ssc)
val networkStream = ssc.socketTextStream(
"localhost", testServer.port, StorageLevel.MEMORY_AND_DISK)
val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]]
val outputStream = new TestOutputStream(networkStream, outputBuffer)
outputStream.register()
ssc.start()

// Verify whether all the elements received are as expected
// (whether the elements were received one in each interval is not verified)
assert(output.size === expectedOutput.size)
for (i <- 0 until output.size) {
assert(output(i) === expectedOutput(i))
// Feed data to the server to send to the network receiver
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val expectedOutput = input.map(_.toString)
for (i <- 0 until input.size) {
testServer.send(input(i).toString + "\n")
Thread.sleep(500)
clock.advance(batchDuration.milliseconds)
}
// Make sure we finish all batches before "stop"
if (!batchCounter.waitUntilBatchesCompleted(input.size, 30000)) {
fail("Timeout: cannot finish all batches in 30 seconds")
}
logInfo("Stopping server")
testServer.stop()
logInfo("Stopping context")
ssc.stop()

// Verify whether data received was as expected
logInfo("--------------------------------")
logInfo("output.size = " + outputBuffer.size)
logInfo("output")
outputBuffer.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("expected output.size = " + expectedOutput.size)
logInfo("expected output")
expectedOutput.foreach(x => logInfo("[" + x.mkString(",") + "]"))
logInfo("--------------------------------")

// Verify whether all the elements received are as expected
// (whether the elements were received one in each interval is not verified)
val output: ArrayBuffer[String] = outputBuffer.flatMap(x => x)
assert(output.size === expectedOutput.size)
for (i <- 0 until output.size) {
assert(output(i) === expectedOutput(i))
}
}
}
}

Expand Down Expand Up @@ -368,31 +375,45 @@ class TestServer(portToBind: Int = 0) extends Logging {

val serverSocket = new ServerSocket(portToBind)

private val startLatch = new CountDownLatch(1)

val servingThread = new Thread() {
override def run() {
try {
while(true) {
logInfo("Accepting connections on port " + port)
val clientSocket = serverSocket.accept()
logInfo("New connection")
try {
clientSocket.setTcpNoDelay(true)
val outputStream = new BufferedWriter(
new OutputStreamWriter(clientSocket.getOutputStream))

while(clientSocket.isConnected) {
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
if (msg != null) {
outputStream.write(msg)
outputStream.flush()
logInfo("Message '" + msg + "' sent")
if (startLatch.getCount == 1) {
// The first connection is a test connection to implement "waitForStart", so skip it
// and send a signal
if (!clientSocket.isClosed) {
clientSocket.close()
}
startLatch.countDown()
} else {
// Real connections
logInfo("New connection")
try {
clientSocket.setTcpNoDelay(true)
val outputStream = new BufferedWriter(
new OutputStreamWriter(clientSocket.getOutputStream))

while (clientSocket.isConnected) {
val msg = queue.poll(100, TimeUnit.MILLISECONDS)
if (msg != null) {
outputStream.write(msg)
outputStream.flush()
logInfo("Message '" + msg + "' sent")
}
}
} catch {
case e: SocketException => logError("TestServer error", e)
} finally {
logInfo("Connection closed")
if (!clientSocket.isClosed) {
clientSocket.close()
}
}
} catch {
case e: SocketException => logError("TestServer error", e)
} finally {
logInfo("Connection closed")
if (!clientSocket.isClosed) clientSocket.close()
}
}
} catch {
Expand All @@ -404,7 +425,29 @@ class TestServer(portToBind: Int = 0) extends Logging {
}
}

def start() { servingThread.start() }
def start(): Unit = {
servingThread.start()
if (!waitForStart(10000)) {
stop()
throw new AssertionError("Timeout: TestServer cannot start in 10 seconds")
}
}

/**
* Wait until the server starts. Return true if the server starts in "millis" milliseconds.
* Otherwise, return false to indicate it's timeout.
*/
private def waitForStart(millis: Long): Boolean = {
// We will create a test connection to the server so that we can make sure it has started.
val socket = new Socket("localhost", port)
try {
startLatch.await(millis, TimeUnit.MILLISECONDS)
} finally {
if (!socket.isClosed) {
socket.close()
}
}
}

def send(msg: String) { queue.put(msg) }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,40 @@ class BatchCounter(ssc: StreamingContext) {
def getNumStartedBatches: Int = this.synchronized {
numStartedBatches
}

/**
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's
* timeout.
*
* @param expectedNumCompletedBatches the `expectedNumCompletedBatches` batches to wait
* @param timeout the maximum time to wait in milliseconds.
*/
def waitUntilBatchesCompleted(expectedNumCompletedBatches: Int, timeout: Long): Boolean =
waitUntilConditionBecomeTrue(numCompletedBatches >= expectedNumCompletedBatches, timeout)

/**
* Wait until `expectedNumStartedBatches` batches are completed, or timeout. Return true if
* `expectedNumStartedBatches` batches are completed. Otherwise, return false to indicate it's
* timeout.
*
* @param expectedNumStartedBatches the `expectedNumStartedBatches` batches to wait
* @param timeout the maximum time to wait in milliseconds.
*/
def waitUntilBatchesStarted(expectedNumStartedBatches: Int, timeout: Long): Boolean =
waitUntilConditionBecomeTrue(numStartedBatches >= expectedNumStartedBatches, timeout)

private def waitUntilConditionBecomeTrue(condition: => Boolean, timeout: Long): Boolean = {
synchronized {
var now = System.currentTimeMillis()
val timeoutTick = now + timeout
while (!condition && timeoutTick > now) {
wait(timeoutTick - now)
now = System.currentTimeMillis()
}
condition
}
}
}

/**
Expand Down