diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index 8fd1ad1d0eced..de785e5c5db37 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -54,14 +54,64 @@ test - com.amazonaws + software.amazon.kinesis amazon-kinesis-client ${aws.kinesis.client.version} + + + + com.kjetland + mbknor-jackson-jsonschema_2.12 + + + org.lz4 + lz4-java + + + + + software.amazon.awssdk + auth + ${aws.java.sdk.v2.version} + + + software.amazon.awssdk + sts + ${aws.java.sdk.v2.version} + + + software.amazon.awssdk + apache-client + ${aws.java.sdk.v2.version} + + + software.amazon.awssdk + regions + ${aws.java.sdk.v2.version} + + + software.amazon.awssdk + dynamodb + ${aws.java.sdk.v2.version} + + + software.amazon.awssdk + kinesis + ${aws.java.sdk.v2.version} + + + software.amazon.awssdk + cloudwatch + ${aws.java.sdk.v2.version} - com.amazonaws - aws-java-sdk-sts - ${aws.java.sdk.version} + software.amazon.awssdk + sdk-core + ${aws.java.sdk.v2.version} software.amazon.kinesis diff --git a/connector/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/connector/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 636af9f6c6060..98ef98798bb0c 100644 --- a/connector/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/connector/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -16,6 +16,7 @@ */ package org.apache.spark.examples.streaming; +import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; @@ -38,8 +39,10 @@ import scala.Tuple2; import scala.reflect.ClassTag$; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.services.kinesis.AmazonKinesisClient; +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.services.kinesis.KinesisClient; +import software.amazon.awssdk.services.kinesis.model.DescribeStreamRequest; /** * Consumes messages from a Amazon Kinesis streams and does wordcount. @@ -66,7 +69,7 @@ * There is a companion helper class called KinesisWordProducerASL which puts dummy data * onto the Kinesis stream. * - * This code uses the DefaultAWSCredentialsProviderChain to find credentials + * This code uses the DefaultCredentialsProvider to find credentials * in the following order: * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY * Java System Properties - aws.accessKeyId and aws.secretKey @@ -106,11 +109,19 @@ public static void main(String[] args) throws Exception { String endpointUrl = args[2]; // Create a Kinesis client in order to determine the number of shards for the given stream - AmazonKinesisClient kinesisClient = - new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()); - kinesisClient.setEndpoint(endpointUrl); + KinesisClient kinesisClient = + KinesisClient.builder() + .credentialsProvider(DefaultCredentialsProvider.create()) + .endpointOverride(URI.create(endpointUrl)) + .httpClientBuilder(ApacheHttpClient.builder()) + .build(); + + DescribeStreamRequest describeStreamRequest = + DescribeStreamRequest.builder() + .streamName(streamName) + .build(); int numShards = - kinesisClient.describeStream(streamName).getStreamDescription().getShards().size(); + kinesisClient.describeStream(describeStreamRequest).streamDescription().shards().size(); // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. diff --git a/connector/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java b/connector/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java index b5f5ab0e90540..936044c5297ab 100644 --- a/connector/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java +++ b/connector/kinesis-asl/src/main/java/org/apache/spark/streaming/kinesis/KinesisInitialPositions.java @@ -16,7 +16,7 @@ */ package org.apache.spark.streaming.kinesis; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import software.amazon.kinesis.common.InitialPositionInStream; import java.io.Serializable; import java.util.Date; diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala index 737e5199e71a4..ec3bc67f5e2af 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisExampleUtils.scala @@ -19,16 +19,16 @@ package org.apache.spark.examples.streaming import scala.jdk.CollectionConverters._ -import com.amazonaws.regions.RegionUtils -import com.amazonaws.services.kinesis.AmazonKinesis +import software.amazon.awssdk.regions.servicemetadata.KinesisServiceMetadata private[streaming] object KinesisExampleUtils { def getRegionNameByEndpoint(endpoint: String): String = { val uri = new java.net.URI(endpoint) - RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + val kinesisServiceMetadata = new KinesisServiceMetadata() + kinesisServiceMetadata.regions .asScala - .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) - .map(_.getName) + .find(r => kinesisServiceMetadata.endpointFor(r).toString.equals(uri.getHost)) + .map(_.id) .getOrElse( throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) } diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index cc24c378f4cbf..217069bd16c6e 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -18,15 +18,18 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import java.net.URI import java.nio.ByteBuffer import scala.util.Random -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain -import com.amazonaws.services.kinesis.AmazonKinesisClient -import com.amazonaws.services.kinesis.model.PutRecordRequest import org.apache.logging.log4j.Level import org.apache.logging.log4j.core.config.Configurator +import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider +import software.amazon.awssdk.core.SdkBytes +import software.amazon.awssdk.http.apache.ApacheHttpClient +import software.amazon.awssdk.services.kinesis.KinesisClient +import software.amazon.awssdk.services.kinesis.model.{DescribeStreamRequest, PutRecordRequest} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging @@ -101,13 +104,22 @@ object KinesisWordCountASL extends Logging { // Determine the number of shards from the stream using the low-level Kinesis Client // from the AWS Java SDK. - val credentials = new DefaultAWSCredentialsProviderChain().getCredentials() - require(credentials != null, + val credentialsProvider = DefaultCredentialsProvider.create + require(credentialsProvider.resolveCredentials() != null, "No AWS credentials found. Please specify credentials using one of the methods specified " + - "in http://docs.aws.amazon.com/AWSSdkDocsJava/latest/DeveloperGuide/credentials.html") - val kinesisClient = new AmazonKinesisClient(credentials) - kinesisClient.setEndpoint(endpointUrl) - val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards().size + "in https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials.html") + val kinesisClient = KinesisClient.builder() + .credentialsProvider(credentialsProvider) + .endpointOverride(URI.create(endpointUrl)) + .httpClientBuilder(ApacheHttpClient.builder()) + .build() + val describeStreamRequest = DescribeStreamRequest.builder() + .streamName(streamName) + .build() + val numShards = kinesisClient.describeStream(describeStreamRequest) + .streamDescription + .shards + .size // In this example, we're going to create 1 Kinesis Receiver/input DStream for each shard. @@ -221,8 +233,11 @@ object KinesisWordProducerASL { val totals = scala.collection.mutable.Map[String, Int]() // Create the low-level Kinesis Client from the AWS Java SDK. - val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) - kinesisClient.setEndpoint(endpoint) + val kinesisClient = KinesisClient.builder() + .credentialsProvider(DefaultCredentialsProvider.create()) + .endpointOverride(URI.create(endpoint)) + .httpClientBuilder(ApacheHttpClient.builder()) + .build() println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + s" $recordsPerSecond records per second and $wordsPerRecord words per record") @@ -247,12 +262,14 @@ object KinesisWordProducerASL { val partitionKey = s"partitionKey-$recordNum" // Create a PutRecordRequest with an Array[Byte] version of the data - val putRecordRequest = new PutRecordRequest().withStreamName(stream) - .withPartitionKey(partitionKey) - .withData(ByteBuffer.wrap(data.getBytes())) + val putRecordRequest = PutRecordRequest.builder() + .streamName(stream) + .partitionKey(partitionKey) + .data(SdkBytes.fromByteBuffer(ByteBuffer.wrap(data.getBytes()))) + .build() // Put the record onto the stream and capture the PutRecordResult - val putRecordResult = kinesisClient.putRecord(putRecordRequest) + kinesisClient.putRecord(putRecordRequest) } // Sleep for a second diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index ac3622f93321a..a9a51db7abe19 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -17,16 +17,20 @@ package org.apache.spark.streaming.kinesis +import java.net.URI import java.util.concurrent.TimeUnit +import java.util.stream.Collectors import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag import scala.util.control.NonFatal -import com.amazonaws.auth.AWSCredentials -import com.amazonaws.services.kinesis.AmazonKinesisClient -import com.amazonaws.services.kinesis.clientlibrary.types.UserRecord -import com.amazonaws.services.kinesis.model._ +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider +import software.amazon.awssdk.http.apache.ApacheHttpClient +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.kinesis.KinesisClient +import software.amazon.awssdk.services.kinesis.model.{GetRecordsRequest, GetRecordsResponse, GetShardIteratorRequest, GetShardIteratorResponse, ProvisionedThroughputExceededException, ShardIteratorType} +import software.amazon.kinesis.retrieval.{AggregatorUtil, KinesisClientRecord} import org.apache.spark._ import org.apache.spark.internal.Logging @@ -84,7 +88,7 @@ class KinesisBackedBlockRDD[T: ClassTag]( @transient private val _blockIds: Array[BlockId], @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], @transient private val isBlockIdValid: Array[Boolean] = Array.empty, - val messageHandler: Record => T = KinesisInputDStream.defaultMessageHandler _, + val messageHandler: KinesisClientRecord => T = KinesisInputDStream.defaultMessageHandler _, val kinesisCreds: SparkAWSCredentials = DefaultCredentials, val kinesisReadConfigs: KinesisReadConfigurations = KinesisReadConfigurations() ) extends BlockRDD[T](sc, _blockIds) { @@ -112,9 +116,9 @@ class KinesisBackedBlockRDD[T: ClassTag]( } def getBlockFromKinesis(): Iterator[T] = { - val credentials = kinesisCreds.provider.getCredentials + val credentialsProvider = kinesisCreds.provider partition.seqNumberRanges.ranges.iterator.flatMap { range => - new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName, + new KinesisSequenceRangeIterator(credentialsProvider, endpointUrl, regionName, range, kinesisReadConfigs).map(messageHandler) } } @@ -134,13 +138,19 @@ class KinesisBackedBlockRDD[T: ClassTag]( */ private[kinesis] class KinesisSequenceRangeIterator( - credentials: AWSCredentials, + credentialsProvider: AwsCredentialsProvider, endpointUrl: String, regionId: String, range: SequenceNumberRange, - kinesisReadConfigs: KinesisReadConfigurations) extends NextIterator[Record] with Logging { - - private val client = new AmazonKinesisClient(credentials) + kinesisReadConfigs: KinesisReadConfigurations) + extends NextIterator[KinesisClientRecord] with Logging { + + private val client = KinesisClient.builder() + .credentialsProvider(credentialsProvider) + .region(Region.of(regionId)) + .endpointOverride(URI.create(endpointUrl)) + .httpClientBuilder(ApacheHttpClient.builder()) + .build() private val streamName = range.streamName private val shardId = range.shardId // AWS limits to maximum of 10k records per get call @@ -148,12 +158,11 @@ class KinesisSequenceRangeIterator( private var toSeqNumberReceived = false private var lastSeqNumber: String = null - private var internalIterator: Iterator[Record] = null - - client.setEndpoint(endpointUrl) + private var internalIterator: Iterator[KinesisClientRecord] = null + private val aggregatorUtil = new AggregatorUtil() - override protected def getNext(): Record = { - var nextRecord: Record = null + override protected def getNext(): KinesisClientRecord = { + var nextRecord: KinesisClientRecord = null if (toSeqNumberReceived) { finished = true } else { @@ -183,11 +192,11 @@ class KinesisSequenceRangeIterator( // Get the record, copy the data into a byte array and remember its sequence number nextRecord = internalIterator.next() - lastSeqNumber = nextRecord.getSequenceNumber() + lastSeqNumber = nextRecord.sequenceNumber // If the this record's sequence number matches the stopping sequence number, then make sure // the iterator is marked finished next time getNext() is called - if (nextRecord.getSequenceNumber == range.toSeqNumber) { + if (nextRecord.sequenceNumber == range.toSeqNumber) { toSeqNumberReceived = true } } @@ -196,7 +205,7 @@ class KinesisSequenceRangeIterator( } override protected def close(): Unit = { - client.shutdown() + client.close() } /** @@ -205,7 +214,7 @@ class KinesisSequenceRangeIterator( private def getRecords( iteratorType: ShardIteratorType, seqNum: String, - recordCount: Int): Iterator[Record] = { + recordCount: Int): Iterator[KinesisClientRecord] = { val shardIterator = getKinesisIterator(iteratorType, seqNum) val result = getRecordsAndNextKinesisIterator(shardIterator, recordCount) result._1 @@ -217,19 +226,23 @@ class KinesisSequenceRangeIterator( */ private def getRecordsAndNextKinesisIterator( shardIterator: String, - recordCount: Int): (Iterator[Record], String) = { - val getRecordsRequest = new GetRecordsRequest - getRecordsRequest.setRequestCredentials(credentials) - getRecordsRequest.setShardIterator(shardIterator) - getRecordsRequest.setLimit(Math.min(recordCount, this.maxGetRecordsLimit)) - val getRecordsResult = retryOrTimeout[GetRecordsResult]( + recordCount: Int): (Iterator[KinesisClientRecord], String) = { + val getRecordsRequest = GetRecordsRequest.builder() + .shardIterator(shardIterator) + .limit(Math.min(recordCount, this.maxGetRecordsLimit)) + .build() + val getRecordsResponse = retryOrTimeout[GetRecordsResponse]( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) } // De-aggregate records, if KPL was used in producing the records. The KCL automatically // handles de-aggregation during regular operation. This code path is used during recovery - val recordIterator = UserRecord.deaggregate(getRecordsResult.getRecords) - (recordIterator.iterator().asScala, getRecordsResult.getNextShardIterator) + val records = getRecordsResponse.records() + .stream() + .map[KinesisClientRecord](r => KinesisClientRecord.fromRecord(r)) + .collect(Collectors.toList[KinesisClientRecord]()) + val recordIterator = aggregatorUtil.deaggregate(records) + (recordIterator.iterator().asScala, getRecordsResponse.nextShardIterator) } /** @@ -239,17 +252,18 @@ class KinesisSequenceRangeIterator( private def getKinesisIterator( iteratorType: ShardIteratorType, sequenceNumber: String): String = { - val getShardIteratorRequest = new GetShardIteratorRequest - getShardIteratorRequest.setRequestCredentials(credentials) - getShardIteratorRequest.setStreamName(streamName) - getShardIteratorRequest.setShardId(shardId) - getShardIteratorRequest.setShardIteratorType(iteratorType.toString) - getShardIteratorRequest.setStartingSequenceNumber(sequenceNumber) - val getShardIteratorResult = retryOrTimeout[GetShardIteratorResult]( + val getShardIteratorRequest = GetShardIteratorRequest.builder() + .streamName(streamName) + .shardId(shardId) + .shardIteratorType(iteratorType) + .startingSequenceNumber(sequenceNumber) + .build() + + val getShardIteratorResponse = retryOrTimeout[GetShardIteratorResponse]( s"getting shard iterator from sequence number $sequenceNumber") { client.getShardIterator(getShardIteratorRequest) } - getShardIteratorResult.getShardIterator + getShardIteratorResponse.shardIterator } /** Helper method to retry Kinesis API request with exponential backoff and timeouts */ diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala index b259a5337f37e..d6ce9c6c4f4c0 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala @@ -20,7 +20,7 @@ import java.util.concurrent._ import scala.util.control.NonFatal -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import software.amazon.kinesis.processor.RecordProcessorCheckpointer import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{SHARD_ID, WORKER_URL} @@ -33,35 +33,34 @@ import org.apache.spark.util.{Clock, SystemClock} * * @param receiver The receiver that keeps track of which sequence numbers we can checkpoint * @param checkpointInterval How frequently we will checkpoint to DynamoDB - * @param workerId Worker Id of KCL worker for logging purposes + * @param schedulerId Scheduler Id of KCL scheduler for logging purposes * @param clock In order to use ManualClocks for the purpose of testing */ private[kinesis] class KinesisCheckpointer( receiver: KinesisReceiver[_], checkpointInterval: Duration, - workerId: String, + schedulerId: String, clock: Clock = new SystemClock) extends Logging { // a map from shardId's to checkpointers - private val checkpointers = new ConcurrentHashMap[String, IRecordProcessorCheckpointer]() + private val checkpointers = new ConcurrentHashMap[String, RecordProcessorCheckpointer]() private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]() private val checkpointerThread: RecurringTimer = startCheckpointerThread() /** Update the checkpointer instance to the most recent one for the given shardId. */ - def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + def setCheckpointer(shardId: String, checkpointer: RecordProcessorCheckpointer): Unit = { checkpointers.put(shardId, checkpointer) } /** * Stop tracking the specified shardId. * - * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown [[ShutdownReason.TERMINATE]], - * we will use that to make the final checkpoint. If `null` is provided, we will not make the - * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. + * If a checkpointer is provided, we will use that to make the final checkpoint. If `null` + * is provided, we will not make the checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]]. */ - def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + def removeCheckpointer(shardId: String, checkpointer: RecordProcessorCheckpointer): Unit = { synchronized { checkpointers.remove(shardId) } @@ -73,7 +72,7 @@ private[kinesis] class KinesisCheckpointer( KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) } catch { case NonFatal(e) => - logError(log"Exception: WorkerId ${MDC(WORKER_URL, workerId)} encountered an " + + logError(log"Exception: SchedulerId ${MDC(WORKER_URL, schedulerId)} encountered an " + log"exception while checkpointing to finish reading a shard of " + log"${MDC(SHARD_ID, shardId)}.", e) // Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor @@ -83,7 +82,7 @@ private[kinesis] class KinesisCheckpointer( } /** Perform the checkpoint. */ - private def checkpoint(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + private def checkpoint(shardId: String, checkpointer: RecordProcessorCheckpointer): Unit = { try { if (checkpointer != null) { receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum => @@ -93,8 +92,8 @@ private[kinesis] class KinesisCheckpointer( if (lastSeqNum == null || latestSeqNum > lastSeqNum) { /* Perform the checkpoint */ KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 100) - logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint at sequence number" + - s" $latestSeqNum for shardId $shardId") + logDebug(s"Checkpoint: schedulerId $schedulerId completed checkpoint at sequence " + + s" number $latestSeqNum for shardId $shardId") lastCheckpointedSeqNums.put(shardId, latestSeqNum) } } @@ -127,7 +126,7 @@ private[kinesis] class KinesisCheckpointer( */ private def startCheckpointerThread(): RecurringTimer = { val period = checkpointInterval.milliseconds - val threadName = s"Kinesis Checkpointer - Worker $workerId" + val threadName = s"Kinesis Checkpointer - scheduler $schedulerId" val timer = new RecurringTimer(clock, period, _ => checkpointAll(), threadName) timer.start() logDebug(s"Started checkpointer thread: $threadName") diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 9e432eda6251b..11b6d28266a09 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -17,12 +17,11 @@ package org.apache.spark.streaming.kinesis -import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} -import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel -import com.amazonaws.services.kinesis.model.Record +import software.amazon.kinesis.common.InitialPositionInStream +import software.amazon.kinesis.metrics.{MetricsLevel, MetricsUtil} +import software.amazon.kinesis.retrieval.KinesisClientRecord import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, StorageLevel} @@ -42,7 +41,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag]( val checkpointAppName: String, val checkpointInterval: Duration, val _storageLevel: StorageLevel, - val messageHandler: Record => T, + val messageHandler: KinesisClientRecord => T, val kinesisCreds: SparkAWSCredentials, val dynamoDBCreds: Option[SparkAWSCredentials], val cloudWatchCreds: Option[SparkAWSCredentials], @@ -275,7 +274,7 @@ object KinesisInputDStream { /** * Sets the CloudWatch metrics level. Defaults to - * [[KinesisClientLibConfiguration.DEFAULT_METRICS_LEVEL]] if no custom value is specified. + * [[MetricsLevel.DETAILED]] if no custom value is specified. * * @param metricsLevel [[MetricsLevel]] to specify the CloudWatch metrics level * @return Reference to this [[KinesisInputDStream.Builder]] @@ -289,8 +288,8 @@ object KinesisInputDStream { /** * Sets the enabled CloudWatch metrics dimensions. Defaults to - * [[KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS]] - * if no custom value is specified. + * the set of [[MetricsUtil.OPERATION_DIMENSION_NAME]] and + * [[MetricsUtil.SHARD_ID_DIMENSION_NAME]] if no custom value is specified. * * @param metricsEnabledDimensions Set[String] to specify which CloudWatch metrics dimensions * should be enabled @@ -307,11 +306,12 @@ object KinesisInputDStream { * Create a new instance of [[KinesisInputDStream]] with configured parameters and the provided * message handler. * - * @param handler Function converting [[Record]] instances read by the KCL to DStream type [[T]] + * @param handler Function converting [[KinesisClientRecord]] instances read by the KCL to + * DStream type [[T]] * @return Instance of [[KinesisInputDStream]] constructed with configured parameters */ def buildWithMessageHandler[T: ClassTag]( - handler: Record => T): KinesisInputDStream[T] = { + handler: KinesisClientRecord => T): KinesisInputDStream[T] = { val ssc = getRequiredParam(streamingContext, "streamingContext") new KinesisInputDStream( ssc, @@ -351,9 +351,9 @@ object KinesisInputDStream { */ def builder: Builder = new Builder - private[kinesis] def defaultMessageHandler(record: Record): Array[Byte] = { + private[kinesis] def defaultMessageHandler(record: KinesisClientRecord): Array[Byte] = { if (record == null) return null - val byteBuffer = record.getData() + val byteBuffer = record.data val byteArray = new Array[Byte](byteBuffer.remaining()) byteBuffer.get(byteArray) byteArray @@ -365,7 +365,7 @@ object KinesisInputDStream { private[kinesis] val DEFAULT_INITIAL_POSITION: KinesisInitialPosition = new Latest() private[kinesis] val DEFAULT_STORAGE_LEVEL: StorageLevel = StorageLevel.MEMORY_AND_DISK_2 private[kinesis] val DEFAULT_METRICS_LEVEL: MetricsLevel = - KinesisClientLibConfiguration.DEFAULT_METRICS_LEVEL + MetricsLevel.DETAILED private[kinesis] val DEFAULT_METRICS_ENABLED_DIMENSIONS: Set[String] = - KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS.asScala.toSet + Set(MetricsUtil.OPERATION_DIMENSION_NAME, MetricsUtil.SHARD_ID_DIMENSION_NAME) } diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index ab91431035fef..ed5aee83901f6 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -16,17 +16,24 @@ */ package org.apache.spark.streaming.kinesis -import java.util.UUID +import java.net.URI +import java.util.{HashSet, List => JList, UUID} import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{KinesisClientLibConfiguration, Worker} -import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel -import com.amazonaws.services.kinesis.model.Record +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.cloudwatch.CloudWatchAsyncClient +import software.amazon.awssdk.services.dynamodb.DynamoDbAsyncClient +import software.amazon.awssdk.services.kinesis.KinesisAsyncClient +import software.amazon.kinesis.common.{ConfigsBuilder, InitialPositionInStreamExtended, KinesisClientUtil} +import software.amazon.kinesis.coordinator.Scheduler +import software.amazon.kinesis.metrics.{MetricsConfig, MetricsLevel} +import software.amazon.kinesis.processor.{RecordProcessorCheckpointer, ShardRecordProcessor, ShardRecordProcessorFactory} +import software.amazon.kinesis.retrieval.KinesisClientRecord +import software.amazon.kinesis.retrieval.polling.PollingConfig import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.WORKER_URL @@ -39,12 +46,12 @@ import org.apache.spark.util.Utils /** * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. - * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: + * This implementation relies on the Kinesis Client Library (KCL) Scheduler as described here: * https://github.com/awslabs/amazon-kinesis-client * * The way this Receiver works is as follows: * - * - The receiver starts a KCL Worker, which is essentially runs a threadpool of multiple + * - The receiver starts a KCL Scheduler, which is essentially runs a threadpool of multiple * KinesisRecordProcessor * - Each KinesisRecordProcessor receives data from a Kinesis shard in batches. Each batch is * inserted into a Block Generator, and the corresponding range of sequence numbers is recorded. @@ -62,7 +69,7 @@ import org.apache.spark.util.Utils * DynamoDB (lease coordination and checkpointing) and CloudWatch (metrics) * @param initialPosition Instance of [[KinesisInitialPosition]] * In the absence of Kinesis checkpoint info, this is the - * worker's initial starting position in the stream. + * scheduler's initial starting position in the stream. * The values are either the beginning of the stream * per Kinesis' limit of 24 hours * ([[KinesisInitialPositions.TrimHorizon]]) or @@ -92,7 +99,7 @@ private[kinesis] class KinesisReceiver[T]( checkpointAppName: String, checkpointInterval: Duration, storageLevel: StorageLevel, - messageHandler: Record => T, + messageHandler: KinesisClientRecord => T, kinesisCreds: SparkAWSCredentials, dynamoDBCreds: Option[SparkAWSCredentials], cloudWatchCreds: Option[SparkAWSCredentials], @@ -108,19 +115,19 @@ private[kinesis] class KinesisReceiver[T]( */ /** - * workerId is used by the KCL should be based on the ip address of the actual Spark Worker + * schedulerId is used by the KCL should be based on the ip address of the actual Spark Worker * where this code runs (not the driver's IP address.) */ - @volatile private var workerId: String = null + @volatile private var schedulerId: String = null /** - * Worker is the core client abstraction from the Kinesis Client Library (KCL). - * A worker can process more than one shards from the given stream. - * Each shard is assigned its own IRecordProcessor and the worker run multiple such + * Scheduler is the core client abstraction from the Kinesis Client Library (KCL). + * A Scheduler can process more than one shards from the given stream. + * Each shard is assigned its own ShardRecordProcessor and the scheduler run multiple such * processors. */ - @volatile private var worker: Worker = null - @volatile private var workerThread: Thread = null + @volatile private var scheduler: Scheduler = null + @volatile private var schedulerThread: Thread = null /** BlockGenerator used to generates blocks out of Kinesis data */ @volatile private var blockGenerator: BlockGenerator = null @@ -146,59 +153,71 @@ private[kinesis] class KinesisReceiver[T]( /** * This is called when the KinesisReceiver starts and must be non-blocking. - * The KCL creates and manages the receiving/processing thread pool through Worker.run(). + * The KCL creates and manages the receiving/processing thread pool through Scheduler.run(). */ override def onStart(): Unit = { blockGenerator = supervisor.createBlockGenerator(new GeneratedBlockHandler) - workerId = Utils.localHostName() + ":" + UUID.randomUUID() + schedulerId = Utils.localHostName() + ":" + UUID.randomUUID() - kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId) + kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, schedulerId) val kinesisProvider = kinesisCreds.provider - val kinesisClientLibConfiguration = { - val baseClientLibConfiguration = new KinesisClientLibConfiguration( - checkpointAppName, - streamName, - kinesisProvider, - dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider), - cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider), - workerId) - .withKinesisEndpoint(endpointUrl) - .withTaskBackoffTimeMillis(500) - .withRegionName(regionName) - .withMetricsLevel(metricsLevel) - .withMetricsEnabledDimensions(metricsEnabledDimensions.asJava) - - // Update the Kinesis client lib config with timestamp - // if InitialPositionInStream.AT_TIMESTAMP is passed - initialPosition match { - case ts: AtTimestamp => - baseClientLibConfiguration.withTimestampAtInitialPositionInStream(ts.getTimestamp) - case _ => - baseClientLibConfiguration.withInitialPositionInStream(initialPosition.getPosition) + val kinesisClient = KinesisClientUtil.createKinesisAsyncClient( + KinesisAsyncClient.builder + .region(Region.of(regionName)) + .credentialsProvider(kinesisProvider) + .endpointOverride(URI.create(endpointUrl))) + val dynamoClient = DynamoDbAsyncClient.builder + .region(Region.of(regionName)) + .credentialsProvider(dynamoDBCreds.map(_.provider).getOrElse(kinesisProvider)) + .build + val cloudWatchClient = CloudWatchAsyncClient.builder + .region(Region.of(regionName)) + .credentialsProvider(cloudWatchCreds.map(_.provider).getOrElse(kinesisProvider)) + .build + val recordProcessorFactory = new ShardRecordProcessorFactory { + override def shardRecordProcessor(): ShardRecordProcessor = { + new KinesisRecordProcessor(receiver, schedulerId) } } - /* - * RecordProcessorFactory creates impls of IRecordProcessor. - * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the - * IRecordProcessor.processRecords() method. - * We're using our custom KinesisRecordProcessor in this case. - */ - val recordProcessorFactory = new IRecordProcessorFactory { - override def createProcessor: IRecordProcessor = - new KinesisRecordProcessor(receiver, workerId) + val configsBuilder = new ConfigsBuilder(streamName, checkpointAppName, kinesisClient, + dynamoClient, cloudWatchClient, schedulerId, recordProcessorFactory) + val metricsConfig = new MetricsConfig(cloudWatchClient, checkpointAppName) + .metricsLevel(metricsLevel) + .metricsEnabledDimensions(new HashSet(metricsEnabledDimensions.asJava)) + + val initialPositionInStreamExtended = initialPosition match { + case ts: AtTimestamp => + InitialPositionInStreamExtended.newInitialPositionAtTimestamp(ts.getTimestamp) + case _ => + InitialPositionInStreamExtended.newInitialPosition(initialPosition.getPosition) } - worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) - workerThread = new Thread() { + val pollingConfig = new PollingConfig(streamName, kinesisClient) + // To maintain the same behavior as SDK v1, set the interval to 1000. + pollingConfig.idleTimeBetweenReadsInMillis(1000) + + scheduler = new Scheduler( + configsBuilder.checkpointConfig(), + configsBuilder.coordinatorConfig(), + configsBuilder.leaseManagementConfig(), + configsBuilder.lifecycleConfig(), + metricsConfig, + configsBuilder.processorConfig(), + configsBuilder.retrievalConfig() + .retrievalSpecificConfig(pollingConfig) + .initialPositionInStreamExtended(initialPositionInStreamExtended) + ) + + schedulerThread = new Thread() { override def run(): Unit = { try { - worker.run() + scheduler.run() } catch { case NonFatal(e) => - restart("Error running the KCL worker in Receiver", e) + restart("Error running the KCL scheduler in Receiver", e) } } } @@ -206,29 +225,29 @@ private[kinesis] class KinesisReceiver[T]( blockIdToSeqNumRanges.clear() blockGenerator.start() - workerThread.setName(s"Kinesis Receiver ${streamId}") - workerThread.setDaemon(true) - workerThread.start() + schedulerThread.setName(s"Kinesis Receiver ${streamId}") + schedulerThread.setDaemon(true) + schedulerThread.start() - logInfo(log"Started receiver with workerId ${MDC(WORKER_URL, workerId)}") + logInfo(log"Started receiver with schedulerId ${MDC(WORKER_URL, schedulerId)}") } /** * This is called when the KinesisReceiver stops. - * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL scheduler.shutdown() method stops the receiving/processing threads. * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. */ override def onStop(): Unit = { - if (workerThread != null) { - if (worker != null) { - worker.shutdown() - worker = null + if (schedulerThread != null) { + if (scheduler != null) { + scheduler.shutdown() + scheduler = null } - workerThread.join() - workerThread = null - logInfo(log"Stopped receiver for workerId ${MDC(WORKER_URL, workerId)}") + schedulerThread.join() + schedulerThread = null + logInfo(log"Stopped receiver for schedulerId ${MDC(WORKER_URL, schedulerId)}") } - workerId = null + schedulerId = null if (kinesisCheckpointer != null) { kinesisCheckpointer.shutdown() kinesisCheckpointer = null @@ -236,11 +255,11 @@ private[kinesis] class KinesisReceiver[T]( } /** Add records of the given shard to the current block being generated */ - private[kinesis] def addRecords(shardId: String, records: java.util.List[Record]): Unit = { + private[kinesis] def addRecords(shardId: String, records: JList[KinesisClientRecord]): Unit = { if (records.size > 0) { val dataIterator = records.iterator().asScala.map(messageHandler) val metadata = SequenceNumberRange(streamName, shardId, - records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber(), + records.get(0).sequenceNumber, records.get(records.size - 1).sequenceNumber, records.size()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) } @@ -261,7 +280,7 @@ private[kinesis] class KinesisReceiver[T]( * Set the checkpointer that will be used to checkpoint sequence numbers to DynamoDB for the * given shardId. */ - def setCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + def setCheckpointer(shardId: String, checkpointer: RecordProcessorCheckpointer): Unit = { assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") kinesisCheckpointer.setCheckpointer(shardId, checkpointer) } @@ -271,7 +290,7 @@ private[kinesis] class KinesisReceiver[T]( * checkpoint one last time for the given shard. If `checkpointer` is `null`, then we will not * checkpoint. */ - def removeCheckpointer(shardId: String, checkpointer: IRecordProcessorCheckpointer): Unit = { + def removeCheckpointer(shardId: String, checkpointer: RecordProcessorCheckpointer): Unit = { assert(kinesisCheckpointer != null, "Kinesis Checkpointer not initialized!") kinesisCheckpointer.removeCheckpointer(shardId, checkpointer) } diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 8304ddda96dfa..15964fcc75d9f 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -16,71 +16,68 @@ */ package org.apache.spark.streaming.kinesis -import java.util.List - import scala.util.Random import scala.util.control.NonFatal -import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} -import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer} -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason -import com.amazonaws.services.kinesis.model.Record +import software.amazon.kinesis.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import software.amazon.kinesis.lifecycle.events.{InitializationInput, LeaseLostInput, ProcessRecordsInput, ShardEndedInput, ShutdownRequestedInput} +import software.amazon.kinesis.processor.ShardRecordProcessor import org.apache.spark.internal.Logging -import org.apache.spark.internal.LogKeys.{KINESIS_REASON, RETRY_INTERVAL, SHARD_ID, WORKER_URL} +import org.apache.spark.internal.LogKeys.{RETRY_INTERVAL, SHARD_ID, WORKER_URL} /** * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. * This implementation operates on the Array[Byte] from the KinesisReceiver. - * The Kinesis Worker creates an instance of this KinesisRecordProcessor for each + * The Kinesis scheduler creates an instance of this KinesisRecordProcessor for each * shard in the Kinesis stream upon startup. This is normally done in separate threads, * but the KCLs within the KinesisReceivers will balance themselves out if you create * multiple Receivers. * * @param receiver Kinesis receiver - * @param workerId for logging purposes + * @param schedulerId for logging purposes */ -private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], workerId: String) - extends IRecordProcessor with Logging { +private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], schedulerId: String) + extends ShardRecordProcessor with Logging { // shardId populated during initialize() @volatile private var shardId: String = _ /** - * The Kinesis Client Library calls this method during IRecordProcessor initialization. + * The Kinesis Client Library calls this method during ShardRecordProcessor initialization. * - * @param shardId assigned by the KCL to this particular RecordProcessor. + * @param initializationInput contains parameters to the ShardRecordProcessor initialize method */ - override def initialize(shardId: String): Unit = { - this.shardId = shardId - logInfo(log"Initialized workerId ${MDC(WORKER_URL, workerId)} " + + override def initialize(initializationInput: InitializationInput): Unit = { + this.shardId = initializationInput.shardId + logInfo(log"Initialized schedulerId ${MDC(WORKER_URL, schedulerId)} " + log"with shardId ${MDC(SHARD_ID, shardId)}") } /** * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. - * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() - * and Spark Streaming's Receiver.store(). + * This is the record-processing bridge between the KCL's ShardRecordProcessor.processRecords() + * and Spark Streaming's Receiver * - * @param batch list of records from the Kinesis stream shard - * @param checkpointer used to update Kinesis when this batch has been processed/stored - * in the DStream + * @param processRecordsInput Provides the records to be processed as well as information and + * capabilities related to them (eg checkpointing). */ - override def processRecords(batch: List[Record], - checkpointer: IRecordProcessorCheckpointer): Unit = { + override def processRecords(processRecordsInput: ProcessRecordsInput): Unit = { + val batch = processRecordsInput.records + val checkpointer = processRecordsInput.checkpointer if (!receiver.isStopped()) { try { // Limit the number of processed records from Kinesis stream. This is because the KCL cannot // control the number of aggregated records to be fetched even if we set `MaxRecords` - // in `KinesisClientLibConfiguration`. For example, if we set 10 to the number of max - // records in a worker and a producer aggregates two records into one message, the worker + // in `PollingConfig`. For example, if we set 10 to the number of max records + // in a scheduler and a producer aggregates two records into one message, the scheduler // possibly 20 records every callback function called. val maxRecords = receiver.getCurrentLimit for (start <- 0 until batch.size by maxRecords) { val miniBatch = batch.subList(start, math.min(start + maxRecords, batch.size)) receiver.addRecords(shardId, miniBatch) - logDebug(s"Stored: Worker $workerId stored ${miniBatch.size} records " + + logDebug(s"Stored: Scheduler $schedulerId stored ${miniBatch.size} records " + s"for shardId $shardId") } receiver.setCheckpointer(shardId, checkpointer) @@ -91,56 +88,68 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w * This will potentially cause records since the last checkpoint to be processed * more than once. */ - logError(log"Exception: WorkerId ${MDC(WORKER_URL, workerId)} encountered and " + - log"exception while storing or checkpointing a batch for workerId " + - log"${MDC(WORKER_URL, workerId)} and shardId ${MDC(SHARD_ID, shardId)}.", e) + logError(log"Exception: SchedulerId ${MDC(WORKER_URL, schedulerId)} encountered and " + + log"exception while storing or checkpointing a batch for schedulerId " + + log"${MDC(WORKER_URL, schedulerId)} and shardId ${MDC(SHARD_ID, shardId)}.", e) - /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ + /* Rethrow the exception to the Kinesis scheduler that is managing + this RecordProcessor. */ throw e } } else { /* RecordProcessor has been stopped. */ - logInfo(log"Stopped: KinesisReceiver has stopped for workerId ${MDC(WORKER_URL, workerId)}" + - log" and shardId ${MDC(SHARD_ID, shardId)}. No more records will be processed.") + logInfo(log"Stopped: KinesisReceiver has stopped for schedulerId " + + log"${MDC(WORKER_URL, schedulerId)} and shardId ${MDC(SHARD_ID, shardId)}. " + + log"No more records will be processed.") } } /** - * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: - * 1) the stream is resharding by splitting or merging adjacent shards - * (ShutdownReason.TERMINATE) - * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason - * (ShutdownReason.ZOMBIE) + * Called when the lease that tied to this Kinesis record processor has been lost. + * Once the lease has been lost the record processor can no longer checkpoint. + * + * @param leaseLostInput gives access to information related to the loss of the lease. + * Currently this has no functionality. + */ + override def leaseLost(leaseLostInput: LeaseLostInput): Unit = { + logInfo(log"The lease for shardId: ${MDC(SHARD_ID, shardId)} is lost.") + receiver.removeCheckpointer(shardId, null) + } + + /** + * Called when the shard that this Kinesis record processor is handling has been completed. + * Once a shard has been completed no further records will ever arrive on that shard. * - * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE - * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) + * When this is called the record processor must checkpoint. Otherwise an exception + * will be thrown and the all child shards of this shard will not make progress. + * + * @param shardEndedInput provides access to a checkpointer method for completing processing of + * the shard. */ - override def shutdown( - checkpointer: IRecordProcessorCheckpointer, - reason: ShutdownReason): Unit = { - logInfo(log"Shutdown: Shutting down workerId ${MDC(WORKER_URL, workerId)} " + - log"with reason ${MDC(KINESIS_REASON, reason)}") - // null if not initialized before shutdown: + override def shardEnded(shardEndedInput: ShardEndedInput): Unit = { + logInfo(log"Reached shard end. Checkpointing for shardId: ${MDC(SHARD_ID, shardId)}") if (shardId == null) { - logWarning(log"No shardId for workerId ${MDC(WORKER_URL, workerId)}?") + logWarning(log"No shardId for schedulerId ${MDC(WORKER_URL, schedulerId)}?") } else { - reason match { - /* - * TERMINATE Use Case. Checkpoint. - * Checkpoint to indicate that all records from the shard have been drained and processed. - * It's now OK to read from the new shards that resulted from a resharding event. - */ - case ShutdownReason.TERMINATE => receiver.removeCheckpointer(shardId, checkpointer) + receiver.removeCheckpointer(shardId, shardEndedInput.checkpointer) + } + } - /* - * ZOMBIE Use Case or Unknown reason. NoOp. - * No checkpoint because other workers may have taken over and already started processing - * the same records. - * This may lead to records being processed more than once. - * Return null so that we don't checkpoint - */ - case _ => receiver.removeCheckpointer(shardId, null) - } + /** + * Called when the Scheduler has been requested to shutdown. This is called while the + * Kinesis record processor still holds the lease so checkpointing is possible. Once this method + * has completed the lease for the record processor is released, and + * {@link # leaseLost ( LeaseLostInput )} will be called at a later time. + * + * @param shutdownRequestedInput provides access to a checkpointer allowing a record processor to + * checkpoint before the shutdown is completed. + */ + override def shutdownRequested(shutdownRequestedInput: ShutdownRequestedInput): Unit = { + logInfo(log"Shutdown: Shutting down schedulerId: ${MDC(WORKER_URL, schedulerId)} ") + if (shardId == null) { + logWarning(log"No shardId for schedulerId ${MDC(WORKER_URL, schedulerId)}?") + } else { + receiver.removeCheckpointer(shardId, shutdownRequestedInput.checkpointer) } } } diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala index 8abaef6b834eb..dc1098a336335 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtilsPythonHelper.scala @@ -16,8 +16,8 @@ */ package org.apache.spark.streaming.kinesis -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream -import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel +import software.amazon.kinesis.common.InitialPositionInStream +import software.amazon.kinesis.metrics.MetricsLevel import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.Duration diff --git a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala index e821adca20d27..e8ccdcd6a99b7 100644 --- a/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala +++ b/connector/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/SparkAWSCredentials.scala @@ -17,49 +17,52 @@ package org.apache.spark.streaming.kinesis -import com.amazonaws.auth._ +import software.amazon.awssdk.auth.credentials.{AwsBasicCredentials, AwsCredentialsProvider, DefaultCredentialsProvider, StaticCredentialsProvider} +import software.amazon.awssdk.services.sts.StsClient +import software.amazon.awssdk.services.sts.auth.StsAssumeRoleCredentialsProvider +import software.amazon.awssdk.services.sts.model.AssumeRoleRequest import org.apache.spark.internal.Logging /** * Serializable interface providing a method executors can call to obtain an - * AWSCredentialsProvider instance for authenticating to AWS services. + * AwsCredentialsProvider instance for authenticating to AWS services. */ private[kinesis] sealed trait SparkAWSCredentials extends Serializable { /** - * Return an AWSCredentialProvider instance that can be used by the Kinesis Client + * Return an AwsCredentialProvider instance that can be used by the Kinesis Client * Library to authenticate to AWS services (Kinesis, CloudWatch and DynamoDB). */ - def provider: AWSCredentialsProvider + def provider: AwsCredentialsProvider } -/** Returns DefaultAWSCredentialsProviderChain for authentication. */ +/** Returns DefaultCredentialsProvider for authentication. */ private[kinesis] final case object DefaultCredentials extends SparkAWSCredentials { - def provider: AWSCredentialsProvider = new DefaultAWSCredentialsProviderChain + def provider: AwsCredentialsProvider = DefaultCredentialsProvider.create() } /** - * Returns AWSStaticCredentialsProvider constructed using basic AWS keypair. Falls back to using - * DefaultCredentialsProviderChain if unable to construct a AWSCredentialsProviderChain + * Returns StaticCredentialsProvider constructed using basic AWS keypair. Falls back to using + * DefaultCredentialsProvider if unable to construct a StaticCredentialsProvider * instance with the provided arguments (e.g. if they are null). */ private[kinesis] final case class BasicCredentials( awsAccessKeyId: String, awsSecretKey: String) extends SparkAWSCredentials with Logging { - def provider: AWSCredentialsProvider = try { - new AWSStaticCredentialsProvider(new BasicAWSCredentials(awsAccessKeyId, awsSecretKey)) + def provider: AwsCredentialsProvider = try { + StaticCredentialsProvider.create(AwsBasicCredentials.create(awsAccessKeyId, awsSecretKey)) } catch { case e: IllegalArgumentException => - logWarning("Unable to construct AWSStaticCredentialsProvider with provided keypair; " + - "falling back to DefaultCredentialsProviderChain.", e) - new DefaultAWSCredentialsProviderChain + logWarning("Unable to construct StaticCredentialsProvider with provided keypair; " + + "falling back to DefaultCredentialsProvider.", e) + DefaultCredentialsProvider.create() } } /** - * Returns an STSAssumeRoleSessionCredentialsProvider instance which assumes an IAM + * Returns an StsAssumeRoleCredentialsProvider instance which assumes an IAM * role in order to authenticate against resources in an external account. */ private[kinesis] final case class STSCredentials( @@ -69,16 +72,24 @@ private[kinesis] final case class STSCredentials( longLivedCreds: SparkAWSCredentials = DefaultCredentials) extends SparkAWSCredentials { - def provider: AWSCredentialsProvider = { - val builder = new STSAssumeRoleSessionCredentialsProvider.Builder(stsRoleArn, stsSessionName) - .withLongLivedCredentialsProvider(longLivedCreds.provider) + def provider: AwsCredentialsProvider = { + val stsClient = StsClient.builder() + .credentialsProvider(longLivedCreds.provider) + .build() + + val assumeRoleRequestBuilder = AssumeRoleRequest.builder() + .roleArn(stsRoleArn) + .roleSessionName(stsSessionName) stsExternalId match { case Some(stsExternalId) => - builder.withExternalId(stsExternalId) - .build() + assumeRoleRequestBuilder.externalId(stsExternalId) case None => - builder.build() } + + StsAssumeRoleCredentialsProvider.builder() + .stsClient(stsClient) + .refreshRequest(assumeRoleRequestBuilder.build()) + .build() } } @@ -98,8 +109,8 @@ object SparkAWSCredentials { * * @note The given AWS keypair will be saved in DStream checkpoints if checkpointing is * enabled. Make sure that your checkpoint directory is secure. Prefer using the - * default provider chain instead if possible - * (http://docs.aws.amazon.com/sdk-for-java/v1/developer-guide/credentials.html#credentials-default). + * default credentials provider instead if possible + * (https://docs.aws.amazon.com/sdk-for-java/latest/developer-guide/credentials-chain.html). * * @param accessKeyId AWS access key ID * @param secretKey AWS secret key diff --git a/connector/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java b/connector/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java index e64e08a38a4ae..b10b7e04d2b74 100644 --- a/connector/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java +++ b/connector/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisInputDStreamBuilderSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.streaming.kinesis; -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -26,6 +25,7 @@ import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.Seconds; +import software.amazon.kinesis.common.InitialPositionInStream; public class JavaKinesisInputDStreamBuilderSuite extends LocalJavaStreamingContext { /** diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala index 87592b6877b33..6dd589fe4d210 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala @@ -22,13 +22,13 @@ import java.util.concurrent.TimeoutException import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatestplus.mockito.MockitoSugar +import software.amazon.kinesis.processor.RecordProcessorCheckpointer import org.apache.spark.streaming.{Duration, TestSuiteBase} import org.apache.spark.util.ManualClock @@ -39,7 +39,7 @@ class KinesisCheckpointerSuite extends TestSuiteBase with PrivateMethodTester with Eventually { - private val workerId = "dummyWorkerId" + private val schedulerId = "dummySchedulerId" private val shardId = "dummyShardId" private val seqNum = "123" private val otherSeqNum = "245" @@ -48,7 +48,7 @@ class KinesisCheckpointerSuite extends TestSuiteBase private val someOtherSeqNum = Some(otherSeqNum) private var receiverMock: KinesisReceiver[Array[Byte]] = _ - private var checkpointerMock: IRecordProcessorCheckpointer = _ + private var checkpointerMock: RecordProcessorCheckpointer = _ private var kinesisCheckpointer: KinesisCheckpointer = _ private var clock: ManualClock = _ @@ -56,9 +56,13 @@ class KinesisCheckpointerSuite extends TestSuiteBase override def beforeEach(): Unit = { receiverMock = mock[KinesisReceiver[Array[Byte]]] - checkpointerMock = mock[IRecordProcessorCheckpointer] + checkpointerMock = mock[RecordProcessorCheckpointer] clock = new ManualClock() - kinesisCheckpointer = new KinesisCheckpointer(receiverMock, checkpointInterval, workerId, clock) + kinesisCheckpointer = new KinesisCheckpointer( + receiverMock, + checkpointInterval, + schedulerId, + clock) } test("checkpoint is not called twice for the same sequence number") { diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala index 9f2e34e2e2f99..2d82282ecff02 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisInputDStreamBuilderSuite.scala @@ -21,10 +21,10 @@ import java.util.Calendar import scala.jdk.CollectionConverters._ -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration} -import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel import org.scalatest.BeforeAndAfterEach import org.scalatestplus.mockito.MockitoSugar +import software.amazon.kinesis.common.InitialPositionInStream +import software.amazon.kinesis.metrics.{MetricsConfig, MetricsLevel} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Seconds, StreamingContext, TestSuiteBase} @@ -101,7 +101,7 @@ class KinesisInputDStreamBuilderSuite extends TestSuiteBase with BeforeAndAfterE val customCloudWatchCreds = mock[SparkAWSCredentials] val customMetricsLevel = MetricsLevel.NONE val customMetricsEnabledDimensions = - KinesisClientLibConfiguration.METRICS_ALWAYS_ENABLED_DIMENSIONS.asScala.toSet + MetricsConfig.METRICS_ALWAYS_ENABLED_DIMENSIONS.asScala.toSet val dstream = builder .endpointUrl(customEndpointUrl) diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala index d008de3b3f1c4..c83745f3f7853 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -20,16 +20,16 @@ import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.util.Arrays -import com.amazonaws.services.kinesis.clientlibrary.exceptions._ -import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer -import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason -import com.amazonaws.services.kinesis.model.Record import org.mockito.ArgumentMatchers.{anyList, anyString, eq => meq} import org.mockito.Mockito.{never, times, verify, when} import org.scalatest.BeforeAndAfter import org.scalatest.matchers.must.Matchers import org.scalatest.matchers.should.Matchers._ import org.scalatestplus.mockito.MockitoSugar +import software.amazon.kinesis.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException} +import software.amazon.kinesis.lifecycle.events.{InitializationInput, LeaseLostInput, ProcessRecordsInput, ShardEndedInput, ShutdownRequestedInput} +import software.amazon.kinesis.processor.RecordProcessorCheckpointer +import software.amazon.kinesis.retrieval.KinesisClientRecord import org.apache.spark.streaming.{Duration, TestSuiteBase} @@ -42,33 +42,43 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft val app = "TestKinesisReceiver" val stream = "mySparkStream" val endpoint = "endpoint-url" - val workerId = "dummyWorkerId" + val schedulerId = "dummySchedulerId" val shardId = "dummyShardId" val seqNum = "dummySeqNum" val checkpointInterval = Duration(10) val someSeqNum = Some(seqNum) - val record1 = new Record() - record1.setData(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) - val record2 = new Record() - record2.setData(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + val dummyInitializationInput = InitializationInput.builder() + .shardId(shardId) + .build() + + val record1 = KinesisClientRecord.builder() + .data(ByteBuffer.wrap("Spark In Action".getBytes(StandardCharsets.UTF_8))) + .build() + val record2 = KinesisClientRecord.builder() + .data(ByteBuffer.wrap("Learning Spark".getBytes(StandardCharsets.UTF_8))) + .build() val batch = Arrays.asList(record1, record2) var receiverMock: KinesisReceiver[Array[Byte]] = _ - var checkpointerMock: IRecordProcessorCheckpointer = _ + var checkpointerMock: RecordProcessorCheckpointer = _ override def beforeFunction(): Unit = { receiverMock = mock[KinesisReceiver[Array[Byte]]] - checkpointerMock = mock[IRecordProcessorCheckpointer] + checkpointerMock = mock[RecordProcessorCheckpointer] } test("process records including store and set checkpointer") { when(receiverMock.isStopped()).thenReturn(false) when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + recordProcessor.initialize(dummyInitializationInput) + val processRecordsInput = ProcessRecordsInput.builder() + .records(batch) + .checkpointer(checkpointerMock) + .build() + recordProcessor.processRecords(processRecordsInput) verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch) @@ -79,9 +89,13 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft when(receiverMock.isStopped()).thenReturn(false) when(receiverMock.getCurrentLimit).thenReturn(1) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + recordProcessor.initialize(dummyInitializationInput) + val processRecordsInput = ProcessRecordsInput.builder() + .records(batch) + .checkpointer(checkpointerMock) + .build() + recordProcessor.processRecords(processRecordsInput) verify(receiverMock, times(1)).isStopped() verify(receiverMock, times(1)).addRecords(shardId, batch.subList(0, 1)) @@ -93,8 +107,12 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft when(receiverMock.isStopped()).thenReturn(true) when(receiverMock.getCurrentLimit).thenReturn(Int.MaxValue) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.processRecords(batch, checkpointerMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + val processRecordsInput = ProcessRecordsInput.builder() + .records(batch) + .checkpointer(checkpointerMock) + .build() + recordProcessor.processRecords(processRecordsInput) verify(receiverMock, times(1)).isStopped() verify(receiverMock, never).addRecords(anyString, anyList()) @@ -109,9 +127,13 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft ).thenThrow(new RuntimeException()) intercept[RuntimeException] { - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.processRecords(batch, checkpointerMock) + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + recordProcessor.initialize(dummyInitializationInput) + val processRecordsInput = ProcessRecordsInput.builder() + .records(batch) + .checkpointer(checkpointerMock) + .build() + recordProcessor.processRecords(processRecordsInput) } verify(receiverMock, times(1)).isStopped() @@ -119,27 +141,42 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft verify(receiverMock, never).setCheckpointer(anyString, meq(checkpointerMock)) } - test("shutdown should checkpoint if the reason is TERMINATE") { + test("SPARK-45720: shutdownRequest should checkpoint") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE) + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + val shutdownRequestedInput = ShutdownRequestedInput.builder() + .checkpointer(checkpointerMock) + .build() + recordProcessor.initialize(dummyInitializationInput) + recordProcessor.shutdownRequested(shutdownRequestedInput) verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) } + test("SPARK-45720: shardEnded should checkpoint") { + when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) + + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + val shardEndedInput = ShardEndedInput.builder() + .checkpointer(checkpointerMock) + .build() + recordProcessor.initialize(dummyInitializationInput) + recordProcessor.shardEnded(shardEndedInput) + + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), meq(checkpointerMock)) + } - test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + test("SPARK-45720: leaseLost should not checkpoint") { when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum) - val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId) - recordProcessor.initialize(shardId) - recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) - recordProcessor.shutdown(checkpointerMock, null) + val recordProcessor = new KinesisRecordProcessor(receiverMock, schedulerId) + val leaseLostInput = LeaseLostInput.builder().build() + recordProcessor.initialize(dummyInitializationInput) + recordProcessor.leaseLost(leaseLostInput) - verify(receiverMock, times(2)).removeCheckpointer(meq(shardId), - meq[IRecordProcessorCheckpointer](null)) + verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), + meq[RecordProcessorCheckpointer](null)) } test("retry success on first attempt") { diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 43c4118d8f59f..f3b1015df32c2 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -21,10 +21,10 @@ import scala.collection.mutable import scala.concurrent.duration._ import scala.util.Random -import com.amazonaws.services.kinesis.model.Record import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} import org.scalatest.concurrent.Eventually import org.scalatest.matchers.should.Matchers._ +import software.amazon.kinesis.retrieval.KinesisClientRecord import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.network.util.JavaUtils @@ -195,7 +195,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } testIfEnabled("custom message handling") { - def addFive(r: Record): Int = JavaUtils.bytesToString(r.getData).toInt + 5 + def addFive(r: KinesisClientRecord): Int = JavaUtils.bytesToString(r.data).toInt + 5 val stream = KinesisInputDStream.builder.streamingContext(ssc) .checkpointAppName(appName) @@ -305,7 +305,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun val testData2 = 11 to 20 val testData3 = 21 to 30 - eventually(timeout(1.minute), interval(10.seconds)) { + eventually(timeout(2.minute), interval(10.seconds)) { localTestUtils.pushData(testData1, aggregateTestData) collected.synchronized { assert(collected === testData1.toSet, "\nData received does not match data sent") @@ -313,9 +313,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } val shardToSplit = localTestUtils.getShards().head - localTestUtils.splitShard(shardToSplit.getShardId) + localTestUtils.splitShard(shardToSplit.shardId) val (splitOpenShards, splitCloseShards) = localTestUtils.getShards().partition { shard => - shard.getSequenceNumberRange.getEndingSequenceNumber == null + shard.sequenceNumberRange.endingSequenceNumber == null } // We should have one closed shard and two open shards @@ -331,9 +331,9 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun } val Seq(shardToMerge, adjShard) = splitOpenShards - localTestUtils.mergeShard(shardToMerge.getShardId, adjShard.getShardId) + localTestUtils.mergeShard(shardToMerge.shardId, adjShard.shardId) val (mergedOpenShards, mergedCloseShards) = localTestUtils.getShards().partition { shard => - shard.getSequenceNumberRange.getEndingSequenceNumber == null + shard.sequenceNumberRange.endingSequenceNumber == null } // We should have three closed shards and one open shard diff --git a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala index 7674cef105e71..09d92b718447a 100644 --- a/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala +++ b/connector/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.kinesis -import java.nio.ByteBuffer +import java.net.URI import java.nio.charset.StandardCharsets import java.util.concurrent.TimeUnit @@ -26,13 +26,15 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.{Failure, Random, Success, Try} -import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain} -import com.amazonaws.regions.RegionUtils -import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient -import com.amazonaws.services.dynamodbv2.document.DynamoDB -import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient} -import com.amazonaws.services.kinesis.model._ -import com.amazonaws.waiters.WaiterParameters +import software.amazon.awssdk.auth.credentials.{AwsCredentials, DefaultCredentialsProvider} +import software.amazon.awssdk.core.SdkBytes +import software.amazon.awssdk.http.apache.ApacheHttpClient +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.regions.servicemetadata.KinesisServiceMetadata +import software.amazon.awssdk.services.dynamodb.DynamoDbClient +import software.amazon.awssdk.services.dynamodb.model.DeleteTableRequest +import software.amazon.awssdk.services.kinesis.KinesisClient +import software.amazon.awssdk.services.kinesis.model.{CreateStreamRequest, DeleteStreamRequest, DescribeStreamRequest, MergeShardsRequest, PutRecordRequest, ResourceNotFoundException, Shard, SplitShardRequest, StreamDescription} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{STREAM_NAME, TABLE_NAME} @@ -47,7 +49,6 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi val endpointUrl = KinesisTestUtils.endpointUrl val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl) - private val createStreamTimeoutSeconds = 300 private val describeStreamPollTimeSeconds = 1 @volatile @@ -56,18 +57,23 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi @volatile private var _streamName: String = _ - protected lazy val kinesisClient = { - val client = new AmazonKinesisClient(KinesisTestUtils.getAWSCredentials()) - client.setEndpoint(endpointUrl) - client + protected lazy val kinesisClient: KinesisClient = { + KinesisClient.builder() + .credentialsProvider(DefaultCredentialsProvider.create()) + .region(Region.of(regionName)) + .httpClientBuilder(ApacheHttpClient.builder()) + .endpointOverride(URI.create(endpointUrl)) + .build() } - private lazy val streamExistsWaiter = kinesisClient.waiters().streamExists() + private lazy val streamExistsWaiter = kinesisClient.waiter() private lazy val dynamoDB = { - val dynamoDBClient = new AmazonDynamoDBClient(new DefaultAWSCredentialsProviderChain()) - dynamoDBClient.setRegion(RegionUtils.getRegion(regionName)) - new DynamoDB(dynamoDBClient) + DynamoDbClient.builder() + .credentialsProvider(DefaultCredentialsProvider.create()) + .region(Region.of(regionName)) + .httpClientBuilder(ApacheHttpClient.builder()) + .build() } protected def getProducer(aggregate: Boolean): KinesisDataGenerator = { @@ -89,9 +95,10 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi // Create a stream. The number of shards determines the provisioned throughput. logInfo(s"Creating stream ${_streamName}") - val createStreamRequest = new CreateStreamRequest() - createStreamRequest.setStreamName(_streamName) - createStreamRequest.setShardCount(streamShardCount) + val createStreamRequest = CreateStreamRequest.builder() + .streamName(_streamName) + .shardCount(streamShardCount) + .build() kinesisClient.createStream(createStreamRequest) // The stream is now being created. Wait for it to become active. @@ -101,25 +108,30 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi } def getShards(): Seq[Shard] = { - kinesisClient.describeStream(_streamName).getStreamDescription.getShards.asScala.toSeq + val describeStreamRequest = DescribeStreamRequest.builder() + .streamName(_streamName) + .build() + kinesisClient.describeStream(describeStreamRequest).streamDescription.shards.asScala.toSeq } def splitShard(shardId: String): Unit = { - val splitShardRequest = new SplitShardRequest() - splitShardRequest.withStreamName(_streamName) - splitShardRequest.withShardToSplit(shardId) - // Set a half of the max hash value - splitShardRequest.withNewStartingHashKey("170141183460469231731687303715884105728") + val splitShardRequest = SplitShardRequest.builder() + .streamName(_streamName) + .shardToSplit(shardId) + // Set a half of the max hash value + .newStartingHashKey("170141183460469231731687303715884105728") + .build() kinesisClient.splitShard(splitShardRequest) // Wait for the shards to become active waitForStreamToBeActive(_streamName) } def mergeShard(shardToMerge: String, adjacentShardToMerge: String): Unit = { - val mergeShardRequest = new MergeShardsRequest - mergeShardRequest.withStreamName(_streamName) - mergeShardRequest.withShardToMerge(shardToMerge) - mergeShardRequest.withAdjacentShardToMerge(adjacentShardToMerge) + val mergeShardRequest = MergeShardsRequest.builder() + .streamName(_streamName) + .shardToMerge(shardToMerge) + .adjacentShardToMerge(adjacentShardToMerge) + .build() kinesisClient.mergeShards(mergeShardRequest) // Wait for the shards to become active waitForStreamToBeActive(_streamName) @@ -145,9 +157,12 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi } def deleteStream(): Unit = { + val deleteStreamRequest = DeleteStreamRequest.builder() + .streamName(streamName) + .build() try { if (streamCreated) { - kinesisClient.deleteStream(streamName) + kinesisClient.deleteStream(deleteStreamRequest) } } catch { case e: Exception => @@ -156,10 +171,11 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi } def deleteDynamoDBTable(tableName: String): Unit = { + val deleteTableRequest = DeleteTableRequest.builder() + .tableName(tableName) + .build() try { - val table = dynamoDB.getTable(tableName) - table.delete() - table.waitForDelete() + dynamoDB.deleteTable(deleteTableRequest) } catch { case e: Exception => logWarning(log"Could not delete DynamoDB table ${MDC(TABLE_NAME, tableName)}", e) @@ -168,11 +184,14 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi private def describeStream(streamNameToDescribe: String): Option[StreamDescription] = { try { - val describeStreamRequest = new DescribeStreamRequest().withStreamName(streamNameToDescribe) - val desc = kinesisClient.describeStream(describeStreamRequest).getStreamDescription() + val describeStreamRequest = DescribeStreamRequest.builder() + .streamName(streamNameToDescribe) + .build() + val desc = kinesisClient.describeStream(describeStreamRequest).streamDescription Some(desc) } catch { case rnfe: ResourceNotFoundException => + logWarning(s"Could not describe stream $streamNameToDescribe", rnfe) None } } @@ -187,9 +206,10 @@ private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Loggi } private def waitForStreamToBeActive(streamNameToWaitFor: String): Unit = { - val describeStreamRequest = new DescribeStreamRequest() - .withStreamName(streamNameToWaitFor) - streamExistsWaiter.run(new WaiterParameters(describeStreamRequest)) + val describeStreamRequest = DescribeStreamRequest.builder() + .streamName(streamNameToWaitFor) + .build() + streamExistsWaiter.waitUntilStreamExists(describeStreamRequest) } } @@ -201,10 +221,11 @@ private[kinesis] object KinesisTestUtils { def getRegionNameByEndpoint(endpoint: String): String = { val uri = new java.net.URI(endpoint) - RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX) + val kinesisServiceMetadata = new KinesisServiceMetadata() + kinesisServiceMetadata.regions .asScala - .find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost)) - .map(_.getName) + .find(r => kinesisServiceMetadata.endpointFor(r).toString.equals(uri.getHost)) + .map(_.id) .getOrElse( throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint")) } @@ -239,20 +260,20 @@ private[kinesis] object KinesisTestUtils { } def isAWSCredentialsPresent: Boolean = { - Try { new DefaultAWSCredentialsProviderChain().getCredentials() }.isSuccess + Try { DefaultCredentialsProvider.create().resolveCredentials() }.isSuccess } - def getAWSCredentials(): AWSCredentials = { + def getAWSCredentials(): AwsCredentials = { assert(shouldRunTests, "Kinesis test not enabled, should not attempt to get AWS credentials") - Try { new DefaultAWSCredentialsProviderChain().getCredentials() } match { + Try { DefaultCredentialsProvider.create().resolveCredentials() } match { case Success(cred) => cred case Failure(e) => throw new Exception( s""" |Kinesis tests enabled using environment variable $envVarNameForEnablingTests |but could not find AWS credentials. Please follow instructions in AWS documentation - |to set the credentials in your system such that the DefaultAWSCredentialsProviderChain + |to set the credentials in your system such that the DefaultCredentialsProvider |can find the credentials. """.stripMargin) } @@ -266,19 +287,21 @@ private[kinesis] trait KinesisDataGenerator { } private[kinesis] class SimpleDataGenerator( - client: AmazonKinesisClient) extends KinesisDataGenerator { + client: KinesisClient) extends KinesisDataGenerator { override def sendData(streamName: String, data: Seq[Int]): Map[String, Seq[(Int, String)]] = { val shardIdToSeqNumbers = new mutable.HashMap[String, ArrayBuffer[(Int, String)]]() data.foreach { num => val str = num.toString - val data = ByteBuffer.wrap(str.getBytes(StandardCharsets.UTF_8)) - val putRecordRequest = new PutRecordRequest().withStreamName(streamName) - .withData(data) - .withPartitionKey(str) - - val putRecordResult = client.putRecord(putRecordRequest) - val shardId = putRecordResult.getShardId - val seqNumber = putRecordResult.getSequenceNumber() + val data = SdkBytes.fromByteArray(str.getBytes(StandardCharsets.UTF_8)) + val putRecordRequest = PutRecordRequest.builder() + .streamName(streamName) + .data(data) + .partitionKey(str) + .build() + + val putRecordResponse = client.putRecord(putRecordRequest) + val shardId = putRecordResponse.shardId + val seqNumber = putRecordResponse.sequenceNumber val sentSeqNumbers = shardIdToSeqNumbers.getOrElseUpdate(shardId, new ArrayBuffer[(Int, String)]()) sentSeqNumbers += ((num, seqNumber)) diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 0396d3cc64d14..34587d9b17ca5 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -115,18 +115,16 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m You may also provide the following settings. This is currently only supported in Scala and Java. - - A "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. + - A "message handler function" that takes a Kinesis `KinesisClientRecord` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key.
```scala - import collection.JavaConverters._ import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.kinesis.KinesisInputDStream import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.kinesis.KinesisInitialPositions - import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration - import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel + import software.amazon.kinesis.metrics.{MetricsLevel, MetricsUtil} val kinesisStream = KinesisInputDStream.builder .streamingContext(streamingContext) @@ -138,21 +136,23 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m .checkpointInterval([checkpoint interval]) .storageLevel(StorageLevel.MEMORY_AND_DISK_2) .metricsLevel(MetricsLevel.DETAILED) - .metricsEnabledDimensions(KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS.asScala.toSet) + .metricsEnabledDimensions( + Set(MetricsUtil.OPERATION_DIMENSION_NAME, MetricsUtil.SHARD_ID_DIMENSION_NAME)) .buildWithMessageHandler([message handler]) ```
```java + import java.util.Set; + import scala.jdk.javaapi.CollectionConverters; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.kinesis.KinesisInputDStream; import org.apache.spark.streaming.Seconds; import org.apache.spark.streaming.StreamingContext; import org.apache.spark.streaming.kinesis.KinesisInitialPositions; - import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; - import com.amazonaws.services.kinesis.metrics.interfaces.MetricsLevel; - import scala.collection.JavaConverters; + import software.amazon.kinesis.metrics.MetricsLevel; + import software.amazon.kinesis.metrics.MetricsUtil; KinesisInputDStream kinesisStream = KinesisInputDStream.builder() .streamingContext(streamingContext) @@ -165,11 +165,10 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m .storageLevel(StorageLevel.MEMORY_AND_DISK_2) .metricsLevel(MetricsLevel.DETAILED) .metricsEnabledDimensions( - JavaConverters.asScalaSetConverter( - KinesisClientLibConfiguration.DEFAULT_METRICS_ENABLED_DIMENSIONS - ) - .asScala().toSet() - ) + CollectionConverters.asScala( + Set.of( + MetricsUtil.OPERATION_DIMENSION_NAME, + MetricsUtil.SHARD_ID_DIMENSION_NAME)).toSet()) .buildWithMessageHandler([message handler]); ``` @@ -194,7 +193,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `[initial position]`: Can be either `KinesisInitialPositions.TrimHorizon` or `KinesisInitialPositions.Latest` or `KinesisInitialPositions.AtTimestamp` (see [`Kinesis Checkpointing`](#kinesis-checkpointing) section and [`Amazon Kinesis API documentation`](http://docs.aws.amazon.com/streams/latest/dev/developing-consumers-with-sdk.html) for more details). - - `[message handler]`: A function that takes a Kinesis `Record` and outputs generic `T`. + - `[message handler]`: A function that takes a Kinesis `KinesisClientRecord` and outputs generic `T`. In other versions of the API, you can also specify the AWS access key and secret key directly. diff --git a/pom.xml b/pom.xml index ae38b87c3f957..143808b488419 100644 --- a/pom.xml +++ b/pom.xml @@ -159,12 +159,11 @@ 4.2.37 1.12.1 - 1.15.3 + 2.7.2 - 1.12.681 2.35.4 - 1.0.5 + 1.0.6 hadoop3-2.2.29 1.3.0