Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Assumed Role ARN as a config option #20

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/options.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"aws.region": "eu-west-1",
"aws.access.key.id": "",
"aws.secret.key": "",
"aws.assume.role.arn": "",

"dynamodb.table.env.tag.key": "environment",
"dynamodb.table.env.tag.value": "dev",
Expand All @@ -38,6 +39,8 @@
"connect.dynamodb.rediscovery.period": "60000"
}
```
`aws.assume.role.arn` - ARN identifier of an IAM role that the KCL and Dynamo Clients can assume for cross account access

`dynamodb.table.env.tag.key` - tag key used to define environment. Useful if you have `staging` and `production` under same AWS account. Or if you want to use different Kafka Connect clusters to sync different tables.

`dynamodb.table.env.tag.value` - defines from which environment to ingest tables. For e.g. 'staging' or 'production'...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@ public void start(Map<String, String> properties) {
AwsClients.buildAWSResourceGroupsTaggingAPIClient(config.getAwsRegion(),
config.getResourceTaggingServiceEndpoint(),
config.getAwsAccessKeyIdValue(),
config.getAwsSecretKeyValue());
config.getAwsSecretKeyValue(),
config.getAwsAssumeRoleArn());

AmazonDynamoDB dynamoDBClient = AwsClients.buildDynamoDbClient(config.getAwsRegion(),
config.getDynamoDBServiceEndpoint(),
config.getAwsAccessKeyIdValue(),
config.getAwsSecretKeyValue());
config.getAwsSecretKeyValue(),
config.getAwsAssumeRoleArn());

if (tablesProvider == null) {
if (config.getWhitelistTables() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public class DynamoDBSourceConnectorConfig extends AbstractConfig {
public static final String SRC_KCL_TABLE_BILLING_MODE_DISPLAY = "KCL table billing mode";
public static final String SRC_KCL_TABLE_BILLING_MODE_DEFAULT = "PROVISIONED";

public static final String AWS_ASSUME_ROLE_ARN_CONFIG = "aws.assume.role.arn";
public static final String AWS_ASSUME_ROLE_ARN_DOC = "Define which role arn the KCL/Dynamo Client should assume.";
public static final String AWS_ASSUME_ROLE_ARN_DISPLAY = "Assume Role Arn";
public static final String AWS_ASSUME_ROLE_ARN_DEFAULT = null;

public static final String DST_TOPIC_PREFIX_CONFIG = "kafka.topic.prefix";
public static final String DST_TOPIC_PREFIX_DOC = "Define Kafka topic destination prefix. End will be the name of a table.";
public static final String DST_TOPIC_PREFIX_DISPLAY = "Topic prefix";
Expand Down Expand Up @@ -181,6 +186,15 @@ public static ConfigDef baseConfigDef() {
ConfigDef.Width.MEDIUM,
SRC_KCL_TABLE_BILLING_MODE_DISPLAY)

.define(AWS_ASSUME_ROLE_ARN_CONFIG,
ConfigDef.Type.STRING,
AWS_ASSUME_ROLE_ARN_DEFAULT,
ConfigDef.Importance.LOW,
AWS_ASSUME_ROLE_ARN_DOC,
AWS_GROUP, 10,
ConfigDef.Width.LONG,
AWS_ASSUME_ROLE_ARN_DISPLAY)

.define(DST_TOPIC_PREFIX_CONFIG,
ConfigDef.Type.STRING,
DST_TOPIC_PREFIX_DEFAULT,
Expand Down Expand Up @@ -272,4 +286,8 @@ public List<String> getWhitelistTables() {
public BillingMode getKCLTableBillingMode() {
return BillingMode.fromValue(getString(SRC_KCL_TABLE_BILLING_MODE_CONFIG));
}

public String getAwsAssumeRoleArn() {
return getString(AWS_ASSUME_ROLE_ARN_CONFIG);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ public void start(Map<String, String> configProperties) {
config.getAwsRegion(),
config.getDynamoDBServiceEndpoint(),
config.getAwsAccessKeyIdValue(),
config.getAwsSecretKeyValue());
config.getAwsSecretKeyValue(),
config.getAwsAssumeRoleArn());
}
tableDesc = client.describeTable(config.getTableName()).getTable();

Expand All @@ -142,11 +143,12 @@ public void start(Map<String, String> configProperties) {
config.getAwsRegion(),
config.getDynamoDBServiceEndpoint(),
config.getAwsAccessKeyIdValue(),
config.getAwsSecretKeyValue());
config.getAwsSecretKeyValue(),
config.getAwsAssumeRoleArn());

if (kclWorker == null) {
kclWorker = new KclWorkerImpl(
AwsClients.getCredentials(config.getAwsAccessKeyIdValue(), config.getAwsSecretKeyValue()),
AwsClients.getCredentials(config.getAwsAccessKeyIdValue(), config.getAwsSecretKeyValue(), config.getAwsAssumeRoleArn()),
eventsQueue,
shardRegister);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import com.amazonaws.client.builder.AwsClientBuilder;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDB;
import com.amazonaws.services.dynamodbv2.AmazonDynamoDBClientBuilder;
Expand All @@ -22,44 +23,56 @@ public class AwsClients {
public static AmazonDynamoDB buildDynamoDbClient(String awsRegion,
String serviceEndpoint,
String awsAccessKeyID,
String awsSecretKey) {

String awsSecretKey,
String awsAssumeRoleArn) {

return (AmazonDynamoDB) configureBuilder(
AmazonDynamoDBClientBuilder.standard(),
awsRegion, serviceEndpoint,
awsAccessKeyID,
awsSecretKey)
awsSecretKey,
awsAssumeRoleArn)
.build();
}

public static AWSResourceGroupsTaggingAPI buildAWSResourceGroupsTaggingAPIClient(String awsRegion,
String serviceEndpoint,
String awsAccessKeyID,
String awsSecretKey) {
String awsSecretKey,
String awsAssumeRoleArn) {
return (AWSResourceGroupsTaggingAPI) configureBuilder(
AWSResourceGroupsTaggingAPIClientBuilder.standard(),
awsRegion, serviceEndpoint,
awsAccessKeyID,
awsSecretKey)
awsSecretKey,
awsAssumeRoleArn)
.build();
}

public static AmazonDynamoDBStreams buildDynamoDbStreamsClient(String awsRegion,
String serviceEndpoint,
String awsAccessKeyID,
String awsSecretKey) {
String awsSecretKey,
String awsAssumeRoleArn) {
return (AmazonDynamoDBStreams) configureBuilder(
AmazonDynamoDBStreamsClientBuilder.standard(),
awsRegion, serviceEndpoint,
awsAccessKeyID,
awsSecretKey)
awsSecretKey,
awsAssumeRoleArn)
.build();

}

public static AWSCredentialsProvider getCredentials(String awsAccessKeyID, String awsSecretKey) {
if (awsAccessKeyID == null || awsSecretKey == null) {
public static AWSCredentialsProvider getCredentials(String awsAccessKeyID,
String awsSecretKey,
String awsAssumeRoleArn) {
if (awsAssumeRoleArn != null ) {
LOGGER.debug("Using STSAssumeRoleSessionCredentialsProvider");
AWSCredentialsProvider awsCredentialsProviderChain = DefaultAWSCredentialsProviderChain.getInstance();
return new STSAssumeRoleSessionCredentialsProvider(awsCredentialsProviderChain,
awsAssumeRoleArn, "kafkaconnect");
} else if (awsAccessKeyID == null || awsSecretKey == null) {
LOGGER.debug("Using DefaultAWSCredentialsProviderChain");

return DefaultAWSCredentialsProviderChain.getInstance();
Expand All @@ -75,9 +88,10 @@ private static AwsClientBuilder configureBuilder(AwsClientBuilder builder,
String awsRegion,
String serviceEndpoint,
String awsAccessKeyID,
String awsSecretKey) {
String awsSecretKey,
String awsAssumeRoleArn) {

builder.withCredentials(getCredentials(awsAccessKeyID, awsSecretKey))
builder.withCredentials(getCredentials(awsAccessKeyID, awsSecretKey, awsAssumeRoleArn))
.withClientConfiguration(new ClientConfiguration().withThrottledRetries(true));

if(serviceEndpoint != null && !serviceEndpoint.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KafkaConnectITBase {
protected static final String AWS_REGION_CONFIG = "eu-west-3";
protected static final String AWS_ACCESS_KEY_ID_CONFIG = "ABCD";
protected static final String AWS_SECRET_KEY_CONFIG = "1234";
protected static final String AWS_ASSUME_ROLE_ARN_CONFIG = null;
protected static final String SRC_DYNAMODB_TABLE_INGESTION_TAG_KEY_CONFIG = "datalake-ingest";

private static Network network;
Expand Down Expand Up @@ -187,7 +188,8 @@ private AmazonDynamoDB getDynamoDBClient() {
AWS_REGION_CONFIG,
dynamodb.getEndpoint(),
AWS_ACCESS_KEY_ID_CONFIG,
AWS_SECRET_KEY_CONFIG
AWS_SECRET_KEY_CONFIG,
AWS_ASSUME_ROLE_ARN_CONFIG
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package com.trustpilot.connector.dynamodb.aws;

import com.amazonaws.auth.AWSCredentialsProvider;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class AwsClientsTests {

@Test
public void stsAssumeRoleProviderReturned() {
String testRoleArn = "arn:aws:iam::111111111111:role/unit-test";
AWSCredentialsProvider provider = AwsClients.getCredentials(
null,
null,
testRoleArn
);

DefaultAWSCredentialsProviderChain testChain = Mockito.mock(DefaultAWSCredentialsProviderChain.class);
STSAssumeRoleSessionCredentialsProvider expectedProvider = new STSAssumeRoleSessionCredentialsProvider(
testChain.getInstance(),
testRoleArn,
"kafkaconnect"
);
assertEquals(provider.getClass(), expectedProvider.getClass());
}

@Test
public void defaultProviderReturned() {
AWSCredentialsProvider provider = AwsClients.getCredentials(
null,
null,
null
);

assertEquals(provider.getClass(), DefaultAWSCredentialsProviderChain.class);
}

@Test
public void staticCredentialsReturned() {
AWSCredentialsProvider provider = AwsClients.getCredentials(
"unit-test",
"unit-test",
null
);

assertEquals(provider.getClass(), AWSStaticCredentialsProvider.class);
}
}