diff --git a/.github/workflows/gradle-integration-test.yml b/.github/workflows/gradle-integration-test.yml index a6b2e23c..c47ac9db 100644 --- a/.github/workflows/gradle-integration-test.yml +++ b/.github/workflows/gradle-integration-test.yml @@ -20,6 +20,7 @@ env: S3_TEST_BUCKET : ${{ vars.S3_TEST_BUCKET }} S3_TEST_PREFIX : ${{ vars.S3_TEST_PREFIX }} ROLE_TO_ASSUME: ${{ secrets.S3_TEST_ASSUME_ROLE_ARN }} + CUSTOMER_KEY: ${{ secrets.CUSTOMER_KEY }} jobs: build: diff --git a/README.md b/README.md index 8b05c2a2..22f0d5c6 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,21 @@ When the `S3SeekableInputStreamFactory` is no longer required to create new stre s3SeekableInputStreamFactory.close(); ``` +### Accessing SSE_C encrypted objects + +To access SSE_C encrypted objects using AAL, set the customer key which was used to encrypt the object in the ```OpenStreamInformation``` object and pass the openStreamInformation object in the stream. The customer key must be base64 encoded. + +``` + OpenStreamInformation openStreamInformation = + OpenStreamInformation.builder() + .encryptionSecrets( + EncryptionSecrets.builder().sseCustomerKey(Optional.of(base64EncodedCustomerKey)).build()) + .build(); + + S3SeekableInputStream s3SeekableInputStream = s3SeekableInputStreamFactory.createStream(S3URI.of(bucket, key), openStreamInformation); + +``` + ### Using with Hadoop If you are using Analytics Accelerator Library for Amazon S3 with Hadoop, you need to set the stream type to `analytics` in the Hadoop configuration. An example configuration is as follows: diff --git a/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/EncryptionSecrets.java b/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/EncryptionSecrets.java new file mode 100644 index 00000000..e784eb9f --- /dev/null +++ b/common/src/main/java/software/amazon/s3/analyticsaccelerator/request/EncryptionSecrets.java @@ -0,0 +1,79 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package software.amazon.s3.analyticsaccelerator.request; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.Base64; +import java.util.Optional; +import lombok.Builder; +import lombok.Getter; + +/** + * Contains encryption secrets for Server-Side Encryption with Customer-Provided Keys (SSE-C). This + * class manages the customer-provided encryption key used for SSE-C operations with Amazon S3. + */ +@Getter +public class EncryptionSecrets { + + /** + * The customer-provided encryption key for SSE-C operations. When present, this key will be used + * for server-side encryption. The key must be Base64 encoded and exactly 256 bits (32 bytes) when + * decoded. + */ + private final Optional ssecCustomerKey; + + /** + * The Base64-encoded MD5 hash of the customer key. This hash is automatically calculated from the + * customer key and is used by Amazon S3 to verify the integrity of the encryption key during + * transmission. Will be null if no customer key is provided. + */ + private final String ssecCustomerKeyMd5; + + /** + * Constructs an EncryptionSecrets instance with the specified SSE-C customer key. + * + *

This constructor processes the SSE-C (Server-Side Encryption with Customer-Provided Keys) + * encryption key and calculates its MD5 hash as required by Amazon S3. The process involves: + * + *

    + *
  1. Accepting a Base64-encoded encryption key + *
  2. Decoding the Base64 key back to bytes + *
  3. Computing the MD5 hash of these bytes + *
  4. Encoding the MD5 hash in Base64 format + *
+ * + * @param sseCustomerKey An Optional containing the Base64-encoded encryption key, or empty if no + * encryption is needed + */ + @Builder + public EncryptionSecrets(Optional sseCustomerKey) { + this.ssecCustomerKey = sseCustomerKey; + this.ssecCustomerKeyMd5 = + sseCustomerKey + .map( + key -> { + try { + MessageDigest md = MessageDigest.getInstance("MD5"); + return Base64.getEncoder() + .encodeToString(md.digest(Base64.getDecoder().decode(key))); + } catch (NoSuchAlgorithmException e) { + throw new IllegalStateException("MD5 algorithm not available", e); + } + }) + .orElse(null); + } +} diff --git a/common/src/main/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformation.java b/common/src/main/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformation.java index b2a42e07..8d2b7565 100644 --- a/common/src/main/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformation.java +++ b/common/src/main/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformation.java @@ -18,6 +18,7 @@ import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; +import software.amazon.s3.analyticsaccelerator.request.EncryptionSecrets; import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; import software.amazon.s3.analyticsaccelerator.request.StreamAuditContext; @@ -41,6 +42,7 @@ public class OpenStreamInformation { private final StreamAuditContext streamAuditContext; private final ObjectMetadata objectMetadata; private final InputPolicy inputPolicy; + private final EncryptionSecrets encryptionSecrets; /** Default set of settings for {@link OpenStreamInformation} */ public static final OpenStreamInformation DEFAULT = OpenStreamInformation.builder().build(); diff --git a/common/src/test/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformationTest.java b/common/src/test/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformationTest.java index dbce0252..7871775c 100644 --- a/common/src/test/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformationTest.java +++ b/common/src/test/java/software/amazon/s3/analyticsaccelerator/util/OpenStreamInformationTest.java @@ -17,13 +17,25 @@ import static org.junit.jupiter.api.Assertions.*; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Optional; import org.junit.jupiter.api.Test; import org.mockito.Mockito; +import software.amazon.s3.analyticsaccelerator.request.EncryptionSecrets; import software.amazon.s3.analyticsaccelerator.request.ObjectMetadata; import software.amazon.s3.analyticsaccelerator.request.StreamAuditContext; public class OpenStreamInformationTest { + private static final String CUSTOMER_KEY = "32-bytes-long-key-for-testing-123"; + + /** + * To generate the base64 encoded md5 value for a customer key use the cli command echo -n + * "customer_key" | base64 | base64 -d | openssl md5 -binary | base64 + */ + private static final String EXPECTED_BASE64_MD5 = "R+k8pqEVUmkxDfaH5MqIdw=="; + @Test public void testDefaultInstance() { OpenStreamInformation info = OpenStreamInformation.DEFAULT; @@ -32,6 +44,7 @@ public void testDefaultInstance() { assertNull(info.getStreamAuditContext(), "Default streamContext should be null"); assertNull(info.getObjectMetadata(), "Default objectMetadata should be null"); assertNull(info.getInputPolicy(), "Default inputPolicy should be null"); + assertNull(info.getEncryptionSecrets(), "Default encryptionSecrets should be null"); } @Test @@ -39,17 +52,28 @@ public void testBuilderWithAllFields() { StreamAuditContext mockContext = Mockito.mock(StreamAuditContext.class); ObjectMetadata mockMetadata = Mockito.mock(ObjectMetadata.class); InputPolicy mockPolicy = Mockito.mock(InputPolicy.class); + String base64Key = + Base64.getEncoder().encodeToString(CUSTOMER_KEY.getBytes(StandardCharsets.UTF_8)); + EncryptionSecrets secrets = + EncryptionSecrets.builder().sseCustomerKey(Optional.of(base64Key)).build(); OpenStreamInformation info = OpenStreamInformation.builder() .streamAuditContext(mockContext) .objectMetadata(mockMetadata) .inputPolicy(mockPolicy) + .encryptionSecrets(secrets) .build(); assertSame(mockContext, info.getStreamAuditContext(), "StreamContext should match"); assertSame(mockMetadata, info.getObjectMetadata(), "ObjectMetadata should match"); assertSame(mockPolicy, info.getInputPolicy(), "InputPolicy should match"); + assertEquals( + base64Key, + info.getEncryptionSecrets().getSsecCustomerKey().get(), + "Customer key should match"); + assertNotNull(info.getEncryptionSecrets().getSsecCustomerKeyMd5(), "MD5 should not be null"); + assertEquals(EXPECTED_BASE64_MD5, info.getEncryptionSecrets().getSsecCustomerKeyMd5()); } @Test @@ -103,4 +127,47 @@ public void testNullFields() { assertNull(info.getObjectMetadata(), "ObjectMetadata should be null"); assertNull(info.getInputPolicy(), "InputPolicy should be null"); } + + @Test + public void testDefaultInstanceEncryptionSecrets() { + OpenStreamInformation info = OpenStreamInformation.DEFAULT; + assertNull(info.getEncryptionSecrets(), "Default encryptionSecrets should be null"); + } + + @Test + public void testBuilderWithEncryptionSecrets() { + // Create a sample base64 encoded key + String base64Key = + Base64.getEncoder().encodeToString(CUSTOMER_KEY.getBytes(StandardCharsets.UTF_8)); + EncryptionSecrets secrets = + EncryptionSecrets.builder().sseCustomerKey(Optional.of(base64Key)).build(); + + OpenStreamInformation info = OpenStreamInformation.builder().encryptionSecrets(secrets).build(); + + assertNotNull(info.getEncryptionSecrets(), "EncryptionSecrets should not be null"); + assertTrue( + info.getEncryptionSecrets().getSsecCustomerKey().isPresent(), + "Customer key should be present"); + assertEquals( + base64Key, + info.getEncryptionSecrets().getSsecCustomerKey().get(), + "Customer key should match"); + assertNotNull(info.getEncryptionSecrets().getSsecCustomerKeyMd5(), "MD5 should not be null"); + assertEquals(EXPECTED_BASE64_MD5, info.getEncryptionSecrets().getSsecCustomerKeyMd5()); + } + + @Test + public void testBuilderWithEmptyEncryptionSecrets() { + EncryptionSecrets secrets = + EncryptionSecrets.builder().sseCustomerKey(Optional.empty()).build(); + + OpenStreamInformation info = OpenStreamInformation.builder().encryptionSecrets(secrets).build(); + + assertNotNull(info.getEncryptionSecrets(), "EncryptionSecrets should not be null"); + assertFalse( + info.getEncryptionSecrets().getSsecCustomerKey().isPresent(), + "Customer key should be empty"); + assertNull( + info.getEncryptionSecrets().getSsecCustomerKeyMd5(), "MD5 should be null for empty key"); + } } diff --git a/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/IntegrationTestBase.java b/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/IntegrationTestBase.java index 406f7c0c..4d851b91 100644 --- a/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/IntegrationTestBase.java +++ b/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/IntegrationTestBase.java @@ -46,6 +46,7 @@ import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.s3.analyticsaccelerator.S3SeekableInputStream; import software.amazon.s3.analyticsaccelerator.common.ObjectRange; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; import software.amazon.s3.analyticsaccelerator.util.S3URI; /** Base class for the integration tests */ @@ -99,7 +100,11 @@ protected void testAndCompareStreamReadPattern( // Read using the standard S3 async client Crc32CChecksum directChecksum = new Crc32CChecksum(); executeReadPatternDirectly( - s3ClientKind, s3Object, streamReadPattern, Optional.of(directChecksum)); + s3ClientKind, + s3Object, + streamReadPattern, + Optional.of(directChecksum), + OpenStreamInformation.DEFAULT); // Read using the AAL S3 Crc32CChecksum aalChecksum = new Crc32CChecksum(); @@ -108,7 +113,8 @@ protected void testAndCompareStreamReadPattern( s3Object, streamReadPattern, AALInputStreamConfigurationKind, - Optional.of(aalChecksum)); + Optional.of(aalChecksum), + OpenStreamInformation.DEFAULT); // Assert checksums assertChecksums(directChecksum, aalChecksum); @@ -140,7 +146,8 @@ protected void testChangingEtagMidStream( S3URI s3URI = s3Object.getObjectUri(this.getS3ExecutionContext().getConfiguration().getBaseUri()); S3AsyncClient s3Client = this.getS3ExecutionContext().getS3Client(); - S3SeekableInputStream stream = s3AALClientStreamReader.createReadStream(s3Object); + S3SeekableInputStream stream = + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); // Read first 100 bytes readAndAssert(stream, buffer, 0, 100); @@ -171,7 +178,11 @@ protected void testChangingEtagMidStream( assertDoesNotThrow( () -> executeReadPatternOnAAL( - s3Object, s3AALClientStreamReader, streamReadPattern, Optional.of(datChecksum))); + s3Object, + s3AALClientStreamReader, + streamReadPattern, + Optional.of(datChecksum), + OpenStreamInformation.DEFAULT)); assert (datChecksum.getChecksumBytes().length > 0); } } @@ -199,7 +210,7 @@ protected void testReadVectored( this.createS3AALClientStreamReader(s3ClientKind, AALInputStreamConfigurationKind)) { S3SeekableInputStream s3SeekableInputStream = - s3AALClientStreamReader.createReadStream(s3Object); + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); List objectRanges = new ArrayList<>(); objectRanges.add(new ObjectRange(new CompletableFuture<>(), 50, 500)); @@ -217,7 +228,7 @@ protected void testReadVectored( ByteBuffer byteBuffer = objectRange.getByteBuffer().join(); S3SeekableInputStream verificationStream = - s3AALClientStreamReader.createReadStream(s3Object); + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); verificationStream.seek(objectRange.getOffset()); byte[] buffer = new byte[objectRange.getLength()]; int readBytes = verificationStream.read(buffer, 0, buffer.length); @@ -273,7 +284,8 @@ protected void testChangingEtagAfterStreamPassesAndReturnsCachedObject( // Create the s3DATClientStreamReader - that creates the shared state try (S3AALClientStreamReader s3AALClientStreamReader = this.createS3AALClientStreamReader(s3ClientKind, AALInputStreamConfigurationKind)) { - S3SeekableInputStream stream = s3AALClientStreamReader.createReadStream(s3Object); + S3SeekableInputStream stream = + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); Crc32CChecksum datChecksum = calculateCRC32C(stream, bufferSize); S3URI s3URI = @@ -287,7 +299,8 @@ protected void testChangingEtagAfterStreamPassesAndReturnsCachedObject( AsyncRequestBody.fromBytes(generateRandomBytes(bufferSize))) .join(); - S3SeekableInputStream cacheStream = s3AALClientStreamReader.createReadStream(s3Object); + S3SeekableInputStream cacheStream = + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); Crc32CChecksum cachedChecksum = calculateCRC32C(cacheStream, bufferSize); // Assert checksums @@ -351,7 +364,11 @@ protected void testAALReadConcurrency( // Read using the standard S3 async client. We do this once, to calculate the checksums Crc32CChecksum directChecksum = new Crc32CChecksum(); executeReadPatternDirectly( - s3ClientKind, s3Object, streamReadPattern, Optional.of(directChecksum)); + s3ClientKind, + s3Object, + streamReadPattern, + Optional.of(directChecksum), + OpenStreamInformation.DEFAULT); // Create the s3DATClientStreamReader - that creates the shared state try (S3AALClientStreamReader s3AALClientStreamReader = @@ -374,7 +391,8 @@ protected void testAALReadConcurrency( s3Object, s3AALClientStreamReader, streamReadPattern, - Optional.of(datChecksum)); + Optional.of(datChecksum), + OpenStreamInformation.DEFAULT); // Assert checksums assertChecksums(directChecksum, datChecksum); @@ -418,7 +436,8 @@ protected void testSmallObjectPrefetching( this.createS3AALClientStreamReader(s3ClientKind, AALInputStreamConfigurationKind)) { // First stream - S3SeekableInputStream stream = s3AALClientStreamReader.createReadStream(s3Object); + S3SeekableInputStream stream = + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); Crc32CChecksum firstChecksum = calculateCRC32C(stream, (int) s3Object.getSize()); S3URI s3URI = @@ -433,7 +452,8 @@ protected void testSmallObjectPrefetching( .join(); // Create second stream - S3SeekableInputStream secondStream = s3AALClientStreamReader.createReadStream(s3Object); + S3SeekableInputStream secondStream = + s3AALClientStreamReader.createReadStream(s3Object, OpenStreamInformation.DEFAULT); Crc32CChecksum secondChecksum = calculateCRC32C(secondStream, (int) s3Object.getSize()); if (s3Object.getSize() < 8 * ONE_MB) { diff --git a/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/SSECEncryptionTest.java b/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/SSECEncryptionTest.java new file mode 100644 index 00000000..364dbb5d --- /dev/null +++ b/input-stream/src/integrationTest/java/software/amazon/s3/analyticsaccelerator/access/SSECEncryptionTest.java @@ -0,0 +1,224 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package software.amazon.s3.analyticsaccelerator.access; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static software.amazon.s3.analyticsaccelerator.access.ChecksumAssertions.assertChecksums; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; +import lombok.NonNull; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import software.amazon.awssdk.core.checksums.Crc32CChecksum; +import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.s3.analyticsaccelerator.request.EncryptionSecrets; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; + +public class SSECEncryptionTest extends IntegrationTestBase { + private static final Logger LOG = LoggerFactory.getLogger(SSECEncryptionTest.class); + + private static final String CUSTOMER_KEY = System.getenv("CUSTOMER_KEY"); + + private void checkPrerequisites() { + String skipMessage = "Skipping tests: CUSTOMER_KEY environment variable is not set"; + if (CUSTOMER_KEY == null || CUSTOMER_KEY.trim().isEmpty()) { + LOG.info(skipMessage); + } + assumeTrue(CUSTOMER_KEY != null && !CUSTOMER_KEY.trim().isEmpty(), skipMessage); + } + + @ParameterizedTest + @MethodSource("encryptedSequentialReads") + void testEncryptedSequentialReads( + S3ClientKind s3ClientKind, + S3Object s3Object, + StreamReadPatternKind streamReadPattern, + AALInputStreamConfigurationKind configuration) + throws IOException { + checkPrerequisites(); + testReadPatternUsingSSECEncryption( + s3ClientKind, s3Object, streamReadPattern, configuration, CUSTOMER_KEY); + } + + @ParameterizedTest + @MethodSource("encryptedParquetReads") + void testEncryptedParquetReads( + S3ClientKind s3ClientKind, + S3Object s3Object, + StreamReadPatternKind streamReadPattern, + AALInputStreamConfigurationKind configuration) + throws IOException { + checkPrerequisites(); + testReadPatternUsingSSECEncryption( + s3ClientKind, s3Object, streamReadPattern, configuration, CUSTOMER_KEY); + } + + @ParameterizedTest + @MethodSource("encryptedReadsWithWrongKey") + void testEncryptedReadsWithWrongKey( + S3ClientKind s3ClientKind, + S3Object s3Object, + StreamReadPatternKind streamReadPattern, + AALInputStreamConfigurationKind configuration) { + + IOException exception = + assertThrows( + IOException.class, + () -> { + testReadPatternUsingWrongKeyOrEmptyKey( + s3ClientKind, s3Object, streamReadPattern, configuration, "wrongkey"); + }); + + Throwable cause = exception.getCause(); + assertTrue(cause instanceof S3Exception); + S3Exception s3Exception = (S3Exception) cause; + assertEquals(403, s3Exception.statusCode()); + } + + @ParameterizedTest + @MethodSource("encryptedReadsWithWrongKey") + void testEncryptedReadsWithEmptyKey( + S3ClientKind s3ClientKind, + S3Object s3Object, + StreamReadPatternKind streamReadPattern, + AALInputStreamConfigurationKind configuration) { + + IOException exception = + assertThrows( + IOException.class, + () -> { + testReadPatternUsingWrongKeyOrEmptyKey( + s3ClientKind, s3Object, streamReadPattern, configuration, null); + }); + + Throwable cause = exception.getCause(); + assertTrue(cause instanceof S3Exception); + S3Exception s3Exception = (S3Exception) cause; + assertEquals(400, s3Exception.statusCode()); + } + + protected void testReadPatternUsingSSECEncryption( + @NonNull S3ClientKind s3ClientKind, + @NonNull S3Object s3Object, + @NonNull StreamReadPatternKind streamReadPatternKind, + @NonNull AALInputStreamConfigurationKind AALInputStreamConfigurationKind, + String customerKey) + throws IOException { + StreamReadPattern streamReadPattern = streamReadPatternKind.getStreamReadPattern(s3Object); + OpenStreamInformation openStreamInformation = + OpenStreamInformation.builder() + .encryptionSecrets( + EncryptionSecrets.builder().sseCustomerKey(Optional.of(customerKey)).build()) + .build(); + + // Read using the standard S3 async client + Crc32CChecksum directChecksum = new Crc32CChecksum(); + executeReadPatternDirectly( + s3ClientKind, + s3Object, + streamReadPattern, + Optional.of(directChecksum), + openStreamInformation); + + // Read using the AAL S3 + Crc32CChecksum aalChecksum = new Crc32CChecksum(); + executeReadPatternOnAAL( + s3ClientKind, + s3Object, + streamReadPattern, + AALInputStreamConfigurationKind, + Optional.of(aalChecksum), + openStreamInformation); + + // Assert checksums + assertChecksums(directChecksum, aalChecksum); + } + + protected void testReadPatternUsingWrongKeyOrEmptyKey( + @NonNull S3ClientKind s3ClientKind, + @NonNull S3Object s3Object, + @NonNull StreamReadPatternKind streamReadPatternKind, + @NonNull AALInputStreamConfigurationKind AALInputStreamConfigurationKind, + String customerKey) + throws IOException { + StreamReadPattern streamReadPattern = streamReadPatternKind.getStreamReadPattern(s3Object); + + OpenStreamInformation openStreamInformation = + customerKey == null + ? OpenStreamInformation.DEFAULT + : OpenStreamInformation.builder() + .encryptionSecrets( + EncryptionSecrets.builder().sseCustomerKey(Optional.of(customerKey)).build()) + .build(); + + // Read using the AAL S3 + Crc32CChecksum aalChecksum = new Crc32CChecksum(); + executeReadPatternOnAAL( + s3ClientKind, + s3Object, + streamReadPattern, + AALInputStreamConfigurationKind, + Optional.of(aalChecksum), + openStreamInformation); + } + + static Stream encryptedSequentialReads() { + List readEncryptedObjects = new ArrayList<>(); + readEncryptedObjects.add(S3Object.RANDOM_SSEC_ENCRYPTED_SEQUENTIAL_1MB); + + return argumentsFor( + getS3ClientKinds(), + readEncryptedObjects, + sequentialPatterns(), + readCorrectnessConfigurationKind()); + } + + static Stream encryptedParquetReads() { + List readEncryptedObjects = new ArrayList<>(); + readEncryptedObjects.add(S3Object.RANDOM_SSEC_ENCRYPTED_PARQUET_1MB); + readEncryptedObjects.add(S3Object.RANDOM_SSEC_ENCRYPTED_PARQUET_64MB); + + return argumentsFor( + getS3ClientKinds(), + readEncryptedObjects, + parquetPatterns(), + readCorrectnessConfigurationKind()); + } + + static Stream encryptedReadsWithWrongKey() { + List readEncryptedObjects = new ArrayList<>(); + readEncryptedObjects.add(S3Object.RANDOM_SSEC_ENCRYPTED_PARQUET_64MB); + + return argumentsFor( + getS3ClientKinds(), + readEncryptedObjects, + sequentialPatterns(), + readCorrectnessConfigurationKind()); + } + + private static List readCorrectnessConfigurationKind() { + return Arrays.asList(AALInputStreamConfigurationKind.READ_CORRECTNESS); + } +} diff --git a/input-stream/src/jmh/java/software/amazon/s3/analyticsaccelerator/benchmarks/BenchmarkBase.java b/input-stream/src/jmh/java/software/amazon/s3/analyticsaccelerator/benchmarks/BenchmarkBase.java index d30a8650..2dd256bf 100644 --- a/input-stream/src/jmh/java/software/amazon/s3/analyticsaccelerator/benchmarks/BenchmarkBase.java +++ b/input-stream/src/jmh/java/software/amazon/s3/analyticsaccelerator/benchmarks/BenchmarkBase.java @@ -21,6 +21,7 @@ import lombok.NonNull; import org.openjdk.jmh.annotations.*; import software.amazon.s3.analyticsaccelerator.access.*; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; /** * Base class for benchmarks that iterate through the client types and stream types All derived @@ -128,7 +129,8 @@ protected void executeReadPatternOnDAT() throws IOException { this.getReadPatternKind().getStreamReadPattern(s3Object), // Use default configuration this.getDATInputStreamConfigurationKind(), - Optional.empty()); + Optional.empty(), + OpenStreamInformation.DEFAULT); } /** @@ -142,6 +144,7 @@ protected void executeReadPatternDirectly() throws IOException { this.getClientKind(), s3Object, this.getReadPatternKind().getStreamReadPattern(s3Object), - Optional.empty()); + Optional.empty(), + OpenStreamInformation.DEFAULT); } } diff --git a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/ExecutionBase.java b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/ExecutionBase.java index b81a4171..43e2b2b8 100644 --- a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/ExecutionBase.java +++ b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/ExecutionBase.java @@ -20,6 +20,7 @@ import lombok.NonNull; import software.amazon.awssdk.core.checksums.Crc32CChecksum; import software.amazon.s3.analyticsaccelerator.S3SeekableInputStreamConfiguration; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; /** * This class is a base for performance (JMH) and integration tests and contains common @@ -89,13 +90,15 @@ protected S3AALClientStreamReader createS3AALClientStreamReader( * @param s3Object {@link } S3 Object to run the pattern on * @param streamReadPattern the read pattern * @param checksum checksum to update, if specified + * @param openStreamInformation contains the open stream information * @throws IOException IO error, if thrown */ protected void executeReadPatternDirectly( S3ClientKind s3ClientKind, S3Object s3Object, StreamReadPattern streamReadPattern, - Optional checksum) + Optional checksum, + OpenStreamInformation openStreamInformation) throws IOException { // Direct Read Pattern execution shouldn't read using the faulty client but it should use a // trusted client. @@ -105,7 +108,8 @@ protected void executeReadPatternDirectly( : s3ClientKind; try (S3AsyncClientStreamReader s3AsyncClientStreamReader = this.createS3AsyncClientStreamReader(s3ClientKind)) { - s3AsyncClientStreamReader.readPattern(s3Object, streamReadPattern, checksum); + s3AsyncClientStreamReader.readPattern( + s3Object, streamReadPattern, checksum, openStreamInformation); } } @@ -117,6 +121,7 @@ protected void executeReadPatternDirectly( * @param AALInputStreamConfigurationKind DAT configuration * @param streamReadPattern the read pattern * @param checksum checksum to update, if specified + * @param openStreamInformation contains the open stream information * @throws IOException IO error, if thrown */ protected void executeReadPatternOnAAL( @@ -124,11 +129,13 @@ protected void executeReadPatternOnAAL( S3Object s3Object, StreamReadPattern streamReadPattern, AALInputStreamConfigurationKind AALInputStreamConfigurationKind, - Optional checksum) + Optional checksum, + OpenStreamInformation openStreamInformation) throws IOException { try (S3AALClientStreamReader s3AALClientStreamReader = this.createS3AALClientStreamReader(s3ClientKind, AALInputStreamConfigurationKind)) { - executeReadPatternOnAAL(s3Object, s3AALClientStreamReader, streamReadPattern, checksum); + executeReadPatternOnAAL( + s3Object, s3AALClientStreamReader, streamReadPattern, checksum, openStreamInformation); } } @@ -139,14 +146,17 @@ protected void executeReadPatternOnAAL( * @param s3AALClientStreamReader DAT stream reader * @param streamReadPattern the read pattern * @param checksum checksum to update, if specified + * @param openStreamInformation contains the open stream information * @throws IOException IO error, if thrown */ protected void executeReadPatternOnAAL( S3Object s3Object, S3AALClientStreamReader s3AALClientStreamReader, StreamReadPattern streamReadPattern, - Optional checksum) + Optional checksum, + OpenStreamInformation openStreamInformation) throws IOException { - s3AALClientStreamReader.readPattern(s3Object, streamReadPattern, checksum); + s3AALClientStreamReader.readPattern( + s3Object, streamReadPattern, checksum, openStreamInformation); } } diff --git a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AALClientStreamReader.java b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AALClientStreamReader.java index f6243af1..1ccb5160 100644 --- a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AALClientStreamReader.java +++ b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AALClientStreamReader.java @@ -25,6 +25,7 @@ import software.amazon.s3.analyticsaccelerator.S3SeekableInputStream; import software.amazon.s3.analyticsaccelerator.S3SeekableInputStreamConfiguration; import software.amazon.s3.analyticsaccelerator.S3SeekableInputStreamFactory; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; import software.amazon.s3.analyticsaccelerator.util.S3URI; /** Client stream reader based on DAT */ @@ -55,11 +56,14 @@ public S3AALClientStreamReader( * Creates the read stream for a given object * * @param s3Object {@link S3Object} to create the stream for + * @param openStreamInformation contains the open stream information * @return read stream */ - public S3SeekableInputStream createReadStream(@NonNull S3Object s3Object) throws IOException { + public S3SeekableInputStream createReadStream( + @NonNull S3Object s3Object, @NonNull OpenStreamInformation openStreamInformation) + throws IOException { S3URI s3URI = s3Object.getObjectUri(this.getBaseUri()); - return this.getS3SeekableInputStreamFactory().createStream(s3URI); + return this.getS3SeekableInputStreamFactory().createStream(s3URI, openStreamInformation); } /** @@ -68,14 +72,17 @@ public S3SeekableInputStream createReadStream(@NonNull S3Object s3Object) throws * @param s3Object S3 Object to read * @param streamReadPattern Stream read pattern * @param checksum optional checksum, to update + * @param openStreamInformation contains the open stream information */ @Override public void readPattern( @NonNull S3Object s3Object, @NonNull StreamReadPattern streamReadPattern, - @NonNull Optional checksum) + @NonNull Optional checksum, + @NonNull OpenStreamInformation openStreamInformation) throws IOException { - try (S3SeekableInputStream inputStream = this.createReadStream(s3Object)) { + try (S3SeekableInputStream inputStream = + this.createReadStream(s3Object, openStreamInformation)) { readPattern(s3Object, inputStream, streamReadPattern, checksum); } } diff --git a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AsyncClientStreamReader.java b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AsyncClientStreamReader.java index a7bd0e6a..c8c2c295 100644 --- a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AsyncClientStreamReader.java +++ b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3AsyncClientStreamReader.java @@ -24,6 +24,8 @@ import software.amazon.awssdk.core.checksums.Crc32CChecksum; import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; import software.amazon.s3.analyticsaccelerator.util.S3URI; /** A naive stream reader based on the {@link S3AsyncClient} */ @@ -54,32 +56,49 @@ public S3AsyncClientStreamReader( public void readPattern( @NonNull S3Object s3Object, @NonNull StreamReadPattern streamReadPattern, - @NonNull Optional checksum) + @NonNull Optional checksum, + @NonNull OpenStreamInformation openStreamInformation) throws IOException { S3URI s3URI = s3Object.getObjectUri(this.getBaseUri()); // Replay the pattern through series of GETs for (StreamRead streamRead : streamReadPattern.getStreamReads()) { + // Build base request with common parameters + GetObjectRequest.Builder requestBuilder = + GetObjectRequest.builder() + .bucket(s3URI.getBucket()) + .key(s3URI.getKey()) + .range( + String.format( + "bytes=%s-%s", + streamRead.getStart(), streamRead.getStart() + streamRead.getLength() - 1)); + + // Add encryption parameters if present + addEncryptionSecrets(requestBuilder, openStreamInformation); + // Issue a ranged GET and get InputStream InputStream inputStream = s3AsyncClient - .getObject( - GetObjectRequest.builder() - .bucket(s3URI.getBucket()) - .key(s3URI.getKey()) - .range( - String.format( - "bytes=%s-%s", - streamRead.getStart(), - streamRead.getStart() + streamRead.getLength() - 1)) - .build(), - AsyncResponseTransformer.toBlockingInputStream()) + .getObject(requestBuilder.build(), AsyncResponseTransformer.toBlockingInputStream()) .join(); // drain bytes drainStream(inputStream, s3Object, checksum, streamRead.getLength()); } } + private void addEncryptionSecrets( + GetObjectRequest.Builder requestBuilder, OpenStreamInformation openStreamInformation) { + if (openStreamInformation.getEncryptionSecrets() != null + && openStreamInformation.getEncryptionSecrets().getSsecCustomerKey().isPresent()) { + String customerKey = openStreamInformation.getEncryptionSecrets().getSsecCustomerKey().get(); + String customerKeyMd5 = openStreamInformation.getEncryptionSecrets().getSsecCustomerKeyMd5(); + requestBuilder + .sseCustomerAlgorithm(ServerSideEncryption.AES256.name()) + .sseCustomerKey(customerKey) + .sseCustomerKeyMD5(customerKeyMd5); + } + } + /** * Closes the reader * diff --git a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3Object.java b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3Object.java index b7a97682..90d7133f 100644 --- a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3Object.java +++ b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3Object.java @@ -52,7 +52,19 @@ public enum S3Object { CSV_20MB( "sequential-20mb.csv", 20 * SizeConstants.ONE_MB_IN_BYTES, S3ObjectKind.RANDOM_SEQUENTIAL), TXT_16MB( - "sequential-16mb.txt", 16 * SizeConstants.ONE_MB_IN_BYTES, S3ObjectKind.RANDOM_SEQUENTIAL); + "sequential-16mb.txt", 16 * SizeConstants.ONE_MB_IN_BYTES, S3ObjectKind.RANDOM_SEQUENTIAL), + RANDOM_SSEC_ENCRYPTED_SEQUENTIAL_1MB( + "random-encrypted-1mb.bin", + SizeConstants.ONE_MB_IN_BYTES, + S3ObjectKind.RANDOM_SEQUENTIAL_ENCRYPTED), + RANDOM_SSEC_ENCRYPTED_PARQUET_1MB( + "random-encrypted-1mb.parquet", + SizeConstants.ONE_MB_IN_BYTES, + S3ObjectKind.RANDOM_PARQUET_ENCRYPTED), + RANDOM_SSEC_ENCRYPTED_PARQUET_64MB( + "random-encrypted-64mb.parquet", + 64 * SizeConstants.ONE_MB_IN_BYTES, + S3ObjectKind.RANDOM_PARQUET_ENCRYPTED); private final String name; private final long size; @@ -61,6 +73,9 @@ public enum S3Object { private static final long SMALL_BINARY_OBJECTS_LOWER_LIMIT = 8 * SizeConstants.ONE_MB_IN_BYTES; private static final long MEDIUM_SIZE_THRESHOLD = 50 * SizeConstants.ONE_MB_IN_BYTES; private static final long LARGE_SIZE_THRESHOLD = 500 * SizeConstants.ONE_MB_IN_BYTES; + private static final List ENCRYPTED_OBJECT_KINDS = + Arrays.asList( + S3ObjectKind.RANDOM_SEQUENTIAL_ENCRYPTED, S3ObjectKind.RANDOM_PARQUET_ENCRYPTED); /** * Get S3 Object Uri based on the content @@ -96,7 +111,8 @@ public static List filter(@NonNull Predicate predicate) { * @return small objects */ public static List smallObjects() { - return filter(o -> o.size < MEDIUM_SIZE_THRESHOLD); + return filter( + o -> o.size < MEDIUM_SIZE_THRESHOLD && !ENCRYPTED_OBJECT_KINDS.contains(o.getKind())); } /** * Returns list of small binary objects (between 8 MB and 50MB, .bin files only). @@ -108,7 +124,8 @@ public static List smallBinaryObjects() { o -> o.size >= SMALL_BINARY_OBJECTS_LOWER_LIMIT && o.size < MEDIUM_SIZE_THRESHOLD - && o.getName().endsWith(".bin")); + && o.getName().endsWith(".bin") + && !ENCRYPTED_OBJECT_KINDS.contains(o.getKind())); } /** @@ -117,7 +134,11 @@ public static List smallBinaryObjects() { * @return medium objects */ public static List mediumObjects() { - return filter(o -> o.size >= MEDIUM_SIZE_THRESHOLD && o.size < LARGE_SIZE_THRESHOLD); + return filter( + o -> + o.size >= MEDIUM_SIZE_THRESHOLD + && o.size < LARGE_SIZE_THRESHOLD + && !ENCRYPTED_OBJECT_KINDS.contains(o.getKind())); } /** @@ -126,7 +147,8 @@ public static List mediumObjects() { * @return small and medium objects */ public static List smallAndMediumObjects() { - return filter(o -> o.size < LARGE_SIZE_THRESHOLD); + return filter( + o -> o.size < LARGE_SIZE_THRESHOLD && !ENCRYPTED_OBJECT_KINDS.contains(o.getKind())); } /** @@ -135,7 +157,8 @@ public static List smallAndMediumObjects() { * @return medium and large objects */ public static List mediumAndLargeObjects() { - return filter(o -> o.size >= MEDIUM_SIZE_THRESHOLD); + return filter( + o -> o.size >= MEDIUM_SIZE_THRESHOLD && !ENCRYPTED_OBJECT_KINDS.contains(o.getKind())); } /** @@ -144,7 +167,8 @@ public static List mediumAndLargeObjects() { * @return large objects */ public static List largeObjects() { - return filter(o -> o.size >= LARGE_SIZE_THRESHOLD); + return filter( + o -> o.size >= LARGE_SIZE_THRESHOLD && !ENCRYPTED_OBJECT_KINDS.contains(o.getKind())); } /** diff --git a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3ObjectKind.java b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3ObjectKind.java index 2ab81a0b..7281e342 100644 --- a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3ObjectKind.java +++ b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3ObjectKind.java @@ -23,7 +23,9 @@ @Getter public enum S3ObjectKind { RANDOM_SEQUENTIAL("sequential"), - RANDOM_PARQUET("parquet"); + RANDOM_PARQUET("parquet"), + RANDOM_SEQUENTIAL_ENCRYPTED("sequential_encrypted"), + RANDOM_PARQUET_ENCRYPTED("parquet_encrypted"); private final String value; } diff --git a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3StreamReaderBase.java b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3StreamReaderBase.java index cc8c0e52..69285a47 100644 --- a/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3StreamReaderBase.java +++ b/input-stream/src/testFixtures/java/software/amazon/s3/analyticsaccelerator/access/S3StreamReaderBase.java @@ -22,6 +22,7 @@ import lombok.Getter; import lombok.NonNull; import software.amazon.awssdk.core.checksums.Crc32CChecksum; +import software.amazon.s3.analyticsaccelerator.util.OpenStreamInformation; import software.amazon.s3.analyticsaccelerator.util.S3URI; /** Base class for all readers from S3 */ @@ -46,11 +47,13 @@ protected S3StreamReaderBase(@NonNull S3URI baseUri, int bufferSize) { * @param s3Object S3 Object to read * @param streamReadPattern Stream read pattern * @param checksum optional checksum, to update + * @param openStreamInformation contains the open stream information */ public abstract void readPattern( @NonNull S3Object s3Object, @NonNull StreamReadPattern streamReadPattern, - @NonNull Optional checksum) + @NonNull Optional checksum, + @NonNull OpenStreamInformation openStreamInformation) throws IOException; /** diff --git a/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java b/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java index c01c549f..ac37ac4f 100644 --- a/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java +++ b/object-client/src/main/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClient.java @@ -33,6 +33,7 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectRequest; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import software.amazon.s3.analyticsaccelerator.common.telemetry.ConfigurableTelemetry; import software.amazon.s3.analyticsaccelerator.common.telemetry.Operation; import software.amazon.s3.analyticsaccelerator.common.telemetry.Telemetry; @@ -132,6 +133,16 @@ public CompletableFuture headObject( builder.overrideConfiguration(requestOverrideConfigurationBuilder.build()); + if (openStreamInformation.getEncryptionSecrets() != null + && openStreamInformation.getEncryptionSecrets().getSsecCustomerKey().isPresent()) { + String customerKey = openStreamInformation.getEncryptionSecrets().getSsecCustomerKey().get(); + String customerKeyMd5 = openStreamInformation.getEncryptionSecrets().getSsecCustomerKeyMd5(); + builder + .sseCustomerAlgorithm(ServerSideEncryption.AES256.name()) + .sseCustomerKey(customerKey) + .sseCustomerKeyMD5(customerKeyMd5); + } + return this.telemetry .measureCritical( () -> @@ -175,6 +186,16 @@ public CompletableFuture getObject( builder.overrideConfiguration(requestOverrideConfigurationBuilder.build()); + if (openStreamInformation.getEncryptionSecrets() != null + && openStreamInformation.getEncryptionSecrets().getSsecCustomerKey().isPresent()) { + String customerKey = openStreamInformation.getEncryptionSecrets().getSsecCustomerKey().get(); + String customerKeyMd5 = openStreamInformation.getEncryptionSecrets().getSsecCustomerKeyMd5(); + builder + .sseCustomerAlgorithm(ServerSideEncryption.AES256.name()) + .sseCustomerKey(customerKey) + .sseCustomerKeyMD5(customerKeyMd5); + } + return this.telemetry.measureCritical( () -> Operation.builder() diff --git a/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java b/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java index 63d128fc..fdd2507e 100644 --- a/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java +++ b/object-client/src/test/java/software/amazon/s3/analyticsaccelerator/S3SdkObjectClientTest.java @@ -28,6 +28,8 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.Optional; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; @@ -49,6 +51,7 @@ import software.amazon.awssdk.services.s3.model.HeadObjectRequest; import software.amazon.awssdk.services.s3.model.HeadObjectResponse; import software.amazon.awssdk.services.s3.model.S3Exception; +import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import software.amazon.s3.analyticsaccelerator.exceptions.ExceptionHandler; import software.amazon.s3.analyticsaccelerator.request.*; import software.amazon.s3.analyticsaccelerator.request.GetRequest; @@ -516,4 +519,132 @@ private static void assertObjectClientExceptions( private static Exception[] exceptions() { return ExceptionHandler.getSampleExceptions(); } + + @Test + void testGetObjectWithEncryption() { + S3AsyncClient mockS3AsyncClient = createMockClient(); + S3SdkObjectClient client = new S3SdkObjectClient(mockS3AsyncClient); + + // Create encryption secrets + String base64Key = + Base64.getEncoder() + .encodeToString("32-bytes-long-key-for-testing-123".getBytes(StandardCharsets.UTF_8)); + EncryptionSecrets secrets = + EncryptionSecrets.builder().sseCustomerKey(Optional.of(base64Key)).build(); + + // Create OpenStreamInformation with encryption + OpenStreamInformation openStreamInformation = + OpenStreamInformation.builder().encryptionSecrets(secrets).build(); + + GetRequest getRequest = + GetRequest.builder() + .s3Uri(S3URI.of("bucket", "key")) + .range(new Range(0, 20)) + .etag(ETAG) + .referrer(new Referrer("bytes=0-20", ReadMode.SYNC)) + .build(); + + client.getObject(getRequest, openStreamInformation); + + // Verify the encryption parameters + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(GetObjectRequest.class); + verify(mockS3AsyncClient) + .getObject( + requestCaptor.capture(), + ArgumentMatchers + .>> + any()); + + GetObjectRequest capturedRequest = requestCaptor.getValue(); + assertEquals(ServerSideEncryption.AES256.name(), capturedRequest.sseCustomerAlgorithm()); + assertEquals(base64Key, capturedRequest.sseCustomerKey()); + assertNotNull(capturedRequest.sseCustomerKeyMD5()); + } + + @Test + void testHeadObjectWithEncryption() { + S3AsyncClient mockS3AsyncClient = createMockClient(); + S3SdkObjectClient client = new S3SdkObjectClient(mockS3AsyncClient); + + // Create encryption secrets + String base64Key = + Base64.getEncoder() + .encodeToString("32-bytes-long-key-for-testing-123".getBytes(StandardCharsets.UTF_8)); + EncryptionSecrets secrets = + EncryptionSecrets.builder().sseCustomerKey(Optional.of(base64Key)).build(); + + // Create OpenStreamInformation with encryption + OpenStreamInformation openStreamInformation = + OpenStreamInformation.builder().encryptionSecrets(secrets).build(); + + HeadRequest headRequest = HeadRequest.builder().s3Uri(S3URI.of("bucket", "key")).build(); + + client.headObject(headRequest, openStreamInformation); + + // Verify the encryption parameters + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(HeadObjectRequest.class); + verify(mockS3AsyncClient).headObject(requestCaptor.capture()); + + HeadObjectRequest capturedRequest = requestCaptor.getValue(); + assertEquals(ServerSideEncryption.AES256.name(), capturedRequest.sseCustomerAlgorithm()); + assertEquals(base64Key, capturedRequest.sseCustomerKey()); + assertNotNull(capturedRequest.sseCustomerKeyMD5()); + } + + @Test + void testGetObjectWithoutEncryption() { + S3AsyncClient mockS3AsyncClient = createMockClient(); + S3SdkObjectClient client = new S3SdkObjectClient(mockS3AsyncClient); + + OpenStreamInformation openStreamInformation = OpenStreamInformation.builder().build(); + + GetRequest getRequest = + GetRequest.builder() + .s3Uri(S3URI.of("bucket", "key")) + .range(new Range(0, 20)) + .etag(ETAG) + .referrer(new Referrer("bytes=0-20", ReadMode.SYNC)) + .build(); + + client.getObject(getRequest, openStreamInformation); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(GetObjectRequest.class); + verify(mockS3AsyncClient) + .getObject( + requestCaptor.capture(), + ArgumentMatchers + .>> + any()); + + GetObjectRequest capturedRequest = requestCaptor.getValue(); + assertNull(capturedRequest.sseCustomerAlgorithm()); + assertNull(capturedRequest.sseCustomerKey()); + assertNull(capturedRequest.sseCustomerKeyMD5()); + } + + @Test + void testHeadObjectWithoutEncryption() { + S3AsyncClient mockS3AsyncClient = createMockClient(); + S3SdkObjectClient client = new S3SdkObjectClient(mockS3AsyncClient); + + OpenStreamInformation openStreamInformation = OpenStreamInformation.builder().build(); + + HeadRequest headRequest = HeadRequest.builder().s3Uri(S3URI.of("bucket", "key")).build(); + + client.headObject(headRequest, openStreamInformation); + + ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(HeadObjectRequest.class); + verify(mockS3AsyncClient).headObject(requestCaptor.capture()); + + HeadObjectRequest capturedRequest = requestCaptor.getValue(); + assertNull(capturedRequest.sseCustomerAlgorithm()); + assertNull(capturedRequest.sseCustomerKey()); + assertNull(capturedRequest.sseCustomerKeyMD5()); + } }