Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming._
import org.apache.spark.sql.types.StructType

/**
* A [[ContinuousReadSupport]] for data from kafka.
* A [[ContinuousStream]] for data from kafka.
*
* @param offsetReader a reader used to get kafka offsets. Note that the actual data will be
* read by per-task consumers generated later.
Expand All @@ -46,17 +45,23 @@ import org.apache.spark.sql.types.StructType
* scenarios, where some offsets after the specified initial ones can't be
* properly read.
*/
class KafkaContinuousReadSupport(
class KafkaContinuousStream(
offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
sourceOptions: Map[String, String],
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousReadSupport with Logging {
extends ContinuousStream with Logging {

private val pollTimeoutMs = sourceOptions.getOrElse("kafkaConsumer.pollTimeoutMs", "512").toLong

// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
// Exposed outside this object only for unit tests.
@volatile private[sql] var knownPartitions: Set[TopicPartition] = _


override def initialOffset(): Offset = {
val offsets = initialOffsets match {
case EarliestOffsetRangeLimit => KafkaSourceOffset(offsetReader.fetchEarliestOffsets())
Expand All @@ -67,27 +72,40 @@ class KafkaContinuousReadSupport(
offsets
}

override def fullSchema(): StructType = KafkaOffsetReader.kafkaSchema

override def newScanConfigBuilder(start: Offset): ScanConfigBuilder = {
new KafkaContinuousScanConfigBuilder(fullSchema(), start, offsetReader, reportDataLoss)
}

override def deserializeOffset(json: String): Offset = {
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}

override def planInputPartitions(config: ScanConfig): Array[InputPartition] = {
val startOffsets = config.asInstanceOf[KafkaContinuousScanConfig].startOffsets
override def planInputPartitions(start: Offset): Array[InputPartition] = {
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(start)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)

val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
if (deletedPartitions.nonEmpty) {
val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
reportDataLoss(message)
}

val startOffsets = newPartitionOffsets ++
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
knownPartitions = startOffsets.keySet

startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
}.toArray
}

override def createContinuousReaderFactory(
config: ScanConfig): ContinuousPartitionReaderFactory = {
override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
KafkaContinuousReaderFactory
}

Expand All @@ -105,8 +123,7 @@ class KafkaContinuousReadSupport(
KafkaSourceOffset(mergedMap)
}

override def needsReconfiguration(config: ScanConfig): Boolean = {
val knownPartitions = config.asInstanceOf[KafkaContinuousScanConfig].knownPartitions
override def needsReconfiguration(): Boolean = {
offsetReader.fetchLatestOffsets(None).keySet != knownPartitions
}

Expand Down Expand Up @@ -151,47 +168,6 @@ object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
}
}

class KafkaContinuousScanConfigBuilder(
schema: StructType,
startOffset: Offset,
offsetReader: KafkaOffsetReader,
reportDataLoss: String => Unit)
extends ScanConfigBuilder {

override def build(): ScanConfig = {
val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(startOffset)

val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet
val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet)
val newPartitionOffsets = offsetReader.fetchEarliestOffsets(newPartitions.toSeq)

val deletedPartitions = oldStartPartitionOffsets.keySet.diff(currentPartitionSet)
if (deletedPartitions.nonEmpty) {
val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
reportDataLoss(message)
}

val startOffsets = newPartitionOffsets ++
oldStartPartitionOffsets.filterKeys(!deletedPartitions.contains(_))
KafkaContinuousScanConfig(schema, startOffsets)
}
}

case class KafkaContinuousScanConfig(
readSchema: StructType,
startOffsets: Map[TopicPartition, Long])
extends ScanConfig {

// Created when building the scan config builder. If this diverges from the partitions at the
// latest offsets, we need to reconfigure the kafka read support.
def knownPartitions: Set[TopicPartition] = startOffsets.keySet
}

/**
* A per-task data reader for continuous Kafka processing.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder}
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
Expand All @@ -48,7 +48,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
with RelationProvider
with CreatableRelationProvider
with StreamingWriteSupportProvider
with ContinuousReadSupportProvider
with TableProvider
with Logging {
import KafkaSourceProvider._
Expand Down Expand Up @@ -107,46 +106,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
new KafkaTable(strategy(options.asMap().asScala.toMap))
}

/**
* Creates a [[org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReadSupport]] to read
* Kafka data in a continuous streaming query.
*/
override def createContinuousReadSupport(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadataPath: String,
options: DataSourceOptions): KafkaContinuousReadSupport = {
val parameters = options.asMap().asScala.toMap
validateStreamOptions(parameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)

val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap

val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)

val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaContinuousReadSupport(
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
parameters,
metadataPath,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
}

/**
* Returns a new base relation with the given parameters.
*
Expand Down Expand Up @@ -406,7 +365,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}

class KafkaTable(strategy: => ConsumerStrategy) extends Table
with SupportsMicroBatchRead {
with SupportsMicroBatchRead with SupportsContinuousRead {

override def name(): String = s"Kafka $strategy"

Expand Down Expand Up @@ -449,6 +408,40 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
}

override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
val parameters = options.asMap().asScala.toMap
validateStreamOptions(parameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)

val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap

val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)

val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaContinuousStream(
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
parameters,
checkpointLocation,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,11 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
assert(
query.lastExecution.executedPlan.collectFirst {
case scan: ContinuousScanExec
if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] =>
scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig]
}.exists { config =>
if scan.stream.isInstanceOf[KafkaContinuousStream] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this logic is correct, but let's keep an eye on the tests after merging since some flakiness slipped through in the last iteration of the refactoring.

scan.stream.asInstanceOf[KafkaContinuousStream]
}.exists { stream =>
// Ensure the new topic is present and the old topic is gone.
config.knownPartitions.exists(_.topic == topic2)
stream.knownPartitions.exists(_.topic == topic2)
},
s"query never reconfigured to new topic $topic2")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ trait KafkaContinuousTest extends KafkaSourceTest {
assert(
query.lastExecution.executedPlan.collectFirst {
case scan: ContinuousScanExec
if scan.readSupport.isInstanceOf[KafkaContinuousReadSupport] =>
scan.scanConfig.asInstanceOf[KafkaContinuousScanConfig]
if scan.stream.isInstanceOf[KafkaContinuousStream] =>
scan.stream.asInstanceOf[KafkaContinuousStream]
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ import scala.collection.JavaConverters._
import scala.io.Source
import scala.util.Random

import org.apache.kafka.clients.admin.{AdminClient, ConsumerGroupListing}
import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord, RecordMetadata}
import org.apache.kafka.clients.producer.{ProducerRecord, RecordMetadata}
import org.apache.kafka.common.TopicPartition
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._

import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
import org.apache.spark.sql.execution.datasources.v2.{OldStreamingDataSourceV2Relation, StreamingDataSourceV2Relation}
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
Expand Down Expand Up @@ -118,17 +117,10 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf
val sources: Seq[BaseStreamingSource] = {
query.get.logicalPlan.collect {
case StreamingExecutionRelation(source: KafkaSource, _) => source
case r: StreamingDataSourceV2Relation
if r.stream.isInstanceOf[KafkaMicroBatchStream] =>
r.stream.asInstanceOf[KafkaMicroBatchStream]
} ++ (query.get.lastExecution match {
case null => Seq()
case e => e.logical.collect {
case r: OldStreamingDataSourceV2Relation
if r.readSupport.isInstanceOf[KafkaContinuousReadSupport] =>
r.readSupport.asInstanceOf[KafkaContinuousReadSupport]
}
})
case r: StreamingDataSourceV2Relation if r.stream.isInstanceOf[KafkaMicroBatchStream] ||
r.stream.isInstanceOf[KafkaContinuousStream] =>
r.stream
}
}.distinct

if (sources.isEmpty) {
Expand Down

This file was deleted.

Loading