@@ -19,18 +19,23 @@ package org.apache.spark.sql.kafka010
1919
2020import scala .collection .JavaConverters ._
2121
22+ import org .json4s .JsonDSL ._
23+ import org .json4s .jackson .JsonMethods ._
24+
2225import org .apache .spark .sql .catalyst .InternalRow
2326import org .apache .spark .sql .catalyst .expressions .Attribute
2427import org .apache .spark .sql .kafka010 .KafkaWriter .validateQuery
28+ import org .apache .spark .sql .sources .v2 .CustomMetrics
2529import org .apache .spark .sql .sources .v2 .writer ._
26- import org .apache .spark .sql .sources .v2 .writer .streaming .StreamWriter
30+ import org .apache .spark .sql .sources .v2 .writer .streaming .{ StreamWriter , SupportsCustomWriterMetrics }
2731import org .apache .spark .sql .types .StructType
2832
2933/**
3034 * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
3135 * don't need to really send one.
3236 */
33- case object KafkaWriterCommitMessage extends WriterCommitMessage
37+ case class KafkaWriterCommitMessage (minOffset : KafkaSourceOffset , maxOffset : KafkaSourceOffset )
38+ extends WriterCommitMessage
3439
3540/**
3641 * A [[StreamWriter ]] for Kafka writing. Responsible for generating the writer factory.
@@ -42,15 +47,25 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
4247 */
4348class KafkaStreamWriter (
4449 topic : Option [String ], producerParams : Map [String , String ], schema : StructType )
45- extends StreamWriter {
50+ extends StreamWriter with SupportsCustomWriterMetrics {
51+
52+ private var customMetrics : KafkaWriterCustomMetrics = _
4653
4754 validateQuery(schema.toAttributes, producerParams.toMap[String , Object ].asJava, topic)
4855
4956 override def createWriterFactory (): KafkaStreamWriterFactory =
5057 KafkaStreamWriterFactory (topic, producerParams, schema)
5158
52- override def commit (epochId : Long , messages : Array [WriterCommitMessage ]): Unit = {}
59+ override def commit (epochId : Long , messages : Array [WriterCommitMessage ]): Unit = {
60+ customMetrics = KafkaWriterCustomMetrics (messages)
61+ }
62+
5363 override def abort (epochId : Long , messages : Array [WriterCommitMessage ]): Unit = {}
64+
65+ override def getCustomMetrics : KafkaWriterCustomMetrics = {
66+ customMetrics
67+ }
68+
5469}
5570
5671/**
@@ -102,7 +117,9 @@ class KafkaStreamDataWriter(
102117 checkForErrors()
103118 producer.flush()
104119 checkForErrors()
105- KafkaWriterCommitMessage
120+ val minOffset : KafkaSourceOffset = KafkaSourceOffset (minOffsetAccumulator.toMap)
121+ val maxOffset : KafkaSourceOffset = KafkaSourceOffset (maxOffsetAccumulator.toMap)
122+ KafkaWriterCommitMessage (minOffset, maxOffset)
106123 }
107124
108125 def abort (): Unit = {}
@@ -116,3 +133,66 @@ class KafkaStreamDataWriter(
116133 }
117134 }
118135}
136+
137+ private [kafka010] case class KafkaWriterCustomMetrics (
138+ minOffset : KafkaSourceOffset ,
139+ maxOffset : KafkaSourceOffset ) extends CustomMetrics {
140+ override def json (): String = {
141+ val jsonVal = (" minOffset" -> parse(minOffset.json)) ~
142+ (" maxOffset" -> parse(maxOffset.json))
143+ compact(render(jsonVal))
144+ }
145+
146+ override def toString : String = json()
147+ }
148+
149+ private [kafka010] object KafkaWriterCustomMetrics {
150+
151+ import Math .{min , max }
152+
153+ def apply (messages : Array [WriterCommitMessage ]): KafkaWriterCustomMetrics = {
154+ val minMax = collate(messages)
155+ KafkaWriterCustomMetrics (minMax._1, minMax._2)
156+ }
157+
158+ private def collate (messages : Array [WriterCommitMessage ]):
159+ (KafkaSourceOffset , KafkaSourceOffset ) = {
160+
161+ messages.headOption.flatMap {
162+ case x : KafkaWriterCommitMessage =>
163+ val lower = messages.map(_.asInstanceOf [KafkaWriterCommitMessage ])
164+ .map(_.minOffset).reduce(collateLower)
165+ val higher = messages.map(_.asInstanceOf [KafkaWriterCommitMessage ])
166+ .map(_.maxOffset).reduce(collateHigher)
167+ Some ((lower, higher))
168+ case _ => throw new IllegalArgumentException ()
169+ }.getOrElse((KafkaSourceOffset (), KafkaSourceOffset ()))
170+ }
171+
172+ private def collateHigher (o1 : KafkaSourceOffset , o2 : KafkaSourceOffset ): KafkaSourceOffset = {
173+ collate(o1, o2, max)
174+ }
175+
176+ private def collateLower (o1 : KafkaSourceOffset , o2 : KafkaSourceOffset ): KafkaSourceOffset = {
177+ collate(o1, o2, min)
178+ }
179+
180+ private def collate (
181+ o1 : KafkaSourceOffset ,
182+ o2 : KafkaSourceOffset ,
183+ collator : (Long , Long ) => Long ): KafkaSourceOffset = {
184+ val thisOffsets = o1.partitionToOffsets
185+ val thatOffsets = o2.partitionToOffsets
186+ val collated = (thisOffsets.keySet ++ thatOffsets.keySet)
187+ .map(key =>
188+ if (! thatOffsets.contains(key)) {
189+ key -> thisOffsets(key)
190+ } else if (! thisOffsets.contains(key)) {
191+ key -> thatOffsets(key)
192+ } else {
193+ key -> collator(thisOffsets(key), thatOffsets(key))
194+ }
195+ ).toMap
196+ new KafkaSourceOffset (collated)
197+ }
198+ }
0 commit comments