diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala deleted file mode 100644 index fc177cdc9037..000000000000 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaProducer.scala +++ /dev/null @@ -1,128 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} -import java.util.concurrent.{ConcurrentMap, ExecutionException, TimeUnit} - -import com.google.common.cache._ -import com.google.common.util.concurrent.{ExecutionError, UncheckedExecutionException} -import org.apache.kafka.clients.producer.KafkaProducer -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal - -import org.apache.spark.SparkEnv -import org.apache.spark.internal.Logging -import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil} - -private[kafka010] object CachedKafkaProducer extends Logging { - - private type Producer = KafkaProducer[Array[Byte], Array[Byte]] - - private val defaultCacheExpireTimeout = TimeUnit.MINUTES.toMillis(10) - - private lazy val cacheExpireTimeout: Long = Option(SparkEnv.get) - .map(_.conf.get(PRODUCER_CACHE_TIMEOUT)) - .getOrElse(defaultCacheExpireTimeout) - - private val cacheLoader = new CacheLoader[Seq[(String, Object)], Producer] { - override def load(config: Seq[(String, Object)]): Producer = { - createKafkaProducer(config) - } - } - - private val removalListener = new RemovalListener[Seq[(String, Object)], Producer]() { - override def onRemoval( - notification: RemovalNotification[Seq[(String, Object)], Producer]): Unit = { - val paramsSeq: Seq[(String, Object)] = notification.getKey - val producer: Producer = notification.getValue - if (log.isDebugEnabled()) { - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) - logDebug(s"Evicting kafka producer $producer params: $redactedParamsSeq, " + - s"due to ${notification.getCause}") - } - close(paramsSeq, producer) - } - } - - private lazy val guavaCache: LoadingCache[Seq[(String, Object)], Producer] = - CacheBuilder.newBuilder().expireAfterAccess(cacheExpireTimeout, TimeUnit.MILLISECONDS) - .removalListener(removalListener) - .build[Seq[(String, Object)], Producer](cacheLoader) - - private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = { - val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava) - if (log.isDebugEnabled()) { - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) - logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.") - } - kafkaProducer - } - - /** - * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't - * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep - * one instance per specified kafkaParams. - */ - private[kafka010] def getOrCreate(kafkaParams: ju.Map[String, Object]): Producer = { - val updatedKafkaProducerConfiguration = - KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) - .setAuthenticationConfigIfNeeded() - .build() - val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration) - try { - guavaCache.get(paramsSeq) - } catch { - case e @ (_: ExecutionException | _: UncheckedExecutionException | _: ExecutionError) - if e.getCause != null => - throw e.getCause - } - } - - private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { - val paramsSeq: Seq[(String, Object)] = kafkaParams.asScala.toSeq.sortBy(x => x._1) - paramsSeq - } - - /** For explicitly closing kafka producer */ - private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = { - val paramsSeq = paramsToSeq(kafkaParams) - guavaCache.invalidate(paramsSeq) - } - - /** Auto close on cache evict */ - private def close(paramsSeq: Seq[(String, Object)], producer: Producer): Unit = { - try { - if (log.isInfoEnabled()) { - val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) - logInfo(s"Closing the KafkaProducer with params: ${redactedParamsSeq.mkString("\n")}.") - } - producer.close() - } catch { - case NonFatal(e) => logWarning("Error while closing kafka producer.", e) - } - } - - private[kafka010] def clear(): Unit = { - logInfo("Cleaning up guava cache.") - guavaCache.invalidateAll() - } - - // Intended for testing purpose only. - private def getAsMap: ConcurrentMap[Seq[(String, Object)], Producer] = guavaCache.asMap() -} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala index 9a2b36993361..63863a6cc6d6 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaDataWriter.scala @@ -22,6 +22,7 @@ import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool} /** * Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we @@ -44,11 +45,14 @@ private[kafka010] class KafkaDataWriter( inputSchema: Seq[Attribute]) extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] { - private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams) + private var producer: Option[CachedKafkaProducer] = None def write(row: InternalRow): Unit = { checkForErrors() - sendRow(row, producer) + if (producer.isEmpty) { + producer = Some(InternalKafkaProducerPool.acquire(producerParams)) + } + producer.foreach { p => sendRow(row, p.producer) } } def commit(): WriterCommitMessage = { @@ -56,22 +60,15 @@ private[kafka010] class KafkaDataWriter( // This requires flushing and then checking that no callbacks produced errors. // We also check for errors before to fail as soon as possible - the check is cheap. checkForErrors() - producer.flush() + producer.foreach(_.producer.flush()) checkForErrors() KafkaDataWriterCommitMessage } def abort(): Unit = {} - def close(): Unit = {} - - /** explicitly invalidate producer from pool. only for testing. */ - private[kafka010] def invalidateProducer(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() - checkForErrors() - CachedKafkaProducer.close(producerParams) - } + def close(): Unit = { + producer.foreach(InternalKafkaProducerPool.release) + producer = None } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 8b907065af1d..fddba3f0f991 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -27,6 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection} +import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool} import org.apache.spark.sql.types.BinaryType /** @@ -39,25 +40,30 @@ private[kafka010] class KafkaWriteTask( inputSchema: Seq[Attribute], topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) { // used to synchronize with Kafka callbacks - private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + private var producer: Option[CachedKafkaProducer] = None /** * Writes key value data out to topics. */ def execute(iterator: Iterator[InternalRow]): Unit = { - producer = CachedKafkaProducer.getOrCreate(producerConfiguration) + producer = Some(InternalKafkaProducerPool.acquire(producerConfiguration)) + val internalProducer = producer.get.producer while (iterator.hasNext && failedWrite == null) { val currentRow = iterator.next() - sendRow(currentRow, producer) + sendRow(currentRow, internalProducer) } } def close(): Unit = { - checkForErrors() - if (producer != null) { - producer.flush() + try { checkForErrors() - producer = null + producer.foreach { p => + p.producer.flush() + checkForErrors() + } + } finally { + producer.foreach(InternalKafkaProducerPool.release) + producer = None } } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala index 6f6ae55fc497..460bb8bd34ec 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/package.scala @@ -32,6 +32,13 @@ package object kafka010 { // scalastyle:ignore .timeConf(TimeUnit.MILLISECONDS) .createWithDefaultString("10m") + private[kafka010] val PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL = + ConfigBuilder("spark.kafka.producer.cache.evictorThreadRunInterval") + .doc("The interval of time between runs of the idle evictor thread for producer pool. " + + "When non-positive, no idle evictor thread will be run.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("1m") + private[kafka010] val CONSUMER_CACHE_CAPACITY = ConfigBuilder("spark.kafka.consumer.cache.capacity") .doc("The maximum number of consumers cached. Please note it's a soft limit" + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala new file mode 100644 index 000000000000..83519de0d3b1 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/CachedKafkaProducer.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.producer + +import java.{util => ju} + +import scala.util.control.NonFatal + +import org.apache.kafka.clients.producer.KafkaProducer + +import org.apache.spark.internal.Logging + +private[kafka010] class CachedKafkaProducer( + val cacheKey: Seq[(String, Object)], + val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging { + val id: String = ju.UUID.randomUUID().toString + + private[producer] def close(): Unit = { + try { + logInfo(s"Closing the KafkaProducer with id: $id.") + producer.close() + } catch { + case NonFatal(e) => logWarning("Error while closing kafka producer.", e) + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala new file mode 100644 index 000000000000..7a0c68eb74a3 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPool.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.producer + +import java.{util => ju} +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.kafka.clients.producer.KafkaProducer + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.kafka010.{KafkaConfigUpdater, KafkaRedactionUtil} +import org.apache.spark.sql.kafka010.{PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL, PRODUCER_CACHE_TIMEOUT} +import org.apache.spark.util.{Clock, ShutdownHookManager, SystemClock, ThreadUtils, Utils} + +/** + * Provides object pool for [[CachedKafkaProducer]] which is grouped by + * [[org.apache.spark.sql.kafka010.producer.InternalKafkaProducerPool.CacheKey]]. + */ +private[producer] class InternalKafkaProducerPool( + executorService: ScheduledExecutorService, + val clock: Clock, + conf: SparkConf) extends Logging { + import InternalKafkaProducerPool._ + + def this(sparkConf: SparkConf) = { + this(ThreadUtils.newDaemonSingleThreadScheduledExecutor( + "kafka-producer-cache-evictor"), new SystemClock, sparkConf) + } + + /** exposed for testing */ + private[producer] val cacheExpireTimeoutMillis: Long = conf.get(PRODUCER_CACHE_TIMEOUT) + + @GuardedBy("this") + private val cache = new mutable.HashMap[CacheKey, CachedProducerEntry] + + private def startEvictorThread(): Option[ScheduledFuture[_]] = { + val evictorThreadRunIntervalMillis = conf.get(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL) + if (evictorThreadRunIntervalMillis > 0) { + val future = executorService.scheduleAtFixedRate(() => { + Utils.tryLogNonFatalError(evictExpired()) + }, 0, evictorThreadRunIntervalMillis, TimeUnit.MILLISECONDS) + Some(future) + } else { + None + } + } + + private var scheduled = startEvictorThread() + + /** + * Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't + * exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep + * one instance per specified kafkaParams. + */ + private[producer] def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = { + val updatedKafkaProducerConfiguration = + KafkaConfigUpdater("executor", kafkaParams.asScala.toMap) + .setAuthenticationConfigIfNeeded() + .build() + val paramsSeq: Seq[(String, Object)] = paramsToSeq(updatedKafkaProducerConfiguration) + synchronized { + val entry = cache.getOrElseUpdate(paramsSeq, { + val producer = createKafkaProducer(paramsSeq) + val cachedProducer = new CachedKafkaProducer(paramsSeq, producer) + new CachedProducerEntry(cachedProducer, + TimeUnit.MILLISECONDS.toNanos(cacheExpireTimeoutMillis)) + }) + entry.handleBorrowed() + entry.producer + } + } + + private[producer] def release(producer: CachedKafkaProducer): Unit = { + synchronized { + cache.get(producer.cacheKey) match { + case Some(entry) if entry.producer.id == producer.id => + entry.handleReturned(clock.nanoTime()) + case _ => + logWarning(s"Released producer ${producer.id} is not a member of the cache. Closing.") + producer.close() + } + } + } + + private[producer] def shutdown(): Unit = { + scheduled.foreach(_.cancel(false)) + ThreadUtils.shutdown(executorService) + } + + /** exposed for testing. */ + private[producer] def reset(): Unit = synchronized { + cache.foreach { case (_, v) => v.producer.close() } + cache.clear() + } + + /** exposed for testing */ + private[producer] def getAsMap: Map[CacheKey, CachedProducerEntry] = cache.toMap + + private def evictExpired(): Unit = { + val curTimeNs = clock.nanoTime() + val producers = new mutable.ArrayBuffer[CachedProducerEntry]() + synchronized { + cache.retain { case (_, v) => + if (v.expired(curTimeNs)) { + producers += v + false + } else { + true + } + } + } + producers.foreach { _.producer.close() } + } + + private def createKafkaProducer(paramsSeq: Seq[(String, Object)]): Producer = { + val kafkaProducer: Producer = new Producer(paramsSeq.toMap.asJava) + if (log.isDebugEnabled()) { + val redactedParamsSeq = KafkaRedactionUtil.redactParams(paramsSeq) + logDebug(s"Created a new instance of KafkaProducer for $redactedParamsSeq.") + } + kafkaProducer + } + + private def paramsToSeq(kafkaParams: ju.Map[String, Object]): Seq[(String, Object)] = { + kafkaParams.asScala.toSeq.sortBy(x => x._1) + } +} + +private[kafka010] object InternalKafkaProducerPool extends Logging { + private val pool = new InternalKafkaProducerPool( + Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())) + + private type CacheKey = Seq[(String, Object)] + private type Producer = KafkaProducer[Array[Byte], Array[Byte]] + + ShutdownHookManager.addShutdownHook { () => + try { + pool.shutdown() + } catch { + case e: Throwable => + logWarning("Ignoring Exception while shutting down pools from shutdown hook", e) + } + } + + /** + * This class is used as metadata of producer pool, and shouldn't be exposed to the public. + * This class assumes thread-safety is guaranteed by the caller. + */ + private[producer] class CachedProducerEntry( + val producer: CachedKafkaProducer, + cacheExpireTimeoutNs: Long) { + private var _refCount: Long = 0L + private var _expireAt: Long = Long.MaxValue + + /** exposed for testing */ + private[producer] def refCount: Long = _refCount + private[producer] def expireAt: Long = _expireAt + + def handleBorrowed(): Unit = { + _refCount += 1 + _expireAt = Long.MaxValue + } + + def handleReturned(curTimeNs: Long): Unit = { + require(_refCount > 0, "Reference count shouldn't become negative. Returning same producer " + + "multiple times would occur this bug. Check the logic around returning producer.") + + _refCount -= 1 + if (_refCount == 0) { + _expireAt = curTimeNs + cacheExpireTimeoutNs + } + } + + def expired(curTimeNs: Long): Boolean = _refCount == 0 && _expireAt < curTimeNs + } + + def acquire(kafkaParams: ju.Map[String, Object]): CachedKafkaProducer = { + pool.acquire(kafkaParams) + } + + def release(producer: CachedKafkaProducer): Unit = { + pool.release(producer) + } + + def reset(): Unit = pool.reset() +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala deleted file mode 100644 index 7425a74315e1..000000000000 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/CachedKafkaProducerSuite.scala +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.kafka010 - -import java.{util => ju} -import java.util.concurrent.ConcurrentMap - -import org.apache.kafka.clients.producer.KafkaProducer -import org.apache.kafka.common.serialization.ByteArraySerializer -import org.scalatest.PrivateMethodTester - -import org.apache.spark.sql.test.SharedSparkSession - -class CachedKafkaProducerSuite extends SharedSparkSession with PrivateMethodTester with KafkaTest { - - type KP = KafkaProducer[Array[Byte], Array[Byte]] - - protected override def beforeEach(): Unit = { - super.beforeEach() - CachedKafkaProducer.clear() - } - - test("Should return the cached instance on calling getOrCreate with same params.") { - val kafkaParams = new ju.HashMap[String, Object]() - kafkaParams.put("acks", "0") - // Here only host should be resolvable, it does not need a running instance of kafka server. - kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") - kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) - kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) - val producer = CachedKafkaProducer.getOrCreate(kafkaParams) - val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams) - assert(producer == producer2) - - val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]](Symbol("getAsMap")) - val map = CachedKafkaProducer.invokePrivate(cacheMap()) - assert(map.size == 1) - } - - test("Should close the correct kafka producer for the given kafkaPrams.") { - val kafkaParams = new ju.HashMap[String, Object]() - kafkaParams.put("acks", "0") - kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") - kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) - kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) - val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams) - kafkaParams.put("acks", "1") - val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams) - // With updated conf, a new producer instance should be created. - assert(producer != producer2) - - val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]](Symbol("getAsMap")) - val map = CachedKafkaProducer.invokePrivate(cacheMap()) - assert(map.size == 2) - - CachedKafkaProducer.close(kafkaParams) - val map2 = CachedKafkaProducer.invokePrivate(cacheMap()) - assert(map2.size == 1) - import scala.collection.JavaConverters._ - val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0) - assert(_producer == producer) - } -} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index ac242ba3d135..e2dcd6200531 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -370,7 +370,7 @@ class KafkaContinuousSinkSuite extends KafkaSinkStreamingSuiteBase { iter.foreach(writeTask.write(_)) writeTask.commit() } finally { - writeTask.invalidateProducer() + writeTask.close() } } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala index 19acda95c707..087d938f8ed8 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaTest.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kafka010 import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.kafka010.producer.InternalKafkaProducerPool /** A trait to clean cached Kafka producers in `afterAll` */ trait KafkaTest extends BeforeAndAfterAll { @@ -27,6 +28,6 @@ trait KafkaTest extends BeforeAndAfterAll { override def afterAll(): Unit = { super.afterAll() - CachedKafkaProducer.clear() + InternalKafkaProducerPool.reset() } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala new file mode 100644 index 000000000000..97885754f204 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/producer/InternalKafkaProducerPoolSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010.producer + +import java.{util => ju} +import java.util.concurrent.{Executors, TimeUnit} + +import scala.util.Random + +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.jmock.lib.concurrent.DeterministicScheduler + +import org.apache.spark.SparkConf +import org.apache.spark.sql.kafka010.{PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL, PRODUCER_CACHE_TIMEOUT} +import org.apache.spark.sql.kafka010.producer.InternalKafkaProducerPool.CachedProducerEntry +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ManualClock + +class InternalKafkaProducerPoolSuite extends SharedSparkSession { + + private var pool: InternalKafkaProducerPool = _ + + protected override def afterEach(): Unit = { + if (pool != null) { + try { + pool.shutdown() + pool = null + } catch { + // ignore as it's known issue, DeterministicScheduler doesn't support shutdown + case _: UnsupportedOperationException => + } + } + } + + test("Should return same cached instance on calling acquire with same params.") { + pool = new InternalKafkaProducerPool(new SparkConf()) + + val kafkaParams = getTestKafkaParams() + val producer = pool.acquire(kafkaParams) + val producer2 = pool.acquire(kafkaParams) + assert(producer eq producer2) + + val map = pool.getAsMap + assert(map.size === 1) + val cacheEntry = map.head._2 + assertCacheEntry(pool, cacheEntry, 2L) + + pool.release(producer) + assertCacheEntry(pool, cacheEntry, 1L) + + pool.release(producer2) + assertCacheEntry(pool, cacheEntry, 0L) + + val producer3 = pool.acquire(kafkaParams) + assertCacheEntry(pool, cacheEntry, 1L) + assert(producer eq producer3) + } + + test("Should return different cached instances on calling acquire with different params.") { + pool = new InternalKafkaProducerPool(new SparkConf()) + + val kafkaParams = getTestKafkaParams() + val producer = pool.acquire(kafkaParams) + kafkaParams.put("acks", "1") + val producer2 = pool.acquire(kafkaParams) + // With updated conf, a new producer instance should be created. + assert(producer ne producer2) + + val map = pool.getAsMap + assert(map.size === 2) + val cacheEntry = map.find(_._2.producer.id == producer.id).get._2 + assertCacheEntry(pool, cacheEntry, 1L) + val cacheEntry2 = map.find(_._2.producer.id == producer2.id).get._2 + assertCacheEntry(pool, cacheEntry2, 1L) + } + + test("expire instances") { + val minEvictableIdleTimeMillis = 2000L + val evictorThreadRunIntervalMillis = 500L + + val conf = new SparkConf() + conf.set(PRODUCER_CACHE_TIMEOUT, minEvictableIdleTimeMillis) + conf.set(PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL, evictorThreadRunIntervalMillis) + + val scheduler = new DeterministicScheduler() + val clock = new ManualClock() + pool = new InternalKafkaProducerPool(scheduler, clock, conf) + + val kafkaParams = getTestKafkaParams() + + var map = pool.getAsMap + assert(map.isEmpty) + + val producer = pool.acquire(kafkaParams) + map = pool.getAsMap + assert(map.size === 1) + + clock.advance(minEvictableIdleTimeMillis + 100) + scheduler.tick(evictorThreadRunIntervalMillis + 100, TimeUnit.MILLISECONDS) + map = pool.getAsMap + assert(map.size === 1) + + pool.release(producer) + + // This will clean up expired instance from cache. + clock.advance(minEvictableIdleTimeMillis + 100) + scheduler.tick(evictorThreadRunIntervalMillis + 100, TimeUnit.MILLISECONDS) + + map = pool.getAsMap + assert(map.size === 0) + } + + test("reference counting with concurrent access") { + pool = new InternalKafkaProducerPool(new SparkConf()) + + val kafkaParams = getTestKafkaParams() + + val numThreads = 100 + val numProducerUsages = 500 + + def produce(i: Int): Unit = { + val producer = pool.acquire(kafkaParams) + try { + val map = pool.getAsMap + assert(map.size === 1) + val cacheEntry = map.head._2 + assert(cacheEntry.refCount > 0L) + assert(cacheEntry.expireAt === Long.MaxValue) + + Thread.sleep(Random.nextInt(100)) + } finally { + pool.release(producer) + } + } + + val threadpool = Executors.newFixedThreadPool(numThreads) + try { + val futures = (1 to numProducerUsages).map { i => + threadpool.submit(new Runnable { + override def run(): Unit = { produce(i) } + }) + } + futures.foreach(_.get(1, TimeUnit.MINUTES)) + } finally { + threadpool.shutdown() + } + + val map = pool.getAsMap + assert(map.size === 1) + + val cacheEntry = map.head._2 + assertCacheEntry(pool, cacheEntry, 0L) + } + + private def getTestKafkaParams(): ju.HashMap[String, Object] = { + val kafkaParams = new ju.HashMap[String, Object]() + kafkaParams.put("acks", "0") + // Here only host should be resolvable, it does not need a running instance of kafka server. + kafkaParams.put("bootstrap.servers", "127.0.0.1:9022") + kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName) + kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName) + kafkaParams + } + + private def assertCacheEntry( + pool: InternalKafkaProducerPool, + cacheEntry: CachedProducerEntry, + expectedRefCount: Long): Unit = { + val timeoutVal = TimeUnit.MILLISECONDS.toNanos(pool.cacheExpireTimeoutMillis) + assert(cacheEntry.refCount === expectedRefCount) + if (expectedRefCount > 0) { + assert(cacheEntry.expireAt === Long.MaxValue) + } else { + assert(cacheEntry.expireAt <= pool.clock.nanoTime() + timeoutVal) + } + } +}