-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-19968][SS] Use a cached instance of KafkaProducer instead of creating one every batch.
#17308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19968][SS] Use a cached instance of KafkaProducer instead of creating one every batch.
#17308
Changes from 10 commits
8224596
e07e77e
d6e4088
3ec9981
c614bc0
e5cd1e6
ef2d6cd
d2b3ecd
15dfc80
039d063
1c9f892
588fa03
a10276a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,174 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.kafka010 | ||
|
|
||
| import java.{util => ju} | ||
| import java.util.concurrent.{ConcurrentMap, TimeUnit} | ||
| import javax.annotation.concurrent.GuardedBy | ||
|
|
||
| import com.google.common.cache.{Cache, CacheBuilder, RemovalListener, RemovalNotification} | ||
| import org.apache.kafka.clients.producer.KafkaProducer | ||
| import scala.collection.JavaConverters._ | ||
| import scala.collection.immutable.SortedMap | ||
| import scala.collection.mutable | ||
| import scala.util.control.NonFatal | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.util.ShutdownHookManager | ||
|
|
||
| private[kafka010] object CachedKafkaProducer extends Logging { | ||
|
|
||
| private type Producer = KafkaProducer[Array[Byte], Array[Byte]] | ||
|
|
||
| private val cacheExpireTimeout: Long = | ||
| System.getProperty("spark.kafka.guava.cache.timeout.minutes", "10").toLong | ||
|
|
||
| private val removalListener = new RemovalListener[String, Producer]() { | ||
| override def onRemoval(notification: RemovalNotification[String, Producer]): Unit = { | ||
| val uid: String = notification.getKey | ||
| val producer: Producer = notification.getValue | ||
| logDebug(s"Evicting kafka producer $producer uid: $uid, due to ${notification.getCause}") | ||
| close(uid, producer) | ||
| } | ||
| } | ||
|
|
||
| private val guavaCache: Cache[String, Producer] = CacheBuilder.newBuilder() | ||
| .recordStats() | ||
|
||
| .expireAfterAccess(cacheExpireTimeout, TimeUnit.MINUTES) | ||
| .removalListener(removalListener) | ||
| .build[String, Producer]() | ||
|
|
||
| ShutdownHookManager.addShutdownHook { () => | ||
| clear() | ||
|
||
| } | ||
|
|
||
| private def createKafkaProducer( | ||
| producerConfiguration: ju.Map[String, Object]): Producer = { | ||
|
||
| val uid = getUniqueId(producerConfiguration) | ||
| val kafkaProducer: Producer = new Producer(producerConfiguration) | ||
| guavaCache.put(uid.toString, kafkaProducer) | ||
| logDebug(s"Created a new instance of KafkaProducer for $producerConfiguration.") | ||
| kafkaProducer | ||
| } | ||
|
|
||
| private def getUniqueId(kafkaParams: ju.Map[String, Object]): String = { | ||
| val uid = kafkaParams.get(CanonicalizeKafkaParams.sparkKafkaParamsUniqueId) | ||
| assert(uid != null, s"KafkaParams($kafkaParams) not canonicalized.") | ||
| uid.toString | ||
| } | ||
|
|
||
| /** | ||
| * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't | ||
| * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep | ||
| * one instance per specified kafkaParams. | ||
| */ | ||
| private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = synchronized { | ||
| val params = if (!CanonicalizeKafkaParams.isCanonicalized(kafkaParams)) { | ||
| CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| } else { | ||
| kafkaParams | ||
| } | ||
| val uid = getUniqueId(params) | ||
| Option(guavaCache.getIfPresent(uid)).getOrElse(createKafkaProducer(params)) | ||
| } | ||
|
|
||
| /** For explicitly closing kafka producer */ | ||
| private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = { | ||
| val params = if (!CanonicalizeKafkaParams.isCanonicalized(kafkaParams)) { | ||
| CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| } else kafkaParams | ||
| val uid = getUniqueId(params) | ||
| guavaCache.invalidate(uid) | ||
| } | ||
|
|
||
| /** Auto close on cache evict */ | ||
| private def close(uid: String, producer: Producer): Unit = { | ||
| try { | ||
| val outcome = CanonicalizeKafkaParams.remove( | ||
| new ju.HashMap[String, Object]( | ||
| Map(CanonicalizeKafkaParams.sparkKafkaParamsUniqueId -> uid).asJava)) | ||
| logDebug(s"Removed kafka params from cache: $outcome.") | ||
| logInfo(s"Closing the KafkaProducer with uid: $uid.") | ||
| producer.close() | ||
| } catch { | ||
| case NonFatal(e) => logWarning("Error while closing kafka producer.", e) | ||
| } | ||
| } | ||
|
|
||
| private def clear(): Unit = { | ||
| logInfo("Cleaning up guava cache.") | ||
| guavaCache.invalidateAll() | ||
| } | ||
|
|
||
| // Intended for testing purpose only. | ||
| private def getAsMap: ConcurrentMap[String, Producer] = guavaCache.asMap() | ||
| } | ||
|
|
||
| /** | ||
| * Canonicalize kafka params i.e. append a unique internal id to kafka params, if it already does | ||
| * not exist. This is done to ensure, we have only one set of kafka parameters associated with a | ||
| * unique ID. | ||
| */ | ||
| private[kafka010] object CanonicalizeKafkaParams extends Logging { | ||
|
||
|
|
||
| @GuardedBy("this") | ||
| private val registryMap = mutable.HashMap[String, String]() | ||
|
|
||
| private[kafka010] val sparkKafkaParamsUniqueId: String = | ||
| "spark.internal.sql.kafka.params.uniqueId" | ||
|
|
||
| private def generateRandomUUID(kafkaParams: String): String = { | ||
| val uuid = ju.UUID.randomUUID().toString | ||
| logDebug(s"Generating a new unique id: $uuid for kafka params: $kafkaParams") | ||
| registryMap.put(kafkaParams, uuid) | ||
| uuid | ||
| } | ||
|
|
||
| private[kafka010] def isCanonicalized(kafkaParams: ju.Map[String, Object]): Boolean = { | ||
| kafkaParams.get(sparkKafkaParamsUniqueId) != null | ||
| } | ||
|
|
||
| private[kafka010] def computeUniqueCanonicalForm( | ||
| kafkaParams: ju.Map[String, Object]): ju.Map[String, Object] = synchronized { | ||
|
||
| if (isCanonicalized(kafkaParams)) { | ||
| logWarning(s"A unique id, $sparkKafkaParamsUniqueId ->" + | ||
| s" ${kafkaParams.get(sparkKafkaParamsUniqueId)}" + | ||
| s" already exists in kafka params, returning Kafka Params: $kafkaParams as is.") | ||
| kafkaParams | ||
| } else { | ||
| val sortedMap = SortedMap.empty[String, Object] ++ kafkaParams.asScala | ||
| val stringRepresentation: String = sortedMap.mkString("\n") | ||
| val uuid = | ||
| registryMap.getOrElse(stringRepresentation, generateRandomUUID(stringRepresentation)) | ||
| val newMap = new ju.HashMap[String, Object](kafkaParams) | ||
| newMap.put(sparkKafkaParamsUniqueId, uuid) | ||
| newMap | ||
| } | ||
| } | ||
|
|
||
| private[kafka010] def remove(kafkaParams: ju.Map[String, Object]): Boolean = { | ||
| val sortedMap = SortedMap.empty[String, Object] ++ kafkaParams.asScala | ||
| val stringRepresentation: String = sortedMap.mkString("\n") | ||
| registryMap.remove(stringRepresentation).isDefined | ||
| } | ||
|
|
||
| // For testing purpose only. | ||
| private[kafka010] def clear(): Unit = { | ||
| registryMap.clear() | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010 | |
|
|
||
| import java.{util => ju} | ||
|
|
||
| import org.apache.kafka.clients.producer.{KafkaProducer, _} | ||
| import org.apache.kafka.common.serialization.ByteArraySerializer | ||
| import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata} | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} | ||
|
|
@@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask( | |
| * Writes key value data out to topics. | ||
| */ | ||
| def execute(iterator: Iterator[InternalRow]): Unit = { | ||
| producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) | ||
| producer = CachedKafkaProducer.getOrCreate(producerConfiguration) | ||
| while (iterator.hasNext && failedWrite == null) { | ||
| val currentRow = iterator.next() | ||
| val projectedRow = projection(currentRow) | ||
|
|
@@ -68,11 +67,10 @@ private[kafka010] class KafkaWriteTask( | |
| } | ||
|
|
||
| def close(): Unit = { | ||
| checkForErrors() | ||
| if (producer != null) { | ||
| checkForErrors | ||
| producer.close() | ||
| checkForErrors | ||
| producer = null | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: please keep |
||
| producer.flush() | ||
| checkForErrors() | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -88,7 +86,7 @@ private[kafka010] class KafkaWriteTask( | |
| case t => | ||
| throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + | ||
| s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + | ||
| s"must be a ${StringType}") | ||
| "must be a StringType") | ||
| } | ||
| val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) | ||
| .getOrElse(Literal(null, BinaryType)) | ||
|
|
@@ -100,7 +98,7 @@ private[kafka010] class KafkaWriteTask( | |
| } | ||
| val valueExpression = inputSchema | ||
| .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( | ||
| throw new IllegalStateException(s"Required attribute " + | ||
| throw new IllegalStateException("Required attribute " + | ||
| s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") | ||
| ) | ||
| valueExpression.dataType match { | ||
|
|
@@ -114,7 +112,7 @@ private[kafka010] class KafkaWriteTask( | |
| Cast(valueExpression, BinaryType)), inputSchema) | ||
| } | ||
|
|
||
| private def checkForErrors: Unit = { | ||
| private def checkForErrors(): Unit = { | ||
| if (failedWrite != null) { | ||
| throw failedWrite | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.kafka010 | ||
|
|
||
| import java.{util => ju} | ||
| import java.util.concurrent.ConcurrentMap | ||
|
|
||
| import org.apache.kafka.clients.producer.KafkaProducer | ||
| import org.apache.kafka.common.serialization.ByteArraySerializer | ||
| import org.scalatest.PrivateMethodTester | ||
|
|
||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
||
| class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester { | ||
|
|
||
| type KP = KafkaProducer[Array[Byte], Array[Byte]] | ||
|
|
||
| protected override def beforeEach(): Unit = { | ||
| super.beforeEach() | ||
| val clear = PrivateMethod[Unit]('clear) | ||
| CachedKafkaProducer.invokePrivate(clear()) | ||
| CanonicalizeKafkaParams.clear() | ||
| } | ||
|
|
||
| test("Should return the cached instance on calling getOrCreate with same params.") { | ||
| val kafkaParams = new ju.HashMap[String, Object]() | ||
| kafkaParams.put("acks", "0") | ||
| // Here only host should be resolvable, it does not need a running instance of kafka server. | ||
| kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") | ||
| kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) | ||
| kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) | ||
| val kafkaParams2 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| val producer = CachedKafkaProducer.getOrCreate(kafkaParams2) | ||
| val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams2) | ||
| assert(producer == producer2) | ||
|
|
||
| val cacheMap = PrivateMethod[ConcurrentMap[String, Option[KP]]]('getAsMap) | ||
| val map = CachedKafkaProducer.invokePrivate(cacheMap()) | ||
| assert(map.size == 1) | ||
| } | ||
|
|
||
| test("Should close the correct kafka producer for the given kafkaPrams.") { | ||
| val kafkaParams = new ju.HashMap[String, Object]() | ||
| kafkaParams.put("acks", "0") | ||
| kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") | ||
| kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) | ||
| kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) | ||
| val kafkaParams2 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| kafkaParams.put("acks", "1") | ||
| val kafkaParams3 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams) | ||
| val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams2) | ||
|
|
||
| val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams3) | ||
| // With updated conf, a new producer instance should be created. | ||
| assert(producer != producer2) | ||
|
|
||
| val cacheMap = PrivateMethod[ConcurrentMap[String, Option[KP]]]('getAsMap) | ||
| val map = CachedKafkaProducer.invokePrivate(cacheMap()) | ||
| assert(map.size == 2) | ||
|
|
||
| CachedKafkaProducer.close(kafkaParams3) | ||
| val map2 = CachedKafkaProducer.invokePrivate(cacheMap()) | ||
| assert(map2.size == 1) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We just know there is one KP by this assert. Seems we should also verify if we close the correct KP? |
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't we need to get this from
SparkEnvby the way? I don't know if the properties get populated properly.Also, adding
minutesto the conf makes it kinda long right? I think we can also replaceguavawithproducer.I think it may also be better to use this so that we get rid of
minutesand users can actually provide arbitrary durations (hours if they want). I think that's what we generally use fordurationtype confs.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, you are right !