Skip to content

Commit da18da0

Browse files
author
Budde
committed
[SPARK-19405][STREAMING] Add support to KinesisUtils for cross-account Kinesis reads via STS
- Add dependency on aws-java-sdk-sts - Replace SerializableAWSCredentials with new SerializableCredentialsProvider interface - Make KinesisReceiver take SerializableCredentialsProvider as argument and pass credential provider to KCL - Add new implementations of KinesisUtils.createStream() that take STS arguments - Make JavaKinesisStreamSuite test the entire KinesisUtils Java API - Update KCL/AWS SDK dependencies to 1.7.x/1.11.x - Make SerializableCredentialsProvider a sealed trait and classes to their own file
1 parent 15627ac commit da18da0

File tree

17 files changed

+407
-83
lines changed

17 files changed

+407
-83
lines changed

external/kinesis-asl/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
<artifactId>amazon-kinesis-client</artifactId>
5959
<version>${aws.kinesis.client.version}</version>
6060
</dependency>
61+
<dependency>
62+
<groupId>com.amazonaws</groupId>
63+
<artifactId>aws-java-sdk-sts</artifactId>
64+
<version>${aws.java.sdk.version}</version>
65+
</dependency>
6166
<dependency>
6267
<groupId>com.amazonaws</groupId>
6368
<artifactId>amazon-kinesis-producer</artifactId>

external/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ public static void main(String[] args) throws Exception {
127127

128128
// Get the region name from the endpoint URL to save Kinesis Client Library metadata in
129129
// DynamoDB of the same region as the Kinesis stream
130-
String regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName();
130+
String regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl);
131131

132132
// Setup the Spark config and StreamingContext
133133
SparkConf sparkConfig = new SparkConf().setAppName("JavaKinesisWordCountASL");
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.examples.streaming
19+
20+
import scala.collection.JavaConverters._
21+
22+
import com.amazonaws.regions.RegionUtils
23+
import com.amazonaws.services.kinesis.AmazonKinesis
24+
25+
private[streaming] object KinesisExampleUtils {
26+
def getRegionNameByEndpoint(endpoint: String): String = {
27+
val uri = new java.net.URI(endpoint)
28+
RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX)
29+
.asScala
30+
.find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost))
31+
.map(_.getName)
32+
.getOrElse(
33+
throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint"))
34+
}
35+
}

external/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ object KinesisWordCountASL extends Logging {
127127

128128
// Get the region name from the endpoint URL to save Kinesis Client Library metadata in
129129
// DynamoDB of the same region as the Kinesis stream
130-
val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
130+
val regionName = KinesisExampleUtils.getRegionNameByEndpoint(endpointUrl)
131131

132132
// Setup the SparkConfig and StreamingContext
133133
val sparkConfig = new SparkConf().setAppName("KinesisWordCountASL")

external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
7979
@transient private val isBlockIdValid: Array[Boolean] = Array.empty,
8080
val retryTimeoutMs: Int = 10000,
8181
val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _,
82-
val awsCredentialsOption: Option[SerializableAWSCredentials] = None
82+
val kinesisCredsProvider: SerializableCredentialsProvider = DefaultCredentialsProvider
8383
) extends BlockRDD[T](sc, _blockIds) {
8484

8585
require(_blockIds.length == arrayOfseqNumberRanges.length,
@@ -105,9 +105,7 @@ class KinesisBackedBlockRDD[T: ClassTag](
105105
}
106106

107107
def getBlockFromKinesis(): Iterator[T] = {
108-
val credentials = awsCredentialsOption.getOrElse {
109-
new DefaultAWSCredentialsProviderChain().getCredentials()
110-
}
108+
val credentials = kinesisCredsProvider.provider.getCredentials
111109
partition.seqNumberRanges.ranges.iterator.flatMap { range =>
112110
new KinesisSequenceRangeIterator(credentials, endpointUrl, regionName,
113111
range, retryTimeoutMs).map(messageHandler)
@@ -143,7 +141,7 @@ class KinesisSequenceRangeIterator(
143141
private var lastSeqNumber: String = null
144142
private var internalIterator: Iterator[Record] = null
145143

146-
client.setEndpoint(endpointUrl, "kinesis", regionId)
144+
client.setEndpoint(endpointUrl)
147145

148146
override protected def getNext(): Record = {
149147
var nextRecord: Record = null

external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.concurrent._
2121
import scala.util.control.NonFatal
2222

2323
import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
24-
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
24+
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason
2525

2626
import org.apache.spark.internal.Logging
2727
import org.apache.spark.streaming.Duration

external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
3939
checkpointInterval: Duration,
4040
storageLevel: StorageLevel,
4141
messageHandler: Record => T,
42-
awsCredentialsOption: Option[SerializableAWSCredentials]
42+
kinesisCredsProvider: SerializableCredentialsProvider
4343
) extends ReceiverInputDStream[T](_ssc) {
4444

4545
private[streaming]
@@ -61,7 +61,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
6161
isBlockIdValid = isBlockIdValid,
6262
retryTimeoutMs = ssc.graph.batchDuration.milliseconds.toInt,
6363
messageHandler = messageHandler,
64-
awsCredentialsOption = awsCredentialsOption)
64+
kinesisCredsProvider = kinesisCredsProvider)
6565
} else {
6666
logWarning("Kinesis sequence number information was not present with some block metadata," +
6767
" it may not be possible to recover from failures")
@@ -71,6 +71,7 @@ private[kinesis] class KinesisInputDStream[T: ClassTag](
7171

7272
override def getReceiver(): Receiver[T] = {
7373
new KinesisReceiver(streamName, endpointUrl, regionName, initialPositionInStream,
74-
checkpointAppName, checkpointInterval, storageLevel, messageHandler, awsCredentialsOption)
74+
checkpointAppName, checkpointInterval, storageLevel, messageHandler,
75+
kinesisCredsProvider)
7576
}
7677
}

external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala

Lines changed: 13 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import scala.collection.JavaConverters._
2323
import scala.collection.mutable
2424
import scala.util.control.NonFatal
2525

26-
import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, DefaultAWSCredentialsProviderChain}
2726
import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer, IRecordProcessorFactory}
2827
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream, KinesisClientLibConfiguration, Worker}
2928
import com.amazonaws.services.kinesis.model.Record
@@ -34,13 +33,6 @@ import org.apache.spark.streaming.Duration
3433
import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver}
3534
import org.apache.spark.util.Utils
3635

37-
private[kinesis]
38-
case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
39-
extends AWSCredentials {
40-
override def getAWSAccessKeyId: String = accessKeyId
41-
override def getAWSSecretKey: String = secretKey
42-
}
43-
4436
/**
4537
* Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver.
4638
* This implementation relies on the Kinesis Client Library (KCL) Worker as described here:
@@ -78,8 +70,9 @@ case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
7870
* See the Kinesis Spark Streaming documentation for more
7971
* details on the different types of checkpoints.
8072
* @param storageLevel Storage level to use for storing the received objects
81-
* @param awsCredentialsOption Optional AWS credentials, used when user directly specifies
82-
* the credentials
73+
* @param kinesisCredsProvider SerializableCredentialsProvider instance that will be used to
74+
* generate the AWSCredentialsProvider instance used for KCL
75+
* authorization.
8376
*/
8477
private[kinesis] class KinesisReceiver[T](
8578
val streamName: String,
@@ -90,7 +83,7 @@ private[kinesis] class KinesisReceiver[T](
9083
checkpointInterval: Duration,
9184
storageLevel: StorageLevel,
9285
messageHandler: Record => T,
93-
awsCredentialsOption: Option[SerializableAWSCredentials])
86+
kinesisCredsProvider: SerializableCredentialsProvider)
9487
extends Receiver[T](storageLevel) with Logging { receiver =>
9588

9689
/*
@@ -147,14 +140,15 @@ private[kinesis] class KinesisReceiver[T](
147140
workerId = Utils.localHostName() + ":" + UUID.randomUUID()
148141

149142
kinesisCheckpointer = new KinesisCheckpointer(receiver, checkpointInterval, workerId)
150-
// KCL config instance
151-
val awsCredProvider = resolveAWSCredentialsProvider()
152-
val kinesisClientLibConfiguration =
153-
new KinesisClientLibConfiguration(checkpointAppName, streamName, awsCredProvider, workerId)
154-
.withKinesisEndpoint(endpointUrl)
155-
.withInitialPositionInStream(initialPositionInStream)
156-
.withTaskBackoffTimeMillis(500)
157-
.withRegionName(regionName)
143+
val kinesisClientLibConfiguration = new KinesisClientLibConfiguration(
144+
checkpointAppName,
145+
streamName,
146+
kinesisCredsProvider.provider,
147+
workerId)
148+
.withKinesisEndpoint(endpointUrl)
149+
.withInitialPositionInStream(initialPositionInStream)
150+
.withTaskBackoffTimeMillis(500)
151+
.withRegionName(regionName)
158152

159153
/*
160154
* RecordProcessorFactory creates impls of IRecordProcessor.
@@ -305,25 +299,6 @@ private[kinesis] class KinesisReceiver[T](
305299
}
306300
}
307301

308-
/**
309-
* If AWS credential is provided, return a AWSCredentialProvider returning that credential.
310-
* Otherwise, return the DefaultAWSCredentialsProviderChain.
311-
*/
312-
private def resolveAWSCredentialsProvider(): AWSCredentialsProvider = {
313-
awsCredentialsOption match {
314-
case Some(awsCredentials) =>
315-
logInfo("Using provided AWS credentials")
316-
new AWSCredentialsProvider {
317-
override def getCredentials: AWSCredentials = awsCredentials
318-
override def refresh(): Unit = { }
319-
}
320-
case None =>
321-
logInfo("Using DefaultAWSCredentialsProviderChain")
322-
new DefaultAWSCredentialsProviderChain()
323-
}
324-
}
325-
326-
327302
/**
328303
* Class to handle blocks generated by this receiver's block generator. Specifically, in
329304
* the context of the Kinesis Receiver, this handler does the following.

external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.util.control.NonFatal
2323

2424
import com.amazonaws.services.kinesis.clientlibrary.exceptions.{InvalidStateException, KinesisClientLibDependencyException, ShutdownException, ThrottlingException}
2525
import com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, IRecordProcessorCheckpointer}
26-
import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
26+
import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason
2727
import com.amazonaws.services.kinesis.model.Record
2828

2929
import org.apache.spark.internal.Logging

external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisTestUtils.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import com.amazonaws.auth.{AWSCredentials, DefaultAWSCredentialsProviderChain}
3030
import com.amazonaws.regions.RegionUtils
3131
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClient
3232
import com.amazonaws.services.dynamodbv2.document.DynamoDB
33-
import com.amazonaws.services.kinesis.AmazonKinesisClient
33+
import com.amazonaws.services.kinesis.{AmazonKinesis, AmazonKinesisClient}
3434
import com.amazonaws.services.kinesis.model._
3535

3636
import org.apache.spark.internal.Logging
@@ -43,7 +43,7 @@ import org.apache.spark.internal.Logging
4343
private[kinesis] class KinesisTestUtils(streamShardCount: Int = 2) extends Logging {
4444

4545
val endpointUrl = KinesisTestUtils.endpointUrl
46-
val regionName = RegionUtils.getRegionByEndpoint(endpointUrl).getName()
46+
val regionName = KinesisTestUtils.getRegionNameByEndpoint(endpointUrl)
4747

4848
private val createStreamTimeoutSeconds = 300
4949
private val describeStreamPollTimeSeconds = 1
@@ -205,6 +205,16 @@ private[kinesis] object KinesisTestUtils {
205205
val endVarNameForEndpoint = "KINESIS_TEST_ENDPOINT_URL"
206206
val defaultEndpointUrl = "https://kinesis.us-west-2.amazonaws.com"
207207

208+
def getRegionNameByEndpoint(endpoint: String): String = {
209+
val uri = new java.net.URI(endpoint)
210+
RegionUtils.getRegionsForService(AmazonKinesis.ENDPOINT_PREFIX)
211+
.asScala
212+
.find(_.getAvailableEndpoints.asScala.toSeq.contains(uri.getHost))
213+
.map(_.getName)
214+
.getOrElse(
215+
throw new IllegalArgumentException(s"Could not resolve region for endpoint: $endpoint"))
216+
}
217+
208218
lazy val shouldRunTests = {
209219
val isEnvSet = sys.env.get(envVarNameForEnablingTests) == Some("1")
210220
if (isEnvSet) {

0 commit comments

Comments
 (0)