diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala index 5f0802b466039..9d9bded8acaf4 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamWriter.scala @@ -19,18 +19,23 @@ package org.apache.spark.sql.kafka010 import scala.collection.JavaConverters._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery +import org.apache.spark.sql.sources.v2.CustomMetrics import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics} import org.apache.spark.sql.types.StructType /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we * don't need to really send one. */ -case object KafkaWriterCommitMessage extends WriterCommitMessage +case class KafkaWriterCommitMessage(minOffset: KafkaSourceOffset, maxOffset: KafkaSourceOffset) + extends WriterCommitMessage /** * A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory. @@ -42,15 +47,25 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage */ class KafkaStreamWriter( topic: Option[String], producerParams: Map[String, String], schema: StructType) - extends StreamWriter { + extends StreamWriter with SupportsCustomWriterMetrics { + + private var customMetrics: KafkaWriterCustomMetrics = _ validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic) override def createWriterFactory(): KafkaStreamWriterFactory = KafkaStreamWriterFactory(topic, producerParams, schema) - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + customMetrics = KafkaWriterCustomMetrics(messages) + } + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def getCustomMetrics: KafkaWriterCustomMetrics = { + customMetrics + } + } /** @@ -102,7 +117,9 @@ class KafkaStreamDataWriter( checkForErrors() producer.flush() checkForErrors() - KafkaWriterCommitMessage + val minOffset: KafkaSourceOffset = KafkaSourceOffset(minOffsetAccumulator.toMap) + val maxOffset: KafkaSourceOffset = KafkaSourceOffset(maxOffsetAccumulator.toMap) + KafkaWriterCommitMessage(minOffset, maxOffset) } def abort(): Unit = {} @@ -116,3 +133,66 @@ class KafkaStreamDataWriter( } } } + +private[kafka010] case class KafkaWriterCustomMetrics( + minOffset: KafkaSourceOffset, + maxOffset: KafkaSourceOffset) extends CustomMetrics { + override def json(): String = { + val jsonVal = ("minOffset" -> parse(minOffset.json)) ~ + ("maxOffset" -> parse(maxOffset.json)) + compact(render(jsonVal)) + } + + override def toString: String = json() +} + +private[kafka010] object KafkaWriterCustomMetrics { + + import Math.{min, max} + + def apply(messages: Array[WriterCommitMessage]): KafkaWriterCustomMetrics = { + val minMax = collate(messages) + KafkaWriterCustomMetrics(minMax._1, minMax._2) + } + + private def collate(messages: Array[WriterCommitMessage]): + (KafkaSourceOffset, KafkaSourceOffset) = { + + messages.headOption.flatMap { + case x: KafkaWriterCommitMessage => + val lower = messages.map(_.asInstanceOf[KafkaWriterCommitMessage]) + .map(_.minOffset).reduce(collateLower) + val higher = messages.map(_.asInstanceOf[KafkaWriterCommitMessage]) + .map(_.maxOffset).reduce(collateHigher) + Some((lower, higher)) + case _ => throw new IllegalArgumentException() + }.getOrElse((KafkaSourceOffset(), KafkaSourceOffset())) + } + + private def collateHigher(o1: KafkaSourceOffset, o2: KafkaSourceOffset): KafkaSourceOffset = { + collate(o1, o2, max) + } + + private def collateLower(o1: KafkaSourceOffset, o2: KafkaSourceOffset): KafkaSourceOffset = { + collate(o1, o2, min) + } + + private def collate( + o1: KafkaSourceOffset, + o2: KafkaSourceOffset, + collator: (Long, Long) => Long): KafkaSourceOffset = { + val thisOffsets = o1.partitionToOffsets + val thatOffsets = o2.partitionToOffsets + val collated = (thisOffsets.keySet ++ thatOffsets.keySet) + .map(key => + if (!thatOffsets.contains(key)) { + key -> thisOffsets(key) + } else if (!thisOffsets.contains(key)) { + key -> thatOffsets(key) + } else { + key -> collator(thisOffsets(key), thatOffsets(key)) + } + ).toMap + new KafkaSourceOffset(collated) + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 041fac7717635..12cf3706f450b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -18,8 +18,10 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.concurrent.ConcurrentHashMap import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} +import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} @@ -61,12 +63,30 @@ private[kafka010] class KafkaWriteTask( private[kafka010] abstract class KafkaRowWriter( inputSchema: Seq[Attribute], topic: Option[String]) { + import scala.collection.JavaConverters._ + + protected val minOffsetAccumulator: collection.concurrent.Map[TopicPartition, Long] = + new ConcurrentHashMap[TopicPartition, Long]().asScala + + protected val maxOffsetAccumulator: collection.concurrent.Map[TopicPartition, Long] = + new ConcurrentHashMap[TopicPartition, Long]().asScala + // used to synchronize with Kafka callbacks @volatile protected var failedWrite: Exception = _ protected val projection = createProjection private val callback = new Callback() { + import Math.{min, max} + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (recordMetadata != null) { + val topicPartition = new TopicPartition(recordMetadata.topic(), recordMetadata.partition()) + val next = recordMetadata.offset() + val currentMin = minOffsetAccumulator.getOrElse(topicPartition, next) + minOffsetAccumulator.put(topicPartition, min(currentMin, next)) + val currentMax = maxOffsetAccumulator.getOrElse(topicPartition, next) + maxOffsetAccumulator.put(topicPartition, max(currentMax, next)) + } if (failedWrite == null && e != null) { failedWrite = e } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index a2213e024bd98..72b1d95afcbaf 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkException @@ -38,7 +39,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { protected var testUtils: KafkaTestUtils = _ - override val streamingTimeout = 30.seconds + override val streamingTimeout: Span = 30.seconds override def beforeAll(): Unit = { super.beforeAll() @@ -229,6 +230,30 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { } } + test("streaming - sink progress is produced") { + /* ensure sink progress is correctly produced. */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))() + + try { + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + val topicName = topic.toString + val expected = "{\"minOffset\":{\"" + topicName + "\":{\"0\":0}}," + + "\"maxOffset\":{\"" + topicName + "\":{\"0\":2}}}" + assert(writer.lastProgress.sink.customMetrics == expected) + } finally { + writer.stop() + } + } test("streaming - write data with bad schema") { val input = MemoryStream[String] @@ -417,7 +442,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest { var stream: DataStreamWriter[Row] = null withTempDir { checkpointDir => var df = input.toDF() - if (withSelectExpr.length > 0) { + if (withSelectExpr.nonEmpty) { df = df.selectExpr(withSelectExpr: _*) } stream = df.writeStream diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaWriterCustomMetricsSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaWriterCustomMetricsSuite.scala new file mode 100644 index 0000000000000..11a3680a71d58 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaWriterCustomMetricsSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage + +class KafkaWriterCustomMetricsSuite extends SparkFunSuite { + + test("collate messages") { + val minOffset1 = KafkaSourceOffset(("topic1", 1, 2), ("topic1", 2, 3)) + val maxOffset1 = KafkaSourceOffset(("topic1", 1, 2), ("topic1", 2, 5)) + val minOffset2 = KafkaSourceOffset(("topic1", 1, 0), ("topic1", 2, 3)) + val maxOffset2 = KafkaSourceOffset(("topic1", 1, 0), ("topic1", 2, 7)) + val messages: Array[WriterCommitMessage] = Array( + KafkaWriterCommitMessage(minOffset1, maxOffset1), + KafkaWriterCommitMessage(minOffset2, maxOffset2)) + val metrics = KafkaWriterCustomMetrics(messages) + assert(metrics.minOffset === KafkaSourceOffset(("topic1", 1, 0), ("topic1", 2, 3))) + assert(metrics.maxOffset === KafkaSourceOffset(("topic1", 1, 2), ("topic1", 2, 7))) + } +}