-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19968][SS] Use a cached instance of KafkaProducer instead of creating one every batch.
#17308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19968][SS] Use a cached instance of KafkaProducer instead of creating one every batch.
#17308
Changes from 3 commits
8224596
e07e77e
d6e4088
3ec9981
c614bc0
e5cd1e6
ef2d6cd
d2b3ecd
15dfc80
039d063
1c9f892
588fa03
a10276a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.kafka010 | ||
|
|
||
| import java.{util => ju} | ||
|
|
||
| import scala.collection.immutable.SortedMap | ||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.kafka.clients.producer.KafkaProducer | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
|
|
||
| private[kafka010] object CachedKafkaProducer extends Logging { | ||
|
|
||
| private type Producer = KafkaProducer[Array[Byte], Array[Byte]] | ||
|
|
||
| private val cacheMap = new mutable.HashMap[String, Producer]() | ||
|
|
||
| private def createKafkaProducer( | ||
| producerConfiguration: ju.Map[String, Object]): Producer = { | ||
|
||
| val uid = producerConfiguration.get(CanonicalizeKafkaParams.sparkKafkaParamsUniqueId) | ||
|
||
| val kafkaProducer: Producer = new Producer(producerConfiguration) | ||
| cacheMap.put(uid.toString, kafkaProducer) | ||
| log.debug(s"Created a new instance of KafkaProducer for $producerConfiguration.") | ||
| kafkaProducer | ||
| } | ||
|
|
||
| private def getUniqueId(kafkaParams: ju.Map[String, Object]): String = { | ||
| val uid = kafkaParams.get(CanonicalizeKafkaParams.sparkKafkaParamsUniqueId) | ||
| assert(uid != null, s"KafkaParams($kafkaParams) not canonicalized") | ||
| uid.toString | ||
| } | ||
|
|
||
| /** | ||
| * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't | ||
| * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep | ||
| * one instance per specified kafkaParams. | ||
| */ | ||
| private[kafka010] def getOrCreate( | ||
| kafkaParams: ju.Map[String, Object]): Producer = synchronized { | ||
|
||
| val params = if (!CanonicalizeKafkaParams.isCanonicalized(kafkaParams)) { | ||
| CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| } else kafkaParams | ||
| val uid = getUniqueId(params) | ||
| cacheMap.getOrElse(uid.toString, createKafkaProducer(params)) | ||
| } | ||
|
|
||
| private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = { | ||
| val params = if (!CanonicalizeKafkaParams.isCanonicalized(kafkaParams)) { | ||
| CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| } else kafkaParams | ||
| val uid = getUniqueId(params) | ||
|
|
||
| val producer: Option[Producer] = cacheMap.remove(uid) | ||
|
|
||
| if (producer.isDefined) { | ||
| log.info(s"Closing the KafkaProducer with config: $kafkaParams") | ||
| producer.foreach(_.close()) | ||
| } else { | ||
| log.warn(s"No KafkaProducer found in cache for $kafkaParams.") | ||
| } | ||
| } | ||
|
|
||
| // Intended for testing purpose only. | ||
| private def clear(): Unit = { | ||
| cacheMap.foreach(x => x._2.close()) | ||
| cacheMap.clear() | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Canonicalize kafka params i.e. append a unique internal id to kafka params, if it already does | ||
| * not exist. This is done to ensure, we have only one set of kafka parameters associated with a | ||
| * unique ID. | ||
| */ | ||
| private[kafka010] object CanonicalizeKafkaParams extends Logging { | ||
|
||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| private val registryMap = mutable.HashMap[String, String]() | ||
|
|
||
| private[kafka010] val sparkKafkaParamsUniqueId: String = | ||
| "spark.internal.sql.kafka.params.uniqueId" | ||
|
|
||
| private def generateRandomUUID(kafkaParams: String): String = { | ||
| val uuid = ju.UUID.randomUUID().toString | ||
| logDebug(s"Generating a new unique id:$uuid for kafka params: $kafkaParams") | ||
| registryMap.put(kafkaParams, uuid) | ||
| uuid | ||
| } | ||
|
|
||
| private[kafka010] def isCanonicalized(kafkaParams: ju.Map[String, Object]): Boolean = { | ||
| if (kafkaParams.get(sparkKafkaParamsUniqueId) != null) { | ||
|
||
| true | ||
| } else { | ||
| false | ||
| } | ||
| } | ||
|
|
||
| private[kafka010] def computeUniqueCanonicalForm( | ||
| kafkaParams: ju.Map[String, Object]): ju.Map[String, Object] = synchronized { | ||
|
||
| if (isCanonicalized(kafkaParams)) { | ||
| logWarning(s"A unique id,$sparkKafkaParamsUniqueId ->" + | ||
|
||
| s" ${kafkaParams.get(sparkKafkaParamsUniqueId)}" + | ||
| s" already exists in kafka params, returning Kafka Params:$kafkaParams as is.") | ||
|
||
| kafkaParams | ||
| } else { | ||
| val sortedMap = SortedMap.empty[String, Object] ++ kafkaParams.asScala | ||
| val stringRepresentation: String = sortedMap.mkString("\n") | ||
| val uuid = | ||
| registryMap.getOrElse(stringRepresentation, generateRandomUUID(stringRepresentation)) | ||
| val newMap = new ju.HashMap[String, Object](kafkaParams) | ||
| newMap.put(sparkKafkaParamsUniqueId, uuid) | ||
| newMap | ||
| } | ||
| } | ||
|
|
||
| // For testing purpose only. | ||
| private[kafka010] def clear(): Unit = { | ||
| registryMap.clear() | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: no need for newline |
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,6 +22,7 @@ import java.{util => ju} | |
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.{DataFrame, SQLContext} | ||
| import org.apache.spark.sql.execution.streaming.Sink | ||
| import org.apache.spark.util.Utils | ||
|
||
|
|
||
| private[kafka010] class KafkaSink( | ||
| sqlContext: SQLContext, | ||
|
|
@@ -40,4 +41,8 @@ private[kafka010] class KafkaSink( | |
| latestBatchId = batchId | ||
| } | ||
| } | ||
|
|
||
| override def stop(): Unit = { | ||
| KafkaWriter.close(sqlContext.sparkContext, executorKafkaParams) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String | |
| * and not use wrong broker addresses. | ||
| */ | ||
| private[kafka010] class KafkaSource( | ||
| sqlContext: SQLContext, | ||
| kafkaReader: KafkaOffsetReader, | ||
| executorKafkaParams: ju.Map[String, Object], | ||
| sourceOptions: Map[String, String], | ||
| metadataPath: String, | ||
| startingOffsets: KafkaOffsetRangeLimit, | ||
| failOnDataLoss: Boolean) | ||
| sqlContext: SQLContext, | ||
|
||
| kafkaReader: KafkaOffsetReader, | ||
| executorKafkaParams: ju.Map[String, Object], | ||
| sourceOptions: Map[String, String], | ||
| metadataPath: String, | ||
| startingOffsets: KafkaOffsetRangeLimit, | ||
| failOnDataLoss: Boolean) | ||
| extends Source with Logging { | ||
|
|
||
| private val sc = sqlContext.sparkContext | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010 | |
|
|
||
| import java.{util => ju} | ||
|
|
||
| import org.apache.kafka.clients.producer.{KafkaProducer, _} | ||
| import org.apache.kafka.common.serialization.ByteArraySerializer | ||
| import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} | ||
|
|
@@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask( | |
| * Writes key value data out to topics. | ||
| */ | ||
| def execute(iterator: Iterator[InternalRow]): Unit = { | ||
| producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) | ||
| producer = CachedKafkaProducer.getOrCreate(producerConfiguration) | ||
| while (iterator.hasNext && failedWrite == null) { | ||
| val currentRow = iterator.next() | ||
| val projectedRow = projection(currentRow) | ||
|
|
@@ -68,11 +67,10 @@ private[kafka010] class KafkaWriteTask( | |
| } | ||
|
|
||
| def close(): Unit = { | ||
| checkForErrors() | ||
| if (producer != null) { | ||
| checkForErrors | ||
| producer.close() | ||
| checkForErrors | ||
| producer = null | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: please keep |
||
| producer.flush() | ||
| checkForErrors() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -88,7 +86,7 @@ private[kafka010] class KafkaWriteTask( | |
| case t => | ||
| throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + | ||
| s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + | ||
| s"must be a ${StringType}") | ||
| "must be a StringType") | ||
| } | ||
| val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) | ||
| .getOrElse(Literal(null, BinaryType)) | ||
|
|
@@ -100,7 +98,7 @@ private[kafka010] class KafkaWriteTask( | |
| } | ||
| val valueExpression = inputSchema | ||
| .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( | ||
| throw new IllegalStateException(s"Required attribute " + | ||
| throw new IllegalStateException("Required attribute " + | ||
| s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") | ||
| ) | ||
| valueExpression.dataType match { | ||
|
|
@@ -114,7 +112,7 @@ private[kafka010] class KafkaWriteTask( | |
| Cast(valueExpression, BinaryType)), inputSchema) | ||
| } | ||
|
|
||
| private def checkForErrors: Unit = { | ||
| private def checkForErrors(): Unit = { | ||
| if (failedWrite != null) { | ||
| throw failedWrite | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,9 +19,9 @@ package org.apache.spark.sql.kafka010 | |
|
|
||
| import java.{util => ju} | ||
|
|
||
| import org.apache.spark.SparkContext | ||
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.{AnalysisException, SparkSession} | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} | ||
| import org.apache.spark.sql.types.{BinaryType, StringType} | ||
|
|
@@ -94,4 +94,10 @@ private[kafka010] object KafkaWriter extends Logging { | |
| } | ||
| } | ||
| } | ||
|
|
||
| def close(sc: SparkContext, kafkaParams: ju.Map[String, Object]): Unit = { | ||
| sc.parallelize(1 to 10000).foreachPartition { iter => | ||
| CachedKafkaProducer.close(kafkaParams) | ||
| } | ||
|
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.kafka010 | ||
|
|
||
| import java.{util => ju} | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.kafka.clients.producer.KafkaProducer | ||
| import org.apache.kafka.common.serialization.ByteArraySerializer | ||
| import org.scalatest.PrivateMethodTester | ||
|
|
||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
||
| class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester { | ||
|
|
||
| type KP = KafkaProducer[Array[Byte], Array[Byte]] | ||
|
|
||
| protected override def beforeEach(): Unit = { | ||
| super.beforeEach() | ||
| val clear = PrivateMethod[Unit]('clear) | ||
| CachedKafkaProducer.invokePrivate(clear()) | ||
| CanonicalizeKafkaParams.clear() | ||
| } | ||
|
|
||
| test("Should return the cached instance on calling getOrCreate with same params.") { | ||
| val kafkaParams = new ju.HashMap[String, Object]() | ||
| kafkaParams.put("acks", "0") | ||
| // Here only host should be resolvable, it does not need a running instance of kafka server. | ||
| kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") | ||
| kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) | ||
| kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) | ||
| val kafkaParams2 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| val producer = CachedKafkaProducer.getOrCreate(kafkaParams2) | ||
| val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams2) | ||
| assert(producer == producer2) | ||
|
|
||
| val cacheMap = PrivateMethod[mutable.HashMap[Int, KP]]('cacheMap) | ||
| val map = CachedKafkaProducer.invokePrivate(cacheMap()) | ||
| assert(map.size == 1) | ||
| } | ||
|
|
||
| test("Should close the correct kafka producer for the given kafkaPrams.") { | ||
| val kafkaParams = new ju.HashMap[String, Object]() | ||
| kafkaParams.put("acks", "0") | ||
| kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") | ||
| kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) | ||
| kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) | ||
| val kafkaParams2 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| kafkaParams.put("acks", "1") | ||
| val kafkaParams3 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams2) | ||
|
|
||
| val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams3) | ||
| // With updated conf, a new producer instance should be created. | ||
| assert(producer != producer2) | ||
|
|
||
| val cacheMap = PrivateMethod[mutable.HashMap[Int, KP]]('cacheMap) | ||
| val map = CachedKafkaProducer.invokePrivate(cacheMap()) | ||
| assert(map.size == 2) | ||
|
|
||
| CachedKafkaProducer.close(kafkaParams3) | ||
| val map2 = CachedKafkaProducer.invokePrivate(cacheMap()) | ||
| assert(map2.size == 1) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We just know there is one KP by this assert. Seems we should also verify if we close the correct KP? |
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: indent 2 more spaces