Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import org.apache.kafka.common.TopicPartition
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReaderFactory}


private[kafka010] class KafkaBatch(
strategy: ConsumerStrategy,
sourceOptions: Map[String, String],
sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
Expand All @@ -38,7 +39,7 @@ private[kafka010] class KafkaBatch(
assert(endingOffsets != EarliestOffsetRangeLimit,
"Ending offset not allowed to be set to earliest offsets.")

private val pollTimeoutMs = sourceOptions.getOrElse(
private[kafka010] val pollTimeoutMs = sourceOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
* properly read.
*/
class KafkaContinuousStream(
offsetReader: KafkaOffsetReader,
private[kafka010] val offsetReader: KafkaOffsetReader,
kafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousStream with Logging {

private val pollTimeoutMs =
private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)

// Initialized when creating reader factories. If this diverges from the partitions at the latest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,19 @@ import org.apache.spark.util.UninterruptibleThread
* and not use wrong broker addresses.
*/
private[kafka010] class KafkaMicroBatchStream(
kafkaOffsetReader: KafkaOffsetReader,
private[kafka010] val kafkaOffsetReader: KafkaOffsetReader,
executorKafkaParams: ju.Map[String, Object],
options: CaseInsensitiveStringMap,
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging {

private val pollTimeoutMs = options.getLong(
private[kafka010] val pollTimeoutMs = options.getLong(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
SparkEnv.get.conf.get(NETWORK_TIMEOUT) * 1000L)

private val maxOffsetsPerTrigger = Option(options.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER))
.map(_.toLong)
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)

private val rangeCalculator = KafkaOffsetRangeCalculator(options)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.kafka.clients.consumer.{Consumer, ConsumerConfig, KafkaConsume
import org.apache.kafka.common.TopicPartition

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}

Expand All @@ -47,7 +48,7 @@ import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}
private[kafka010] class KafkaOffsetReader(
consumerStrategy: ConsumerStrategy,
val driverKafkaParams: ju.Map[String, Object],
readerOptions: Map[String, String],
readerOptions: CaseInsensitiveMap[String],
driverGroupIdPrefix: String) extends Logging {
/**
* Used to ensure execute fetch operations execute in an UninterruptibleThread
Expand Down Expand Up @@ -88,10 +89,10 @@ private[kafka010] class KafkaOffsetReader(
_consumer
}

private val maxOffsetFetchAttempts =
private[kafka010] val maxOffsetFetchAttempts =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_NUM_RETRY, "3").toInt

private val offsetFetchAttemptIntervalMs =
private[kafka010] val offsetFetchAttemptIntervalMs =
readerOptions.getOrElse(KafkaSourceProvider.FETCH_OFFSET_RETRY_INTERVAL_MS, "1000").toLong

private def nextGroupId(): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String
Expand All @@ -33,7 +33,7 @@ import org.apache.spark.unsafe.types.UTF8String
private[kafka010] class KafkaRelation(
override val sqlContext: SQLContext,
strategy: ConsumerStrategy,
sourceOptions: Map[String, String],
sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,32 +78,32 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
schema: Option[StructType],
providerName: String,
parameters: Map[String, String]): Source = {
validateStreamOptions(parameters)
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, metadataPath)
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveParameters, metadataPath)

val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams = convertToSpecifiedParams(parameters)

val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)

val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
strategy(caseInsensitiveParameters),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
caseInsensitiveParameters,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaSource(
sqlContext,
kafkaOffsetReader,
kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId),
parameters,
caseInsensitiveParameters,
metadataPath,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
failOnDataLoss(caseInsensitiveParameters))
}

override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
Expand All @@ -119,24 +119,24 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
validateBatchOptions(parameters)
val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateBatchOptions(caseInsensitiveParameters)
val specifiedKafkaParams = convertToSpecifiedParams(parameters)

val startingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
caseInsensitiveParameters, STARTING_OFFSETS_OPTION_KEY, EarliestOffsetRangeLimit)
assert(startingRelationOffsets != LatestOffsetRangeLimit)

val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(caseInsensitiveParams,
ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
val endingRelationOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit)

new KafkaRelation(
sqlContext,
strategy(caseInsensitiveParams),
sourceOptions = parameters,
strategy(caseInsensitiveParameters),
sourceOptions = caseInsensitiveParameters,
specifiedKafkaParams = specifiedKafkaParams,
failOnDataLoss = failOnDataLoss(caseInsensitiveParams),
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
startingOffsets = startingRelationOffsets,
endingOffsets = endingRelationOffsets)
}
Expand Down Expand Up @@ -420,23 +420,22 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
val parameters = options.asScala.toMap
validateStreamOptions(parameters)
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateStreamOptions(caseInsensitiveOptions)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)

val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams = convertToSpecifiedParams(parameters)
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)

val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)

val kafkaOffsetReader = new KafkaOffsetReader(
strategy(parameters),
strategy(caseInsensitiveOptions),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
caseInsensitiveOptions,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaMicroBatchStream(
Expand All @@ -445,32 +444,26 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
options,
checkpointLocation,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
failOnDataLoss(caseInsensitiveOptions))
}

override def toContinuousStream(checkpointLocation: String): ContinuousStream = {
val parameters = options.asScala.toMap
validateStreamOptions(parameters)
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
validateStreamOptions(caseInsensitiveOptions)
// Each running query should use its own group id. Otherwise, the query may be only assigned
// partial data since Kafka will assign partitions to multiple consumers having the same group
// id. Hence, we should generate a unique id for each query.
val uniqueGroupId = streamingUniqueGroupId(parameters, checkpointLocation)
val uniqueGroupId = streamingUniqueGroupId(caseInsensitiveOptions, checkpointLocation)

val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) }
val specifiedKafkaParams =
parameters
.keySet
.filter(_.toLowerCase(Locale.ROOT).startsWith("kafka."))
.map { k => k.drop(6).toString -> parameters(k) }
.toMap
val specifiedKafkaParams = convertToSpecifiedParams(caseInsensitiveOptions)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a required change but thought it would be good to simplify.


val startingStreamOffsets = KafkaSourceProvider.getKafkaOffsetRangeLimit(
caseInsensitiveParams, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
caseInsensitiveOptions, STARTING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)

val kafkaOffsetReader = new KafkaOffsetReader(
strategy(caseInsensitiveParams),
strategy(caseInsensitiveOptions),
kafkaParamsForDriver(specifiedKafkaParams),
parameters,
caseInsensitiveOptions,
driverGroupIdPrefix = s"$uniqueGroupId-driver")

new KafkaContinuousStream(
Expand All @@ -479,7 +472,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
options,
checkpointLocation,
startingStreamOffsets,
failOnDataLoss(caseInsensitiveParams))
failOnDataLoss(caseInsensitiveOptions))
}
}
}
Expand Down
Loading