1818package org .apache .spark .streaming .kafka
1919
2020import java .io .File
21+ import java .util .concurrent .atomic .AtomicLong
2122
2223import scala .collection .mutable
2324import scala .collection .mutable .ArrayBuffer
@@ -330,18 +331,13 @@ class DirectKafkaStreamSuite
330331 eventually(timeout(20000 .milliseconds), interval(200 .milliseconds)) {
331332 assert(allReceived.size === totalSent,
332333 " didn't get expected number of messages, messages:\n " + allReceived.mkString(" \n " ))
334+
335+ // Calculate all the record number collected in the StreamingListener.
336+ assert(collector.numRecordsSubmitted.get() === totalSent)
337+ assert(collector.numRecordsStarted.get() === totalSent)
338+ assert(collector.numRecordsCompleted.get() === totalSent)
333339 }
334340 ssc.stop()
335-
336- // Calculate all the record number collected in the StreamingListener.
337- val numRecordsSubmitted = collector.streamIdToNumRecordsSubmitted.map(_.values.sum).sum
338- assert(numRecordsSubmitted === totalSent)
339-
340- val numRecordsStarted = collector.streamIdToNumRecordsStarted.map(_.values.sum).sum
341- assert(numRecordsStarted === totalSent)
342-
343- val numRecordsCompleted = collector.streamIdToNumRecordsCompleted.map(_.values.sum).sum
344- assert(numRecordsCompleted === totalSent)
345341 }
346342
347343 /** Get the generated offset ranges from the DirectKafkaStream */
@@ -358,22 +354,20 @@ object DirectKafkaStreamSuite {
358354 var total = - 1L
359355
360356 class InputInfoCollector extends StreamingListener {
361- val streamIdToNumRecordsSubmitted = new ArrayBuffer [ Map [ Int , Long ]]( )
362- val streamIdToNumRecordsStarted = new ArrayBuffer [ Map [ Int , Long ]]( )
363- val streamIdToNumRecordsCompleted = new ArrayBuffer [ Map [ Int , Long ]]( )
357+ val numRecordsSubmitted = new AtomicLong ( 0L )
358+ val numRecordsStarted = new AtomicLong ( 0L )
359+ val numRecordsCompleted = new AtomicLong ( 0L )
364360
365- override def onBatchSubmitted (batchSubmitted : StreamingListenerBatchSubmitted ): Unit =
366- synchronized {
367- streamIdToNumRecordsSubmitted += batchSubmitted.batchInfo.streamIdToNumRecords
361+ override def onBatchSubmitted (batchSubmitted : StreamingListenerBatchSubmitted ): Unit = {
362+ numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords)
368363 }
369364
370- override def onBatchStarted (batchStarted : StreamingListenerBatchStarted ): Unit = synchronized {
371- streamIdToNumRecordsStarted += batchStarted.batchInfo.streamIdToNumRecords
365+ override def onBatchStarted (batchStarted : StreamingListenerBatchStarted ): Unit = {
366+ numRecordsStarted.addAndGet( batchStarted.batchInfo.numRecords)
372367 }
373368
374- override def onBatchCompleted (batchCompleted : StreamingListenerBatchCompleted ): Unit =
375- synchronized {
376- streamIdToNumRecordsCompleted += batchCompleted.batchInfo.streamIdToNumRecords
369+ override def onBatchCompleted (batchCompleted : StreamingListenerBatchCompleted ): Unit = {
370+ numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords)
377371 }
378372 }
379373}
0 commit comments