Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ package org.apache.spark.sql.kafka010

import scala.collection.JavaConverters._

import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.sources.v2.CustomMetrics
import org.apache.spark.sql.sources.v2.writer._
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics}
import org.apache.spark.sql.types.StructType

/**
* Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
* don't need to really send one.
*/
case object KafkaWriterCommitMessage extends WriterCommitMessage
case class KafkaWriterCommitMessage(minOffset: KafkaSourceOffset, maxOffset: KafkaSourceOffset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its kind of odd that the writer commit message includes source offset. IMO, better to define a KafkaSinkOffset or if it can be common, something like KafkaOffsets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have to rename the class itself to not add additional duplicate class. I would love to do that, it is just that I am not sure if it would be accepted.

extends WriterCommitMessage

/**
* A [[StreamWriter]] for Kafka writing. Responsible for generating the writer factory.
Expand All @@ -42,15 +47,25 @@ case object KafkaWriterCommitMessage extends WriterCommitMessage
*/
class KafkaStreamWriter(
topic: Option[String], producerParams: Map[String, String], schema: StructType)
extends StreamWriter {
extends StreamWriter with SupportsCustomWriterMetrics {

private var customMetrics: KafkaWriterCustomMetrics = _

validateQuery(schema.toAttributes, producerParams.toMap[String, Object].asJava, topic)

override def createWriterFactory(): KafkaStreamWriterFactory =
KafkaStreamWriterFactory(topic, producerParams, schema)

override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
customMetrics = KafkaWriterCustomMetrics(messages)
}

override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

override def getCustomMetrics: KafkaWriterCustomMetrics = {
customMetrics
}

}

/**
Expand Down Expand Up @@ -102,7 +117,9 @@ class KafkaStreamDataWriter(
checkForErrors()
producer.flush()
checkForErrors()
KafkaWriterCommitMessage
val minOffset: KafkaSourceOffset = KafkaSourceOffset(minOffsetAccumulator.toMap)
val maxOffset: KafkaSourceOffset = KafkaSourceOffset(maxOffsetAccumulator.toMap)
KafkaWriterCommitMessage(minOffset, maxOffset)
}

def abort(): Unit = {}
Expand All @@ -116,3 +133,66 @@ class KafkaStreamDataWriter(
}
}
}

private[kafka010] case class KafkaWriterCustomMetrics(
minOffset: KafkaSourceOffset,
maxOffset: KafkaSourceOffset) extends CustomMetrics {
override def json(): String = {
val jsonVal = ("minOffset" -> parse(minOffset.json)) ~
("maxOffset" -> parse(maxOffset.json))
compact(render(jsonVal))
}

override def toString: String = json()
}

private[kafka010] object KafkaWriterCustomMetrics {

import Math.{min, max}

def apply(messages: Array[WriterCommitMessage]): KafkaWriterCustomMetrics = {
val minMax = collate(messages)
KafkaWriterCustomMetrics(minMax._1, minMax._2)
}

private def collate(messages: Array[WriterCommitMessage]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good to leave some comment on what this does. It seems to be computing the min/max offset per partition? If so choosing an apt name for that function would make it clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will rename to something with minMax.

(KafkaSourceOffset, KafkaSourceOffset) = {

messages.headOption.flatMap {
case x: KafkaWriterCommitMessage =>
val lower = messages.map(_.asInstanceOf[KafkaWriterCommitMessage])
.map(_.minOffset).reduce(collateLower)
val higher = messages.map(_.asInstanceOf[KafkaWriterCommitMessage])
.map(_.maxOffset).reduce(collateHigher)
Some((lower, higher))
case _ => throw new IllegalArgumentException()
}.getOrElse((KafkaSourceOffset(), KafkaSourceOffset()))
}

private def collateHigher(o1: KafkaSourceOffset, o2: KafkaSourceOffset): KafkaSourceOffset = {
collate(o1, o2, max)
}

private def collateLower(o1: KafkaSourceOffset, o2: KafkaSourceOffset): KafkaSourceOffset = {
collate(o1, o2, min)
}

private def collate(
o1: KafkaSourceOffset,
o2: KafkaSourceOffset,
collator: (Long, Long) => Long): KafkaSourceOffset = {
val thisOffsets = o1.partitionToOffsets
val thatOffsets = o2.partitionToOffsets
val collated = (thisOffsets.keySet ++ thatOffsets.keySet)
.map(key =>
if (!thatOffsets.contains(key)) {
key -> thisOffsets(key)
} else if (!thisOffsets.contains(key)) {
key -> thatOffsets(key)
} else {
key -> collator(thisOffsets(key), thatOffsets(key))
}
).toMap
new KafkaSourceOffset(collated)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.util.concurrent.ConcurrentHashMap

import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
import org.apache.kafka.common.TopicPartition

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
Expand Down Expand Up @@ -61,12 +63,30 @@ private[kafka010] class KafkaWriteTask(
private[kafka010] abstract class KafkaRowWriter(
inputSchema: Seq[Attribute], topic: Option[String]) {

import scala.collection.JavaConverters._

protected val minOffsetAccumulator: collection.concurrent.Map[TopicPartition, Long] =
new ConcurrentHashMap[TopicPartition, Long]().asScala
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this concurrent map?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This map is accessed in callbacks concurrently with respect to different partitions. Can be seen from call hierarchy and docs of Kafka's send method.


protected val maxOffsetAccumulator: collection.concurrent.Map[TopicPartition, Long] =
new ConcurrentHashMap[TopicPartition, Long]().asScala

// used to synchronize with Kafka callbacks
@volatile protected var failedWrite: Exception = _
protected val projection = createProjection

private val callback = new Callback() {
import Math.{min, max}

override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = {
if (recordMetadata != null) {
val topicPartition = new TopicPartition(recordMetadata.topic(), recordMetadata.partition())
val next = recordMetadata.offset()
val currentMin = minOffsetAccumulator.getOrElse(topicPartition, next)
minOffsetAccumulator.put(topicPartition, min(currentMin, next))
val currentMax = maxOffsetAccumulator.getOrElse(topicPartition, next)
maxOffsetAccumulator.put(topicPartition, max(currentMax, next))
}
if (failedWrite == null && e != null) {
failedWrite = e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicInteger

import org.apache.kafka.clients.producer.ProducerConfig
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.time.Span
import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkException
Expand All @@ -38,7 +39,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest {

protected var testUtils: KafkaTestUtils = _

override val streamingTimeout = 30.seconds
override val streamingTimeout: Span = 30.seconds

override def beforeAll(): Unit = {
super.beforeAll()
Expand Down Expand Up @@ -229,6 +230,30 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest {
}
}

test("streaming - sink progress is produced") {
/* ensure sink progress is correctly produced. */
val input = MemoryStream[String]
val topic = newTopic()
testUtils.createTopic(topic)

val writer = createKafkaWriter(
input.toDF(),
withTopic = Some(topic),
withOutputMode = Some(OutputMode.Update()))()

try {
input.addData("1", "2", "3")
failAfter(streamingTimeout) {
writer.processAllAvailable()
}
val topicName = topic.toString
val expected = "{\"minOffset\":{\"" + topicName + "\":{\"0\":0}}," +
"\"maxOffset\":{\"" + topicName + "\":{\"0\":2}}}"
assert(writer.lastProgress.sink.customMetrics == expected)
} finally {
writer.stop()
}
}

test("streaming - write data with bad schema") {
val input = MemoryStream[String]
Expand Down Expand Up @@ -417,7 +442,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext with KafkaTest {
var stream: DataStreamWriter[Row] = null
withTempDir { checkpointDir =>
var df = input.toDF()
if (withSelectExpr.length > 0) {
if (withSelectExpr.nonEmpty) {
df = df.selectExpr(withSelectExpr: _*)
}
stream = df.writeStream
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage

class KafkaWriterCustomMetricsSuite extends SparkFunSuite {

test("collate messages") {
val minOffset1 = KafkaSourceOffset(("topic1", 1, 2), ("topic1", 2, 3))
val maxOffset1 = KafkaSourceOffset(("topic1", 1, 2), ("topic1", 2, 5))
val minOffset2 = KafkaSourceOffset(("topic1", 1, 0), ("topic1", 2, 3))
val maxOffset2 = KafkaSourceOffset(("topic1", 1, 0), ("topic1", 2, 7))
val messages: Array[WriterCommitMessage] = Array(
KafkaWriterCommitMessage(minOffset1, maxOffset1),
KafkaWriterCommitMessage(minOffset2, maxOffset2))
val metrics = KafkaWriterCustomMetrics(messages)
assert(metrics.minOffset === KafkaSourceOffset(("topic1", 1, 0), ("topic1", 2, 3)))
assert(metrics.maxOffset === KafkaSourceOffset(("topic1", 1, 2), ("topic1", 2, 7)))
}
}