Skip to content

Commit 120b81e

Browse files
Fix complexity in threading model in test
1 parent 4df5be6 commit 120b81e

File tree

1 file changed

+37
-39
lines changed

1 file changed

+37
-39
lines changed

external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala

Lines changed: 37 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicInteger
2121
import java.util.concurrent.{CountDownLatch, Executors}
2222

2323
import scala.collection.JavaConversions._
24-
import scala.concurrent.{Promise, Future}
25-
import scala.util.{Failure, Success, Try}
24+
import scala.concurrent.{ExecutionContext, Future}
25+
import scala.util.{Failure, Success}
2626

2727
import com.google.common.util.concurrent.ThreadFactoryBuilder
2828
import org.apache.avro.ipc.NettyTransceiver
@@ -38,7 +38,7 @@ class SparkSinkSuite extends TestSuiteBase {
3838
val channelCapacity = 5000
3939

4040
test("Success") {
41-
val (channel, sink) = initializeChannelAndSink(None)
41+
val (channel, sink) = initializeChannelAndSink()
4242
channel.start()
4343
sink.start()
4444

@@ -58,7 +58,7 @@ class SparkSinkSuite extends TestSuiteBase {
5858
}
5959

6060
test("Nack") {
61-
val (channel, sink) = initializeChannelAndSink(None)
61+
val (channel, sink) = initializeChannelAndSink()
6262
channel.start()
6363
sink.start()
6464
putEvents(channel, eventsPerBatch)
@@ -77,8 +77,8 @@ class SparkSinkSuite extends TestSuiteBase {
7777
}
7878

7979
test("Timeout") {
80-
val (channel, sink) = initializeChannelAndSink(Option(Map(SparkSinkConfig
81-
.CONF_TRANSACTION_TIMEOUT -> 1.toString)))
80+
val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig
81+
.CONF_TRANSACTION_TIMEOUT -> 1.toString))
8282
channel.start()
8383
sink.start()
8484
putEvents(channel, eventsPerBatch)
@@ -96,69 +96,67 @@ class SparkSinkSuite extends TestSuiteBase {
9696
}
9797

9898
test("Multiple consumers") {
99-
multipleClients(failSome = false)
99+
testMultipleConsumers(failSome = false)
100100
}
101101

102-
test("Multiple consumers With Some Failures") {
103-
multipleClients(failSome = true)
102+
test("Multiple consumers with some failures") {
103+
testMultipleConsumers(failSome = true)
104104
}
105105

106-
def multipleClients(failSome: Boolean): Unit = {
107-
import scala.concurrent.ExecutionContext.Implicits.global
108-
val (channel, sink) = initializeChannelAndSink(None)
106+
def testMultipleConsumers(failSome: Boolean): Unit = {
107+
implicit val executorContext = ExecutionContext
108+
.fromExecutorService(Executors.newFixedThreadPool(5))
109+
val (channel, sink) = initializeChannelAndSink()
109110
channel.start()
110111
sink.start()
111-
(1 to 5).map(_ => putEvents(channel, eventsPerBatch))
112+
(1 to 5).foreach(_ => putEvents(channel, eventsPerBatch))
112113
val port = sink.getPort
113114
val address = new InetSocketAddress("0.0.0.0", port)
114-
115-
val transAndClient = getTransceiverAndClient(address, 5)
115+
val transceiversAndClients = getTransceiverAndClient(address, 5)
116116
val batchCounter = new CountDownLatch(5)
117117
val counter = new AtomicInteger(0)
118-
transAndClient.foreach(x => {
119-
val promise = Promise[EventBatch]()
120-
val future = promise.future
118+
transceiversAndClients.foreach(x => {
121119
Future {
122120
val client = x._2
123121
var events: EventBatch = null
124-
Try {
125-
events = client.getEventBatch(1000)
126-
if(!failSome || counter.getAndIncrement() % 2 == 0) {
127-
client.ack(events.getSequenceNumber)
128-
} else {
129-
client.nack(events.getSequenceNumber)
130-
}
131-
}.map(_ => promise.success(events)).recover({
132-
case e => promise.failure(e)
133-
})
134-
}
135-
future.onComplete {
136-
case Success(events) => assert(events.getEvents.size() === 1000)
122+
events = client.getEventBatch(1000)
123+
if (!failSome || counter.getAndIncrement() % 2 == 0) {
124+
client.ack(events.getSequenceNumber)
125+
} else {
126+
client.nack(events.getSequenceNumber)
127+
throw new RuntimeException("Sending NACK for failure!")
128+
}
129+
events
130+
}.onComplete {
131+
case Success(events) =>
132+
assert(events.getEvents.size() === 1000)
133+
batchCounter.countDown()
134+
case Failure(t) =>
135+
// Don't re-throw the exception, causes a nasty unnecessary stack trace on stdout
137136
batchCounter.countDown()
138-
case Failure(t) => batchCounter.countDown()
139-
throw t
140137
}
141138
})
142139
batchCounter.await()
140+
executorContext.shutdown()
143141
if(failSome) {
144142
assert(availableChannelSlots(channel) === 3000)
145143
} else {
146144
assertChannelIsEmpty(channel)
147145
}
148146
sink.stop()
149147
channel.stop()
150-
transAndClient.foreach(x => x._1.close())
148+
transceiversAndClients.foreach(x => x._1.close())
151149
}
152150

153-
def initializeChannelAndSink(overrides: Option[Map[String, String]]): (MemoryChannel,
151+
private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty): (MemoryChannel,
154152
SparkSink) = {
155153
val channel = new MemoryChannel()
156154
val channelContext = new Context()
157155

158156
channelContext.put("capacity", channelCapacity.toString)
159157
channelContext.put("transactionCapacity", 1000.toString)
160158
channelContext.put("keep-alive", 0.toString)
161-
overrides.foreach(channelContext.putAll(_))
159+
channelContext.putAll(overrides)
162160
channel.configure(channelContext)
163161

164162
val sink = new SparkSink()
@@ -173,7 +171,7 @@ class SparkSinkSuite extends TestSuiteBase {
173171
private def putEvents(ch: MemoryChannel, count: Int): Unit = {
174172
val tx = ch.getTransaction
175173
tx.begin()
176-
(1 to count).map(x => ch.put(EventBuilder.withBody(x.toString.getBytes)))
174+
(1 to count).foreach(x => ch.put(EventBuilder.withBody(x.toString.getBytes)))
177175
tx.commit()
178176
tx.close()
179177
}
@@ -193,8 +191,8 @@ class SparkSinkSuite extends TestSuiteBase {
193191
})
194192
}
195193

196-
private def assertChannelIsEmpty(channel: MemoryChannel) = {
197-
assert(availableChannelSlots(channel) === 5000)
194+
private def assertChannelIsEmpty(channel: MemoryChannel): Unit = {
195+
assert(availableChannelSlots(channel) === channelCapacity)
198196
}
199197

200198
private def availableChannelSlots(channel: MemoryChannel): Int = {

0 commit comments

Comments
 (0)