Skip to content

Commit 8436f7e

Browse files
jerryshaotdas
authored andcommitted
[SPARK-7113] [STREAMING] Support input information reporting for Direct Kafka stream
Author: jerryshao <[email protected]> Closes #5879 from jerryshao/SPARK-7113 and squashes the following commits: b0b506c [jerryshao] Address the comments 0babe66 [jerryshao] Support input information reporting for Direct Kafka stream
1 parent 8776fe0 commit 8436f7e

File tree

3 files changed

+67
-6
lines changed

3 files changed

+67
-6
lines changed

external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
package org.apache.spark.streaming.kafka
1919

20-
2120
import scala.annotation.tailrec
2221
import scala.collection.mutable
2322
import scala.reflect.{classTag, ClassTag}
@@ -27,10 +26,10 @@ import kafka.message.MessageAndMetadata
2726
import kafka.serializer.Decoder
2827

2928
import org.apache.spark.{Logging, SparkException}
30-
import org.apache.spark.rdd.RDD
31-
import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
3229
import org.apache.spark.streaming.{StreamingContext, Time}
3330
import org.apache.spark.streaming.dstream._
31+
import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset
32+
import org.apache.spark.streaming.scheduler.InputInfo
3433

3534
/**
3635
* A stream of {@link org.apache.spark.streaming.kafka.KafkaRDD} where
@@ -117,6 +116,11 @@ class DirectKafkaInputDStream[
117116
val rdd = KafkaRDD[K, V, U, T, R](
118117
context.sparkContext, kafkaParams, currentOffsets, untilOffsets, messageHandler)
119118

119+
// Report the record number of this batch interval to InputInfoTracker.
120+
val numRecords = rdd.offsetRanges.map(r => r.untilOffset - r.fromOffset).sum
121+
val inputInfo = InputInfo(id, numRecords)
122+
ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
123+
120124
currentOffsets = untilOffsets.map(kv => kv._1 -> kv._2.offset)
121125
Some(rdd)
122126
}

external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.streaming.kafka
1919

2020
import java.io.File
21+
import java.util.concurrent.atomic.AtomicLong
2122

2223
import scala.collection.mutable
2324
import scala.collection.mutable.ArrayBuffer
@@ -34,6 +35,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext}
3435
import org.apache.spark.rdd.RDD
3536
import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
3637
import org.apache.spark.streaming.dstream.DStream
38+
import org.apache.spark.streaming.scheduler._
3739
import org.apache.spark.util.Utils
3840

3941
class DirectKafkaStreamSuite
@@ -290,7 +292,6 @@ class DirectKafkaStreamSuite
290292
},
291293
"Recovered ranges are not the same as the ones generated"
292294
)
293-
294295
// Restart context, give more data and verify the total at the end
295296
// If the total is write that means each records has been received only once
296297
ssc.start()
@@ -301,6 +302,44 @@ class DirectKafkaStreamSuite
301302
ssc.stop()
302303
}
303304

305+
test("Direct Kafka stream report input information") {
306+
val topic = "report-test"
307+
val data = Map("a" -> 7, "b" -> 9)
308+
kafkaTestUtils.createTopic(topic)
309+
kafkaTestUtils.sendMessages(topic, data)
310+
311+
val totalSent = data.values.sum
312+
val kafkaParams = Map(
313+
"metadata.broker.list" -> kafkaTestUtils.brokerAddress,
314+
"auto.offset.reset" -> "smallest"
315+
)
316+
317+
import DirectKafkaStreamSuite._
318+
ssc = new StreamingContext(sparkConf, Milliseconds(200))
319+
val collector = new InputInfoCollector
320+
ssc.addStreamingListener(collector)
321+
322+
val stream = withClue("Error creating direct stream") {
323+
KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
324+
ssc, kafkaParams, Set(topic))
325+
}
326+
327+
val allReceived = new ArrayBuffer[(String, String)]
328+
329+
stream.foreachRDD { rdd => allReceived ++= rdd.collect() }
330+
ssc.start()
331+
eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
332+
assert(allReceived.size === totalSent,
333+
"didn't get expected number of messages, messages:\n" + allReceived.mkString("\n"))
334+
335+
// Calculate all the record number collected in the StreamingListener.
336+
assert(collector.numRecordsSubmitted.get() === totalSent)
337+
assert(collector.numRecordsStarted.get() === totalSent)
338+
assert(collector.numRecordsCompleted.get() === totalSent)
339+
}
340+
ssc.stop()
341+
}
342+
304343
/** Get the generated offset ranges from the DirectKafkaStream */
305344
private def getOffsetRanges[K, V](
306345
kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
@@ -313,4 +352,22 @@ class DirectKafkaStreamSuite
313352
object DirectKafkaStreamSuite {
314353
val collectedData = new mutable.ArrayBuffer[String]()
315354
var total = -1L
355+
356+
class InputInfoCollector extends StreamingListener {
357+
val numRecordsSubmitted = new AtomicLong(0L)
358+
val numRecordsStarted = new AtomicLong(0L)
359+
val numRecordsCompleted = new AtomicLong(0L)
360+
361+
override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = {
362+
numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords)
363+
}
364+
365+
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = {
366+
numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords)
367+
}
368+
369+
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
370+
numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords)
371+
}
372+
}
316373
}

streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ private[streaming] class StreamingJobProgressListener(ssc: StreamingContext)
192192
val latestReceiverNumRecords = latestBatchInfos.map(_.receiverNumRecords)
193193
val streamIds = ssc.graph.getInputStreams().map(_.id)
194194
streamIds.map { id =>
195-
val recordsOfParticularReceiver =
196-
latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration)
195+
val recordsOfParticularReceiver =
196+
latestReceiverNumRecords.map(v => v.getOrElse(id, 0L).toDouble * 1000 / batchDuration)
197197
val distribution = Distribution(recordsOfParticularReceiver)
198198
(id, distribution)
199199
}.toMap

0 commit comments

Comments
 (0)