Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool}

/**
* Dummy commit message. The DataSourceV2 framework requires a commit message implementation but we
Expand All @@ -44,34 +45,30 @@ private[kafka010] class KafkaDataWriter(
inputSchema: Seq[Attribute])
extends KafkaRowWriter(inputSchema, targetTopic) with DataWriter[InternalRow] {

private lazy val producer = CachedKafkaProducer.getOrCreate(producerParams)
private var producer: Option[CachedKafkaProducer] = None

def write(row: InternalRow): Unit = {
checkForErrors()
sendRow(row, producer)
if (producer.isEmpty) {
producer = Some(InternalKafkaProducerPool.acquire(producerParams))
}
producer.foreach { p => sendRow(row, p.producer) }
}

def commit(): WriterCommitMessage = {
// Send is asynchronous, but we can't commit until all rows are actually in Kafka.
// This requires flushing and then checking that no callbacks produced errors.
// We also check for errors before to fail as soon as possible - the check is cheap.
checkForErrors()
producer.flush()
producer.foreach(_.producer.flush())
checkForErrors()
KafkaDataWriterCommitMessage
}

def abort(): Unit = {}

def close(): Unit = {}

/** explicitly invalidate producer from pool. only for testing. */
private[kafka010] def invalidateProducer(): Unit = {
checkForErrors()
if (producer != null) {
producer.flush()
checkForErrors()
CachedKafkaProducer.close(producerParams)
}
def close(): Unit = {
producer.foreach(InternalKafkaProducerPool.release)
producer = None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection}
import org.apache.spark.sql.kafka010.producer.{CachedKafkaProducer, InternalKafkaProducerPool}
import org.apache.spark.sql.types.BinaryType

/**
Expand All @@ -39,25 +40,30 @@ private[kafka010] class KafkaWriteTask(
inputSchema: Seq[Attribute],
topic: Option[String]) extends KafkaRowWriter(inputSchema, topic) {
// used to synchronize with Kafka callbacks
private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _
private var producer: Option[CachedKafkaProducer] = None

/**
* Writes key value data out to topics.
*/
def execute(iterator: Iterator[InternalRow]): Unit = {
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
producer = Some(InternalKafkaProducerPool.acquire(producerConfiguration))
val internalProducer = producer.get.producer
while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
sendRow(currentRow, producer)
sendRow(currentRow, internalProducer)
}
}

def close(): Unit = {
checkForErrors()
if (producer != null) {
producer.flush()
try {
checkForErrors()
producer = null
producer.foreach { p =>
p.producer.flush()
checkForErrors()
}
} finally {
producer.foreach(InternalKafkaProducerPool.release)
producer = None
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ package object kafka010 { // scalastyle:ignore
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("10m")

private[kafka010] val PRODUCER_CACHE_EVICTOR_THREAD_RUN_INTERVAL =
ConfigBuilder("spark.kafka.producer.cache.evictorThreadRunInterval")
.doc("The interval of time between runs of the idle evictor thread for producer pool. " +
"When non-positive, no idle evictor thread will be run.")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("1m")

private[kafka010] val CONSUMER_CACHE_CAPACITY =
ConfigBuilder("spark.kafka.consumer.cache.capacity")
.doc("The maximum number of consumers cached. Please note it's a soft limit" +
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010.producer

import java.{util => ju}

import scala.util.control.NonFatal

import org.apache.kafka.clients.producer.KafkaProducer

import org.apache.spark.internal.Logging

private[kafka010] class CachedKafkaProducer(
val cacheKey: Seq[(String, Object)],
val producer: KafkaProducer[Array[Byte], Array[Byte]]) extends Logging {
val id: String = ju.UUID.randomUUID().toString

private[producer] def close(): Unit = {
try {
logInfo(s"Closing the KafkaProducer with id: $id.")
producer.close()
} catch {
case NonFatal(e) => logWarning("Error while closing kafka producer.", e)
}
}
}
Loading