Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@

import java.util.List;
import java.util.stream.Collectors;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.http.urlconnection.UrlConnectionHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.glue.GlueClient;
import software.amazon.awssdk.services.glue.model.DeleteDatabaseRequest;
import software.amazon.awssdk.services.s3.S3Client;
Expand All @@ -31,6 +34,9 @@
import software.amazon.awssdk.services.s3.model.ListObjectsV2Request;
import software.amazon.awssdk.services.s3.model.ListObjectsV2Response;
import software.amazon.awssdk.services.s3.model.ObjectIdentifier;
import software.amazon.awssdk.services.s3control.S3ControlClient;
import software.amazon.awssdk.services.s3control.model.CreateAccessPointRequest;
import software.amazon.awssdk.services.s3control.model.DeleteAccessPointRequest;

public class AwsIntegTestUtil {

Expand All @@ -39,6 +45,25 @@ public class AwsIntegTestUtil {
private AwsIntegTestUtil() {
}

/**
* Get the environment variable AWS_REGION to use for testing
* @return region
*/
public static String testRegion() {
return System.getenv("AWS_REGION");
}

/**
* Get the environment variable AWS_CROSS_REGION to use for testing
* @return region
*/
public static String testCrossRegion() {
String crossRegion = System.getenv("AWS_CROSS_REGION");
Preconditions.checkArgument(!testRegion().equals(crossRegion), "AWS_REGION should not be equal to " +
"AWS_CROSS_REGION");
return crossRegion;
}

/**
* Set the environment variable AWS_TEST_BUCKET for a default bucket to use for testing
* @return bucket name
Expand All @@ -47,6 +72,14 @@ public static String testBucketName() {
return System.getenv("AWS_TEST_BUCKET");
}

/**
* Set the environment variable AWS_TEST_CROSS_REGION_BUCKET for a default bucket to use for testing
* @return bucket name
*/
public static String testCrossRegionBucketName() {
return System.getenv("AWS_TEST_CROSS_REGION_BUCKET");
}

/**
* Set the environment variable AWS_TEST_ACCOUNT_ID for a default account to use for testing
* @return account id
Expand Down Expand Up @@ -81,4 +114,38 @@ public static void cleanGlueCatalog(GlueClient glue, List<String> namespaces) {
}
}
}

public static S3ControlClient createS3ControlClient(String region) {
return S3ControlClient.builder()
.httpClientBuilder(UrlConnectionHttpClient.builder())
.region(Region.of(region))
.build();
}

public static void createAccessPoint(S3ControlClient s3ControlClient, String accessPointName, String bucketName) {
try {
s3ControlClient.createAccessPoint(CreateAccessPointRequest
.builder()
.name(accessPointName)
.bucket(bucketName)
.accountId(testAccountId())
.build()
);
} catch (Exception e) {
LOG.error("Cannot create access point {}", accessPointName, e);
}
}

public static void deleteAccessPoint(S3ControlClient s3ControlClient, String accessPointName) {
try {
s3ControlClient.deleteAccessPoint(DeleteAccessPointRequest
.builder()
.name(accessPointName)
.accountId(testAccountId())
.build()
);
} catch (Exception e) {
LOG.error("Cannot delete access point {}", accessPointName, e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@
import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.io.OutputFile;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.regions.PartitionMetadata;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.kms.KmsClient;
import software.amazon.awssdk.services.kms.model.ListAliasesRequest;
import software.amazon.awssdk.services.kms.model.ListAliasesResponse;
Expand All @@ -57,14 +61,21 @@
import software.amazon.awssdk.services.s3.model.Permission;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.ServerSideEncryption;
import software.amazon.awssdk.services.s3control.S3ControlClient;
import software.amazon.awssdk.utils.ImmutableMap;
import software.amazon.awssdk.utils.IoUtils;

public class TestS3FileIOIntegration {

private static AwsClientFactory clientFactory;
private static S3Client s3;
private static S3ControlClient s3Control;
private static S3ControlClient crossRegionS3Control;
private static KmsClient kms;
private static String bucketName;
private static String crossRegionBucketName;
private static String accessPointName;
private static String crossRegionAccessPointName;
private static String prefix;
private static byte[] contentBytes;
private static String content;
Expand All @@ -78,17 +89,27 @@ public static void beforeClass() {
clientFactory = AwsClientFactories.defaultFactory();
s3 = clientFactory.s3();
kms = clientFactory.kms();
s3Control = AwsIntegTestUtil.createS3ControlClient(AwsIntegTestUtil.testRegion());
crossRegionS3Control = AwsIntegTestUtil.createS3ControlClient(AwsIntegTestUtil.testCrossRegion());
bucketName = AwsIntegTestUtil.testBucketName();
crossRegionBucketName = AwsIntegTestUtil.testCrossRegionBucketName();
accessPointName = UUID.randomUUID().toString();
crossRegionAccessPointName = UUID.randomUUID().toString();
prefix = UUID.randomUUID().toString();
contentBytes = new byte[1024 * 1024 * 10];
deletionBatchSize = 3;
content = new String(contentBytes, StandardCharsets.UTF_8);
kmsKeyArn = kms.createKey().keyMetadata().arn();

AwsIntegTestUtil.createAccessPoint(s3Control, accessPointName, bucketName);
AwsIntegTestUtil.createAccessPoint(crossRegionS3Control, crossRegionAccessPointName, crossRegionBucketName);
}

@AfterClass
public static void afterClass() {
AwsIntegTestUtil.cleanS3Bucket(s3, bucketName, prefix);
AwsIntegTestUtil.deleteAccessPoint(s3Control, accessPointName);
AwsIntegTestUtil.deleteAccessPoint(crossRegionS3Control, crossRegionAccessPointName);
kms.scheduleKeyDeletion(ScheduleKeyDeletionRequest.builder().keyId(kmsKeyArn).pendingWindowInDays(7).build());
}

Expand All @@ -98,6 +119,11 @@ public void before() {
objectUri = String.format("s3://%s/%s", bucketName, objectKey);
}

@BeforeEach
public void beforeEach() {
clientFactory.initialize(Maps.newHashMap());
}

@Test
public void testNewInputStream() throws Exception {
s3.putObject(PutObjectRequest.builder().bucket(bucketName).key(objectKey).build(),
Expand All @@ -106,6 +132,33 @@ public void testNewInputStream() throws Exception {
validateRead(s3FileIO);
}

@Test
public void testNewInputStreamWithAccessPoint() throws Exception {
s3.putObject(PutObjectRequest.builder().bucket(bucketName).key(objectKey).build(),
RequestBody.fromBytes(contentBytes));
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3);
s3FileIO.initialize(ImmutableMap.of(AwsProperties.S3_ACCESS_POINTS_PREFIX + bucketName,
testAccessPointARN(AwsIntegTestUtil.testRegion(), accessPointName)));
validateRead(s3FileIO);
}

@Test
public void testNewInputStreamWithCrossRegionAccessPoint() throws Exception {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a side note that in aws/aws-sdk-java-v2@51632cb, a new config multiRegionEnabled is added, which defaults to true and controls cross-region access when using MRAP. The default works for us, so we probably don't need to add it for now.

clientFactory.initialize(ImmutableMap.of(AwsProperties.S3_USE_ARN_REGION_ENABLED, "true"));
S3Client s3Client = clientFactory.s3();
s3Client.putObject(PutObjectRequest.builder().bucket(bucketName).key(objectKey).build(),
RequestBody.fromBytes(contentBytes));
// make a copy in cross-region bucket
s3Client.putObject(PutObjectRequest.builder()
.bucket(testAccessPointARN(AwsIntegTestUtil.testCrossRegion(), crossRegionAccessPointName))
.key(objectKey).build(),
RequestBody.fromBytes(contentBytes));
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3);
s3FileIO.initialize(ImmutableMap.of(AwsProperties.S3_ACCESS_POINTS_PREFIX + bucketName,
testAccessPointARN(AwsIntegTestUtil.testCrossRegion(), crossRegionAccessPointName)));
validateRead(s3FileIO);
}

@Test
public void testNewOutputStream() throws Exception {
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3);
Expand All @@ -116,6 +169,34 @@ public void testNewOutputStream() throws Exception {
Assert.assertEquals(content, result);
}

@Test
public void testNewOutputStreamWithAccessPoint() throws Exception {
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3);
s3FileIO.initialize(ImmutableMap.of(AwsProperties.S3_ACCESS_POINTS_PREFIX + bucketName,
testAccessPointARN(AwsIntegTestUtil.testRegion(), accessPointName)));
write(s3FileIO);
InputStream stream = s3.getObject(GetObjectRequest.builder().bucket(bucketName).key(objectKey).build());
String result = IoUtils.toUtf8String(stream);
stream.close();
Assert.assertEquals(content, result);
}

@Test
public void testNewOutputStreamWithCrossRegionAccessPoint() throws Exception {
clientFactory.initialize(ImmutableMap.of(AwsProperties.S3_USE_ARN_REGION_ENABLED, "true"));
S3Client s3Client = clientFactory.s3();
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3);
s3FileIO.initialize(ImmutableMap.of(AwsProperties.S3_ACCESS_POINTS_PREFIX + bucketName,
testAccessPointARN(AwsIntegTestUtil.testCrossRegion(), crossRegionAccessPointName)));
write(s3FileIO);
InputStream stream = s3Client.getObject(GetObjectRequest.builder()
.bucket(testAccessPointARN(AwsIntegTestUtil.testCrossRegion(), crossRegionAccessPointName))
.key(objectKey).build());
String result = IoUtils.toUtf8String(stream);
stream.close();
Assert.assertEquals(content, result);
}

@Test
public void testServerSideS3Encryption() throws Exception {
AwsProperties properties = new AwsProperties();
Expand Down Expand Up @@ -218,6 +299,23 @@ public void testDeleteFilesMultipleBatches() throws Exception {
testDeleteFiles(deletionBatchSize * 2, s3FileIO);
}

@Test
public void testDeleteFilesMultipleBatchesWithAccessPoints() throws Exception {
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3, getDeletionTestProperties());
s3FileIO.initialize(ImmutableMap.of(AwsProperties.S3_ACCESS_POINTS_PREFIX + bucketName,
testAccessPointARN(AwsIntegTestUtil.testRegion(), accessPointName)));
testDeleteFiles(deletionBatchSize * 2, s3FileIO);
}

@Test
public void testDeleteFilesMultipleBatchesWithCrossRegionAccessPoints() throws Exception {
clientFactory.initialize(ImmutableMap.of(AwsProperties.S3_USE_ARN_REGION_ENABLED, "true"));
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3, getDeletionTestProperties());
s3FileIO.initialize(ImmutableMap.of(AwsProperties.S3_ACCESS_POINTS_PREFIX + bucketName,
testAccessPointARN(AwsIntegTestUtil.testCrossRegion(), crossRegionAccessPointName)));
testDeleteFiles(deletionBatchSize * 2, s3FileIO);
}

@Test
public void testDeleteFilesLessThanBatchSize() throws Exception {
S3FileIO s3FileIO = new S3FileIO(clientFactory::s3, getDeletionTestProperties());
Expand Down Expand Up @@ -268,4 +366,10 @@ private void validateRead(S3FileIO s3FileIO) throws Exception {
stream.close();
Assert.assertEquals(content, result);
}

private String testAccessPointARN(String region, String accessPoint) {
// format: arn:aws:s3:region:account-id:accesspoint/resource
return String.format("arn:%s:s3:%s:%s:accesspoint/%s",
PartitionMetadata.of(Region.of(region)).id(), region, AwsIntegTestUtil.testAccountId(), accessPoint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ public class AssumeRoleAwsClientFactory implements AwsClientFactory {
private int timeout;
private String region;
private String s3Endpoint;
private boolean s3UseArnRegionEnabled;
private String httpClientType;

@Override
public S3Client s3() {
return S3Client.builder()
.applyMutation(this::configure)
.applyMutation(builder -> AwsClientFactories.configureEndpoint(builder, s3Endpoint))
.serviceConfiguration(s -> s.useArnRegionEnabled(s3UseArnRegionEnabled).build())
.build();
}

Expand Down Expand Up @@ -84,6 +86,8 @@ public void initialize(Map<String, String> properties) {

this.s3Endpoint = properties.get(AwsProperties.S3FILEIO_ENDPOINT);
this.tags = toTags(properties);
this.s3UseArnRegionEnabled = PropertyUtil.propertyAsBoolean(properties, AwsProperties.S3_ACCESS_POINTS_PREFIX,
AwsProperties.S3_USE_ARN_REGION_ENABLED_DEFAULT);
this.httpClientType = PropertyUtil.propertyAsString(properties,
AwsProperties.HTTP_CLIENT_TYPE, AwsProperties.HTTP_CLIENT_TYPE_DEFAULT);
}
Expand Down Expand Up @@ -125,6 +129,10 @@ protected String httpClientType() {
return httpClientType;
}

protected boolean s3UseArnRegionEnabled() {
return s3UseArnRegionEnabled;
}

private StsClient sts() {
return StsClient.builder()
.httpClientBuilder(AwsClientFactories.configureHttpClientBuilder(httpClientType))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ static class DefaultAwsClientFactory implements AwsClientFactory {
private String s3AccessKeyId;
private String s3SecretAccessKey;
private String s3SessionToken;
private Boolean s3UseArnRegionEnabled;
private String httpClientType;

DefaultAwsClientFactory() {
Expand All @@ -93,6 +94,7 @@ public S3Client s3() {
return S3Client.builder()
.httpClientBuilder(configureHttpClientBuilder(httpClientType))
.applyMutation(builder -> configureEndpoint(builder, s3Endpoint))
.serviceConfiguration(s -> s.useArnRegionEnabled(s3UseArnRegionEnabled).build())
.credentialsProvider(credentialsProvider(s3AccessKeyId, s3SecretAccessKey, s3SessionToken))
.build();
}
Expand All @@ -118,6 +120,8 @@ public void initialize(Map<String, String> properties) {
this.s3AccessKeyId = properties.get(AwsProperties.S3FILEIO_ACCESS_KEY_ID);
this.s3SecretAccessKey = properties.get(AwsProperties.S3FILEIO_SECRET_ACCESS_KEY);
this.s3SessionToken = properties.get(AwsProperties.S3FILEIO_SESSION_TOKEN);
this.s3UseArnRegionEnabled = PropertyUtil.propertyAsBoolean(properties, AwsProperties.S3_USE_ARN_REGION_ENABLED,
AwsProperties.S3_USE_ARN_REGION_ENABLED_DEFAULT);

ValidationException.check((s3AccessKeyId == null && s3SecretAccessKey == null) ||
(s3AccessKeyId != null && s3SecretAccessKey != null),
Expand Down
Loading