Skip to content

Commit b0b506c

Browse files
committed
Address the comments
1 parent 0babe66 commit b0b506c

File tree

2 files changed

+18
-25
lines changed

2 files changed

+18
-25
lines changed

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,9 @@ class DirectKafkaInputDStream[
116116
val rdd = KafkaRDD[K, V, U, T, R](
117117
context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)
118118

119-
// Report the number of records of the batch interval to InputInfoTracker.
120-
val currentNumRecords = currentOffsets.map(_._2).sum
121-
val toBeProcessedNumRecords = untilOffsets.map(_._2.offset).sum
122-
val inputInfo = InputInfo(id, toBeProcessedNumRecords - currentNumRecords)
119+
// Report the record number of this batch interval to InputInfoTracker.
120+
val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum
121+
val inputInfo = InputInfo(id, numRecords)
123122
ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
124123

125124
currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset)

external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.streaming.kafka
1919

2020
import java.io.File
21+
import java.util.concurrent.atomic.AtomicLong
2122

2223
import scala.collection.mutable
2324
import scala.collection.mutable.ArrayBuffer
@@ -330,18 +331,13 @@ class DirectKafkaStreamSuite
330331
eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
331332
assert(allReceived.size === totalSent,
332333
"didn't get expected number of messages, messages:\n" + allReceived.mkString("\n"))
334+
335+
// Calculate all the record number collected in the StreamingListener.
336+
assert(collector.numRecordsSubmitted.get() === totalSent)
337+
assert(collector.numRecordsStarted.get() === totalSent)
338+
assert(collector.numRecordsCompleted.get() === totalSent)
333339
}
334340
ssc.stop()
335-
336-
// Calculate all the record number collected in the StreamingListener.
337-
val numRecordsSubmitted = collector.streamIdToNumRecordsSubmitted.map(_.values.sum).sum
338-
assert(numRecordsSubmitted === totalSent)
339-
340-
val numRecordsStarted = collector.streamIdToNumRecordsStarted.map(_.values.sum).sum
341-
assert(numRecordsStarted === totalSent)
342-
343-
val numRecordsCompleted = collector.streamIdToNumRecordsCompleted.map(_.values.sum).sum
344-
assert(numRecordsCompleted === totalSent)
345341
}
346342

347343
/** Get the generated offset ranges from the DirectKafkaStream */
@@ -358,22 +354,20 @@ object DirectKafkaStreamSuite {
358354
var total = -1L
359355

360356
class InputInfoCollector extends StreamingListener {
361-
val streamIdToNumRecordsSubmitted = new ArrayBuffer[Map[Int, Long]]()
362-
val streamIdToNumRecordsStarted = new ArrayBuffer[Map[Int, Long]]()
363-
val streamIdToNumRecordsCompleted = new ArrayBuffer[Map[Int, Long]]()
357+
val numRecordsSubmitted = new AtomicLong(0L)
358+
val numRecordsStarted = new AtomicLong(0L)
359+
val numRecordsCompleted = new AtomicLong(0L)
364360

365-
override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit =
366-
synchronized {
367-
streamIdToNumRecordsSubmitted += batchSubmitted.batchInfo.streamIdToNumRecords
361+
override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = {
362+
numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords)
368363
}
369364

370-
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = synchronized {
371-
streamIdToNumRecordsStarted += batchStarted.batchInfo.streamIdToNumRecords
365+
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = {
366+
numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords)
372367
}
373368

374-
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
375-
synchronized {
376-
streamIdToNumRecordsCompleted += batchCompleted.batchInfo.streamIdToNumRecords
369+
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
370+
numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords)
377371
}
378372
}
379373
}

0 commit comments

Comments
 (0)