diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala index 60255fc655e5f..778c06ea16a2b 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala @@ -104,6 +104,8 @@ private case class Subscribe[K, V]( toSeek.asScala.foreach { case (topicPartition, offset) => consumer.seek(topicPartition, offset) } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) } consumer @@ -154,6 +156,8 @@ private case class SubscribePattern[K, V]( toSeek.asScala.foreach { case (topicPartition, offset) => consumer.seek(topicPartition, offset) } + // we've called poll, we must pause or next poll may consume messages and set position + consumer.pause(consumer.assignment()) } consumer diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 13827f68f2cb5..432537ebf05b2 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -161,12 +161,31 @@ private[spark] class DirectKafkaInputDStream[K, V]( } } + /** + * The concern here is that poll might consume messages despite being paused, + * which would throw off consumer position. Fix position if this happens. + */ + private def paranoidPoll(c: Consumer[K, V]): Unit = { + val msgs = c.poll(0) + if (!msgs.isEmpty) { + // position should be minimum offset per topicpartition + msgs.asScala.foldLeft(Map[TopicPartition, Long]()) { (acc, m) => + val tp = new TopicPartition(m.topic, m.partition) + val off = acc.get(tp).map(o => Math.min(o, m.offset)).getOrElse(m.offset) + acc + (tp -> off) + }.foreach { case (tp, off) => + logInfo(s"poll(0) returned messages, seeking $tp to $off to compensate") + c.seek(tp, off) + } + } + } + /** * Returns the latest (highest) available offsets, taking new partitions into account. */ protected def latestOffsets(): Map[TopicPartition, Long] = { val c = consumer - c.poll(0) + paranoidPoll(c) val parts = c.assignment().asScala // make sure new partitions are reflected in currentOffsets @@ -223,7 +242,7 @@ private[spark] class DirectKafkaInputDStream[K, V]( override def start(): Unit = { val c = consumer - c.poll(0) + paranoidPoll(c) if (currentOffsets.isEmpty) { currentOffsets = c.assignment().asScala.map { tp => tp -> c.position(tp) diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala index e04f35eceb1b4..02aec43c3b34f 100644 --- a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala +++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala @@ -159,17 +159,19 @@ class DirectKafkaStreamSuite } test("pattern based subscription") { - val topics = List("pat1", "pat2", "advanced3") - // Should match 2 out of 3 topics + val topics = List("pat1", "pat2", "pat3", "advanced3") + // Should match 3 out of 4 topics val pat = """pat\d""".r.pattern val data = Map("a" -> 7, "b" -> 9) topics.foreach { t => kafkaTestUtils.createTopic(t) kafkaTestUtils.sendMessages(t, data) } - val offsets = Map(new TopicPartition("pat2", 0) -> 3L) - // 2 matching topics, one of which starts 3 messages later - val expectedTotal = (data.values.sum * 2) - 3 + val offsets = Map( + new TopicPartition("pat2", 0) -> 3L, + new TopicPartition("pat3", 0) -> 4L) + // 3 matching topics, two of which start a total of 7 messages later + val expectedTotal = (data.values.sum * 3) - 7 val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest") ssc = new StreamingContext(sparkConf, Milliseconds(1000))