diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index 89732d309aa2..badf0429545f 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -614,6 +614,10 @@ The Dataframe being written to Kafka should have the following columns in schema topic (*optional) string + + partition (optional) + int + \* The topic column is required if the "topic" configuration option is not specified.
@@ -622,6 +626,12 @@ a ```null``` valued key column will be automatically added (see Kafka semantics how ```null``` valued key values are handled). If a topic column exists then its value is used as the topic when writing the given row to Kafka, unless the "topic" configuration option is set i.e., the "topic" configuration option overrides the topic column. +If a "partition" column is not specified (or its value is ```null```) +then the partition is calculated by the Kafka producer. +A Kafka partitioner can be specified in Spark by setting the +```kafka.partitioner.class``` option. If not present, Kafka default partitioner +will be used. + The following options must be set for the Kafka sink for both batch and streaming queries. diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index b423ddc959c1..5bdc1b5fe9f3 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -27,7 +27,7 @@ import org.apache.kafka.common.header.internals.RecordHeader import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} /** * Writes out data in a single Spark task, without any concerns about how @@ -92,8 +92,10 @@ private[kafka010] abstract class KafkaRowWriter( throw new NullPointerException(s"null topic present in the data. Use the " + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") } + val partition: Integer = + if (projectedRow.isNullAt(4)) null else projectedRow.getInt(4) val record = if (projectedRow.isNullAt(3)) { - new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value) + new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, partition, key, value) } else { val headerArray = projectedRow.getArray(3) val headers = (0 until headerArray.numElements()).map { i => @@ -101,7 +103,8 @@ private[kafka010] abstract class KafkaRowWriter( new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1)) .asInstanceOf[Header] } - new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava) + new ProducerRecord[Array[Byte], Array[Byte]]( + topic.toString, partition, key, value, headers.asJava) } producer.send(record, callback) } @@ -156,12 +159,23 @@ private[kafka010] abstract class KafkaRowWriter( throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " + s"attribute unsupported type ${t.catalogString}") } + val partitionExpression = + inputSchema.find(_.name == KafkaWriter.PARTITION_ATTRIBUTE_NAME) + .getOrElse(Literal(null, IntegerType)) + partitionExpression.dataType match { + case IntegerType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t. ${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " + + s"must be a ${IntegerType.catalogString}") + } UnsafeProjection.create( Seq( topicExpression, Cast(keyExpression, BinaryType), Cast(valueExpression, BinaryType), - headersExpression + headersExpression, + partitionExpression ), inputSchema ) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index bbb060356f73..9b0d11f137ce 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.{BinaryType, MapType, StringType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, MapType, StringType} import org.apache.spark.util.Utils /** @@ -41,6 +41,7 @@ private[kafka010] object KafkaWriter extends Logging { val KEY_ATTRIBUTE_NAME: String = "key" val VALUE_ATTRIBUTE_NAME: String = "value" val HEADERS_ATTRIBUTE_NAME: String = "headers" + val PARTITION_ATTRIBUTE_NAME: String = "partition" override def toString: String = "KafkaWriter" @@ -86,6 +87,14 @@ private[kafka010] object KafkaWriter extends Logging { throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " + s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}") } + schema.find(_.name == PARTITION_ATTRIBUTE_NAME).getOrElse( + Literal(null, IntegerType) + ).dataType match { + case IntegerType => // good + case _ => + throw new AnalysisException(s"$PARTITION_ATTRIBUTE_NAME attribute type " + + s"must be an ${IntegerType.catalogString}") + } } def write( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index 65adbd6b9887..cbf4952406c0 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -286,6 +286,15 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { } assert(ex3.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binary")) + + val ex4 = intercept[AnalysisException] { + /* partition field wrong type */ + createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as partition", "value" + ) + } + assert(ex4.getMessage.toLowerCase(Locale.ROOT).contains( + "partition attribute type must be an int")) } test("streaming - write to non-existing topic") { diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index d77b9a3b6a9e..aacb10f5197b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -22,6 +22,8 @@ import java.util.Locale import java.util.concurrent.atomic.AtomicInteger import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.clients.producer.internals.DefaultPartitioner +import org.apache.kafka.common.Cluster import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ @@ -33,7 +35,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType, StructField, StructType} abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest { protected var testUtils: KafkaTestUtils = _ @@ -293,6 +295,21 @@ class KafkaSinkStreamingSuite extends KafkaSinkSuiteBase with StreamTest { } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binary")) + + try { + ex = intercept[StreamingQueryException] { + /* partition field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value", "value as partition" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "partition attribute type must be an int")) } test("streaming - write to non-existing topic") { @@ -418,6 +435,65 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { ) } + def writeToKafka(df: DataFrame, topic: String, options: Map[String, String] = Map.empty): Unit = { + df + .write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .options(options) + .mode("append") + .save() + } + + def partitionsInTopic(topic: String): Set[Int] = { + createKafkaReader(topic) + .select("partition") + .map(_.getInt(0)) + .collect() + .toSet + } + + test("batch - partition column and partitioner priorities") { + val nrPartitions = 4 + val topic1 = newTopic() + val topic2 = newTopic() + val topic3 = newTopic() + val topic4 = newTopic() + testUtils.createTopic(topic1, nrPartitions) + testUtils.createTopic(topic2, nrPartitions) + testUtils.createTopic(topic3, nrPartitions) + testUtils.createTopic(topic4, nrPartitions) + val customKafkaPartitionerConf = Map( + "kafka.partitioner.class" -> "org.apache.spark.sql.kafka010.TestKafkaPartitioner" + ) + + val df = (0 until 5).map(n => (topic1, s"$n", s"$n")).toDF("topic", "key", "value") + + // default kafka partitioner + writeToKafka(df, topic1) + val partitionsInTopic1 = partitionsInTopic(topic1) + assert(partitionsInTopic1.size > 1) + + // custom partitioner (always returns 0) overrides default partitioner + writeToKafka(df, topic2, customKafkaPartitionerConf) + val partitionsInTopic2 = partitionsInTopic(topic2) + assert(partitionsInTopic2.size == 1) + assert(partitionsInTopic2.head == 0) + + // partition column overrides custom partitioner + val dfWithCustomPartition = df.withColumn("partition", lit(2)) + writeToKafka(dfWithCustomPartition, topic3, customKafkaPartitionerConf) + val partitionsInTopic3 = partitionsInTopic(topic3) + assert(partitionsInTopic3.size == 1) + assert(partitionsInTopic3.head == 2) + + // when the partition column value is null, it is ignored + val dfWithNullPartitions = df.withColumn("partition", lit(null).cast(IntegerType)) + writeToKafka(dfWithNullPartitions, topic4) + assert(partitionsInTopic(topic4) == partitionsInTopic1) + } + test("batch - null topic field value, and no topic option") { val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") val ex = intercept[SparkException] { @@ -515,3 +591,13 @@ class KafkaSinkBatchSuiteV2 extends KafkaSinkBatchSuiteBase { } } } + +class TestKafkaPartitioner extends DefaultPartitioner { + override def partition( + topic: String, + key: Any, + keyBytes: Array[Byte], + value: Any, + valueBytes: Array[Byte], + cluster: Cluster): Int = 0 +}