Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010

import java.{util => ju}

import scala.collection.immutable.SortedMap
import scala.collection.mutable

import org.apache.kafka.clients.producer.KafkaProducer

import org.apache.spark.internal.Logging

private[kafka010] object CachedKafkaProducer extends Logging {

private type Producer = KafkaProducer[Array[Byte], Array[Byte]]

private val cacheMap = new mutable.HashMap[String, Producer]()

private def createKafkaProducer(
producerConfiguration: ju.Map[String, Object]): Producer = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent 2 more spaces

private def createKafkaProducer(
    producerConfiguration: ju.Map[String, Object]): Producer = {
  ...
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent 4 here

val uid = producerConfiguration.get(CanonicalizeKafkaParams.sparkKafkaParamsUniqueId)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we always use getUniqueId to get uid?

val kafkaProducer: Producer = new Producer(producerConfiguration)
cacheMap.put(uid.toString, kafkaProducer)
log.debug(s"Created a new instance of KafkaProducer for $producerConfiguration.")
kafkaProducer
}

private def getUniqueId(kafkaParams: ju.Map[String, Object]): String = {
val uid = kafkaParams.get(CanonicalizeKafkaParams.sparkKafkaParamsUniqueId)
assert(uid != null, s"KafkaParams($kafkaParams) not canonicalized")
uid.toString
}

/**
* Get a cached KafkaProducer for a given configuration. If matching KafkaProducer doesn't
* exist, a new KafkaProducer will be created. KafkaProducer is thread safe, it is best to keep
* one instance per specified kafkaParams.
*/
private[kafka010] def getOrCreate(
kafkaParams: ju.Map[String, Object]): Producer = synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

val params = if (!CanonicalizeKafkaParams.isCanonicalized(kafkaParams)) {
CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams)
} else kafkaParams
val uid = getUniqueId(params)
cacheMap.getOrElse(uid.toString, createKafkaProducer(params))
}

private[kafka010] def close(kafkaParams: ju.Map[String, Object]): Unit = {
val params = if (!CanonicalizeKafkaParams.isCanonicalized(kafkaParams)) {
CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams)
} else kafkaParams
val uid = getUniqueId(params)

val producer: Option[Producer] = cacheMap.remove(uid)

if (producer.isDefined) {
log.info(s"Closing the KafkaProducer with config: $kafkaParams")
producer.foreach(_.close())
} else {
log.warn(s"No KafkaProducer found in cache for $kafkaParams.")
}
}

// Intended for testing purpose only.
private def clear(): Unit = {
cacheMap.foreach(x => x._2.close())
cacheMap.clear()
}
}

/**
* Canonicalize kafka params i.e. append a unique internal id to kafka params, if it already does
* not exist. This is done to ensure, we have only one set of kafka parameters associated with a
* unique ID.
*/
private[kafka010] object CanonicalizeKafkaParams extends Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems kind of complicated also. Since we know these are always coming from Data[Stream/Frame]Writer and that will always give you Map[String, String] (and we expect the number of options here to be small). Could we just make the key for the cache a sorted Seq[(String, String)] rather than invent another GUID?


import scala.collection.JavaConverters._

private val registryMap = mutable.HashMap[String, String]()

private[kafka010] val sparkKafkaParamsUniqueId: String =
"spark.internal.sql.kafka.params.uniqueId"

private def generateRandomUUID(kafkaParams: String): String = {
val uuid = ju.UUID.randomUUID().toString
logDebug(s"Generating a new unique id:$uuid for kafka params: $kafkaParams")
registryMap.put(kafkaParams, uuid)
uuid
}

private[kafka010] def isCanonicalized(kafkaParams: ju.Map[String, Object]): Boolean = {
if (kafkaParams.get(sparkKafkaParamsUniqueId) != null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you don't need the if-else

true
} else {
false
}
}

private[kafka010] def computeUniqueCanonicalForm(
kafkaParams: ju.Map[String, Object]): ju.Map[String, Object] = synchronized {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on indent

if (isCanonicalized(kafkaParams)) {
logWarning(s"A unique id,$sparkKafkaParamsUniqueId ->" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space after ,

s" ${kafkaParams.get(sparkKafkaParamsUniqueId)}" +
s" already exists in kafka params, returning Kafka Params:$kafkaParams as is.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after :

kafkaParams
} else {
val sortedMap = SortedMap.empty[String, Object] ++ kafkaParams.asScala
val stringRepresentation: String = sortedMap.mkString("\n")
val uuid =
registryMap.getOrElse(stringRepresentation, generateRandomUUID(stringRepresentation))
val newMap = new ju.HashMap[String, Object](kafkaParams)
newMap.put(sparkKafkaParamsUniqueId, uuid)
newMap
}
}

// For testing purpose only.
private[kafka010] def clear(): Unit = {
registryMap.clear()
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: no need for newline

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.{util => ju}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.util.Utils
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to import this?


private[kafka010] class KafkaSink(
sqlContext: SQLContext,
Expand All @@ -40,4 +41,8 @@ private[kafka010] class KafkaSink(
latestBatchId = batchId
}
}

override def stop(): Unit = {
KafkaWriter.close(sqlContext.sparkContext, executorKafkaParams)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ import org.apache.spark.unsafe.types.UTF8String
* and not use wrong broker addresses.
*/
private[kafka010] class KafkaSource(
sqlContext: SQLContext,
kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
sqlContext: SQLContext,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 spaces please

kafkaReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends Source with Logging {

private val sc = sqlContext.sparkContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.kafka.clients.producer.{KafkaProducer, _}
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
Expand All @@ -44,7 +43,7 @@ private[kafka010] class KafkaWriteTask(
* Writes key value data out to topics.
*/
def execute(iterator: Iterator[InternalRow]): Unit = {
producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration)
producer = CachedKafkaProducer.getOrCreate(producerConfiguration)
while (iterator.hasNext && failedWrite == null) {
val currentRow = iterator.next()
val projectedRow = projection(currentRow)
Expand All @@ -68,11 +67,10 @@ private[kafka010] class KafkaWriteTask(
}

def close(): Unit = {
checkForErrors()
if (producer != null) {
checkForErrors
producer.close()
checkForErrors
producer = null
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please keep producer = null for double-close

producer.flush()
checkForErrors()
}
}

Expand All @@ -88,7 +86,7 @@ private[kafka010] class KafkaWriteTask(
case t =>
throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " +
s"must be a ${StringType}")
"must be a StringType")
}
val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME)
.getOrElse(Literal(null, BinaryType))
Expand All @@ -100,7 +98,7 @@ private[kafka010] class KafkaWriteTask(
}
val valueExpression = inputSchema
.find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse(
throw new IllegalStateException(s"Required attribute " +
throw new IllegalStateException("Required attribute " +
s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found")
)
valueExpression.dataType match {
Expand All @@ -114,7 +112,7 @@ private[kafka010] class KafkaWriteTask(
Cast(valueExpression, BinaryType)), inputSchema)
}

private def checkForErrors: Unit = {
private def checkForErrors(): Unit = {
if (failedWrite != null) {
throw failedWrite
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.spark.SparkContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary import?

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType}
Expand Down Expand Up @@ -94,4 +94,10 @@ private[kafka010] object KafkaWriter extends Logging {
}
}
}

def close(sc: SparkContext, kafkaParams: ju.Map[String, Object]): Unit = {
sc.parallelize(1 to 10000).foreachPartition { iter =>
CachedKafkaProducer.close(kafkaParams)
}
Copy link
Member Author

@ScrapCodes ScrapCodes May 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would cause CachedKafkaProducer.close to be executed on each executor. I am thinking of a better way here.
Any help would be appreciated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the KafkaSource also faces the same issue of not being able to close consumers. Can we use a guava cache with a (configurable) timeout? I guess that's the safest way to make sure that they'll eventually get closed.

Copy link
Member Author

@ScrapCodes ScrapCodes May 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using guave cache, we can close if not used for a certain time. Shall we ignore closing them during a shutdown ?
In the particular case of kafka producer, I do not see a direct problem with that. Since we do a producer.flush() on each batch. I was just wondering, with streaming sinks in general - what should be our strategy ?

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010

import java.{util => ju}

import scala.collection.mutable

import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.PrivateMethodTester

import org.apache.spark.sql.test.SharedSQLContext

class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester {

type KP = KafkaProducer[Array[Byte], Array[Byte]]

protected override def beforeEach(): Unit = {
super.beforeEach()
val clear = PrivateMethod[Unit]('clear)
CachedKafkaProducer.invokePrivate(clear())
CanonicalizeKafkaParams.clear()
}

test("Should return the cached instance on calling getOrCreate with same params.") {
val kafkaParams = new ju.HashMap[String, Object]()
kafkaParams.put("acks", "0")
// Here only host should be resolvable, it does not need a running instance of kafka server.
kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
val kafkaParams2 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams)
val producer = CachedKafkaProducer.getOrCreate(kafkaParams2)
val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams2)
assert(producer == producer2)

val cacheMap = PrivateMethod[mutable.HashMap[Int, KP]]('cacheMap)
val map = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map.size == 1)
}

test("Should close the correct kafka producer for the given kafkaPrams.") {
val kafkaParams = new ju.HashMap[String, Object]()
kafkaParams.put("acks", "0")
kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
val kafkaParams2 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams)
kafkaParams.put("acks", "1")
val kafkaParams3 = CanonicalizeKafkaParams.computeUniqueCanonicalForm(kafkaParams)
val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams2)

val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams3)
// With updated conf, a new producer instance should be created.
assert(producer != producer2)

val cacheMap = PrivateMethod[mutable.HashMap[Int, KP]]('cacheMap)
val map = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map.size == 2)

CachedKafkaProducer.close(kafkaParams3)
val map2 = CachedKafkaProducer.invokePrivate(cacheMap())
assert(map2.size == 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just know there is one KP by this assert. Seems we should also verify if we close the correct KP?

}
}
Loading