diff --git a/lib/trino-filesystem/src/main/java/io/trino/filesystem/encryption/EncryptionKey.java b/lib/trino-filesystem/src/main/java/io/trino/filesystem/encryption/EncryptionKey.java index 474262df09a2..b56e1c151955 100644 --- a/lib/trino-filesystem/src/main/java/io/trino/filesystem/encryption/EncryptionKey.java +++ b/lib/trino-filesystem/src/main/java/io/trino/filesystem/encryption/EncryptionKey.java @@ -13,6 +13,8 @@ */ package io.trino.filesystem.encryption; +import java.util.Arrays; +import java.util.Objects; import java.util.concurrent.ThreadLocalRandom; import static java.util.Objects.requireNonNull; @@ -38,4 +40,24 @@ public String toString() // We intentionally overwrite toString to hide a key return algorithm; } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + + if (!(o instanceof EncryptionKey that)) { + return false; + } + return Objects.deepEquals(key, that.key) + && Objects.equals(algorithm, that.algorithm); + } + + @Override + public int hashCode() + { + return Objects.hash(Arrays.hashCode(key), algorithm); + } } diff --git a/plugin/trino-spooling-filesystem/pom.xml b/plugin/trino-spooling-filesystem/pom.xml index a2f0617d4955..a737edeb578a 100644 --- a/plugin/trino-spooling-filesystem/pom.xml +++ b/plugin/trino-spooling-filesystem/pom.xml @@ -17,11 +17,6 @@ - - com.google.crypto.tink - tink - - com.google.guava guava @@ -143,6 +138,36 @@ provided + + software.amazon.awssdk + auth + runtime + + + + software.amazon.awssdk + aws-core + runtime + + + + software.amazon.awssdk + regions + runtime + + + + software.amazon.awssdk + s3 + runtime + + + + software.amazon.awssdk + sdk-core + runtime + + io.airlift junit-extensions @@ -166,5 +191,23 @@ junit-jupiter-api test + + + org.testcontainers + junit-jupiter + test + + + + org.testcontainers + localstack + test + + + + org.testcontainers + testcontainers + test + diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java index 83e0651ae238..e6ed1cd65c5c 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpooledSegmentHandle.java @@ -13,9 +13,9 @@ */ package io.trino.spooling.filesystem; -import io.airlift.slice.Slice; import io.azam.ulidj.ULID; import io.trino.filesystem.Location; +import io.trino.filesystem.encryption.EncryptionKey; import io.trino.spi.QueryId; import io.trino.spi.protocol.SpooledSegmentHandle; import io.trino.spi.protocol.SpoolingContext; @@ -28,7 +28,7 @@ import static com.google.common.base.Verify.verify; import static java.util.Objects.requireNonNull; -public record FileSystemSpooledSegmentHandle(@Override String encoding, @Override QueryId queryId, byte[] uuid, Optional encryptionKey) +public record FileSystemSpooledSegmentHandle(@Override String encoding, @Override QueryId queryId, byte[] uuid, Optional encryptionKey) implements SpooledSegmentHandle { private static final String OBJECT_NAME_SEPARATOR = "::"; @@ -45,7 +45,7 @@ public static FileSystemSpooledSegmentHandle random(Random random, SpoolingConte return random(random, context, expireAt, Optional.empty()); } - public static FileSystemSpooledSegmentHandle random(Random random, SpoolingContext context, Instant expireAt, Optional encryptionKey) + public static FileSystemSpooledSegmentHandle random(Random random, SpoolingContext context, Instant expireAt, Optional encryptionKey) { return new FileSystemSpooledSegmentHandle( context.encoding(), diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java index 3dc52d29e496..60225bb1a5b2 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/FileSystemSpoolingManager.java @@ -13,15 +13,17 @@ */ package io.trino.spooling.filesystem; -import com.google.common.hash.Hashing; +import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; import io.airlift.slice.BasicSliceInput; import io.airlift.slice.DynamicSliceOutput; -import io.airlift.slice.Slice; import io.airlift.units.Duration; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystem; import io.trino.filesystem.TrinoFileSystemFactory; +import io.trino.filesystem.TrinoInputFile; +import io.trino.filesystem.TrinoOutputFile; +import io.trino.filesystem.encryption.EncryptionKey; import io.trino.spi.QueryId; import io.trino.spi.protocol.SpooledLocation; import io.trino.spi.protocol.SpooledLocation.DirectLocation; @@ -29,25 +31,22 @@ import io.trino.spi.protocol.SpoolingContext; import io.trino.spi.protocol.SpoolingManager; import io.trino.spi.security.ConnectorIdentity; +import io.trino.spooling.filesystem.encryption.EncryptionHeadersTranslator; import io.trino.spooling.filesystem.encryption.ExceptionMappingInputStream; -import java.io.FileNotFoundException; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.time.Instant; -import java.util.Base64; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Random; import java.util.concurrent.ThreadLocalRandom; -import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.filesystem.encryption.EncryptionKey.randomAes256; import static io.trino.spi.protocol.SpooledLocation.coordinatorLocation; -import static io.trino.spooling.filesystem.encryption.EncryptionUtils.decryptingInputStream; -import static io.trino.spooling.filesystem.encryption.EncryptionUtils.encryptingOutputStream; -import static io.trino.spooling.filesystem.encryption.EncryptionUtils.generateRandomKey; +import static io.trino.spooling.filesystem.encryption.EncryptionHeadersTranslator.encryptionHeadersTranslator; import static java.nio.charset.StandardCharsets.UTF_8; import static java.time.Duration.between; import static java.util.Objects.requireNonNull; @@ -56,14 +55,8 @@ public class FileSystemSpoolingManager implements SpoolingManager { - private static final String ENCRYPTION_KEY_HEADER_PREFIX = "X-Trino-SSE-C-"; - private static final String ENCRYPTION_KEY_HEADER = ENCRYPTION_KEY_HEADER_PREFIX + "Key"; - private static final String ENCRYPTION_KEY_CHECKSUM_HEADER = ENCRYPTION_KEY_HEADER_PREFIX + "SHA256"; - private static final String ENCRYPTION_KEY_CIPHER = ENCRYPTION_KEY_HEADER_PREFIX + "Cipher"; - - private static final String ENCRYPTION_CIPHER_NAME = "AES256"; - - private final String location; + private final Location location; + private final EncryptionHeadersTranslator encryptionHeadersTranslator; private final TrinoFileSystem fileSystem; private final Duration ttl; private final boolean encryptionEnabled; @@ -73,9 +66,10 @@ public class FileSystemSpoolingManager public FileSystemSpoolingManager(FileSystemSpoolingConfig config, TrinoFileSystemFactory fileSystemFactory) { requireNonNull(config, "config is null"); - this.location = config.getLocation(); + this.location = Location.of(config.getLocation()); this.fileSystem = requireNonNull(fileSystemFactory, "fileSystemFactory is null") .create(ConnectorIdentity.ofUser("ignored")); + this.encryptionHeadersTranslator = encryptionHeadersTranslator(location); this.ttl = config.getTtl(); this.encryptionEnabled = config.isEncryptionEnabled(); } @@ -84,12 +78,19 @@ public FileSystemSpoolingManager(FileSystemSpoolingConfig config, TrinoFileSyste public OutputStream createOutputStream(SpooledSegmentHandle handle) throws IOException { - FileSystemSpooledSegmentHandle filesystemHandle = (FileSystemSpooledSegmentHandle) handle; - OutputStream stream = fileSystem.newOutputFile(location(filesystemHandle)).create(); - if (filesystemHandle.encryptionKey().isPresent()) { - return encryptingOutputStream(stream, filesystemHandle.encryptionKey().get()); + FileSystemSpooledSegmentHandle fileHandle = (FileSystemSpooledSegmentHandle) handle; + Location storageLocation = location(fileHandle); + Optional encryption = fileHandle.encryptionKey(); + + TrinoOutputFile outputFile; + if (encryptionEnabled) { + outputFile = fileSystem.newEncryptedOutputFile(storageLocation, encryption.orElseThrow()); + } + else { + outputFile = fileSystem.newOutputFile(storageLocation); } - return stream; + + return outputFile.create(); } @Override @@ -97,7 +98,7 @@ public FileSystemSpooledSegmentHandle create(SpoolingContext context) { Instant expireAt = Instant.now().plusMillis(ttl.toMillis()); if (encryptionEnabled) { - return FileSystemSpooledSegmentHandle.random(random, context, expireAt, Optional.of(generateRandomKey())); + return FileSystemSpooledSegmentHandle.random(random, context, expireAt, Optional.of(randomAes256())); } return FileSystemSpooledSegmentHandle.random(random, context, expireAt); } @@ -106,22 +107,22 @@ public FileSystemSpooledSegmentHandle create(SpoolingContext context) public InputStream openInputStream(SpooledSegmentHandle handle) throws IOException { - FileSystemSpooledSegmentHandle segmentHandle = (FileSystemSpooledSegmentHandle) handle; - checkExpiration(segmentHandle); - try { - if (!fileSystem.newInputFile(location(segmentHandle)).exists()) { - throw new IOException("Segment not found or expired"); - } + FileSystemSpooledSegmentHandle fileHandle = (FileSystemSpooledSegmentHandle) handle; + checkExpiration(fileHandle); + Optional encryption = fileHandle.encryptionKey(); + Location storageLocation = location(fileHandle); + + TrinoInputFile inputFile; - InputStream stream = fileSystem.newInputFile(location(segmentHandle)).newStream(); - if (segmentHandle.encryptionKey().isPresent()) { - return new ExceptionMappingInputStream(decryptingInputStream(stream, segmentHandle.encryptionKey().get())); - } - return stream; + if (encryptionEnabled) { + inputFile = fileSystem.newEncryptedInputFile(storageLocation, encryption.orElseThrow()); } - catch (FileNotFoundException e) { - throw new IOException("Segment not found or expired", e); + else { + inputFile = fileSystem.newInputFile(storageLocation); } + + checkFileExists(inputFile); + return new ExceptionMappingInputStream(inputFile.newStream()); } @Override @@ -135,14 +136,20 @@ public void acknowledge(SpooledSegmentHandle handle) public Optional directLocation(SpooledSegmentHandle handle) throws IOException { - // TODO: implement SSE-C support in TrinoFileSystems + FileSystemSpooledSegmentHandle fileHandle = (FileSystemSpooledSegmentHandle) handle; + Location storageLocation = location(fileHandle); + Duration ttl = remainingTtl(fileHandle.expirationTime()); + Optional key = fileHandle.encryptionKey(); + + Optional directLocation; if (encryptionEnabled) { - throw new UnsupportedOperationException("Direct access not supported when encryption is enabled"); + directLocation = fileSystem.encryptedPreSignedUri(storageLocation, ttl, key.orElseThrow()) + .map(uri -> new DirectLocation(uri.uri(), uri.headers())); + } + else { + directLocation = fileSystem.preSignedUri(storageLocation, ttl) + .map(uri -> new DirectLocation(uri.uri(), uri.headers())); } - FileSystemSpooledSegmentHandle fileHandle = (FileSystemSpooledSegmentHandle) handle; - Optional directLocation = fileSystem - .preSignedUri(location(fileHandle), remainingTtl(fileHandle.expirationTime())) - .map(location -> SpooledLocation.directLocation(location.uri(), headers(fileHandle))); if (directLocation.isEmpty()) { throw new IOException("Failed to generate pre-signed URI for query %s and segment %s".formatted(fileHandle.queryId(), fileHandle.identifier())); @@ -169,9 +176,17 @@ public SpooledLocation location(SpooledSegmentHandle handle) output.writeBytes(fileHandle.queryId().toString().getBytes(UTF_8)); output.writeBytes(fileHandle.encoding().getBytes(UTF_8)); output.writeBoolean(fileHandle.encryptionKey().isPresent()); + return coordinatorLocation(output.slice(), headers(fileHandle)); } + private Map> headers(FileSystemSpooledSegmentHandle fileHandle) + { + return fileHandle.encryptionKey() + .map(encryptionHeadersTranslator::createHeaders) + .orElse(ImmutableMap.of()); + } + @Override public SpooledSegmentHandle handle(SpooledLocation location) { @@ -191,58 +206,14 @@ public SpooledSegmentHandle handle(SpooledLocation location) if (!input.readBoolean()) { return new FileSystemSpooledSegmentHandle(encoding, queryId, uuid, Optional.empty()); } - - Slice key = getEncryptionKey(location.headers()); - return new FileSystemSpooledSegmentHandle(encoding, queryId, uuid, Optional.of(key)); - } - - private static Slice getEncryptionKey(Map> headers) - { - String encryptionCipher = getOnlyHeader(headers, ENCRYPTION_KEY_CIPHER); - if (!encryptionCipher.contentEquals(ENCRYPTION_CIPHER_NAME)) { - throw new IllegalArgumentException("Unsupported encryption cipher %s".formatted(encryptionCipher)); - } - - String encryptionKey = getOnlyHeader(headers, ENCRYPTION_KEY_HEADER); - String keyChecksum = getOnlyHeader(headers, ENCRYPTION_KEY_CHECKSUM_HEADER); - if (!sha256Checksum(base64Decode(encryptionKey)).contentEquals(keyChecksum)) { - throw new IllegalArgumentException("Encryption key checksum mismatch"); - } - return base64Decode(encryptionKey); - } - - private static String getOnlyHeader(Map> headers, String headerName) - { - List values = headers.get(headerName); - if (values == null || values.isEmpty()) { - throw new IllegalArgumentException("Header %s is missing".formatted(headerName)); - } - - if (values.size() > 1) { - throw new IllegalArgumentException("Header %s has multiple values".formatted(headerName)); - } - - return values.getFirst(); - } - - private Map> headers(SpooledSegmentHandle handle) - { - FileSystemSpooledSegmentHandle fileHandle = (FileSystemSpooledSegmentHandle) handle; - if (encryptionEnabled) { - return Map.of( - ENCRYPTION_KEY_CIPHER, List.of(ENCRYPTION_CIPHER_NAME), - ENCRYPTION_KEY_HEADER, List.of(base64Encode(fileHandle.encryptionKey().orElseThrow())), - ENCRYPTION_KEY_CHECKSUM_HEADER, List.of(sha256Checksum(fileHandle.encryptionKey().orElseThrow()))); - } - return Map.of(); + return new FileSystemSpooledSegmentHandle(encoding, queryId, uuid, Optional.of(encryptionHeadersTranslator.extractKey(location.headers()))); } private Location location(FileSystemSpooledSegmentHandle handle) throws IOException { checkExpiration(handle); - return Location.of(location) - .appendPath(handle.storageObjectName()); + return location.appendPath(handle.storageObjectName()); } private Duration remainingTtl(Instant expiresAt) @@ -258,18 +229,11 @@ private void checkExpiration(FileSystemSpooledSegmentHandle handle) } } - private static String base64Encode(Slice slice) - { - return Base64.getEncoder().encodeToString(slice.getBytes()); - } - - private static Slice base64Decode(String base64) - { - return wrappedBuffer(Base64.getDecoder().decode(base64)); - } - - private static String sha256Checksum(Slice slice) + private static void checkFileExists(TrinoInputFile inputFile) + throws IOException { - return Hashing.sha256().hashBytes(slice.getBytes()).toString(); + if (!inputFile.exists()) { + throw new IOException("Segment not found or expired"); + } } } diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java new file mode 100644 index 000000000000..ec638e0ecf41 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import com.google.common.collect.ImmutableMap; +import com.google.common.hash.Hashing; +import io.trino.filesystem.encryption.EncryptionKey; + +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader; + +public class AzureEncryptionHeadersTranslator + implements EncryptionHeadersTranslator +{ + @Override + public EncryptionKey extractKey(Map> headers) + { + byte[] key = Base64.getDecoder().decode(getOnlyHeader(headers, "x-ms-encryption-key")); + String sha256Checksum = getOnlyHeader(headers, "x-ms-encryption-key-sha256"); + EncryptionKey encryption = new EncryptionKey(key, getOnlyHeader(headers, "x-ms-encryption-algorithm")); + checkArgument(sha256(encryption).equals(sha256Checksum), "Key SHA256 checksum does not match"); + return encryption; + } + + @Override + public Map> createHeaders(EncryptionKey key) + { + return ImmutableMap.of( + "x-ms-encryption-key", List.of(encoded(key)), + "x-ms-encryption-key-sha256", List.of(sha256(key)), + "x-ms-encryption-algorithm", List.of(key.algorithm())); + } + + private static String sha256(EncryptionKey key) + { + return Base64.getEncoder().encodeToString(Hashing.sha256().hashBytes(key.key()).asBytes()); + } + + private static String encoded(EncryptionKey key) + { + return Base64.getEncoder().encodeToString(key.key()); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java new file mode 100644 index 000000000000..854089641fcb --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import io.trino.filesystem.Location; +import io.trino.filesystem.encryption.EncryptionKey; + +import java.util.List; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public interface EncryptionHeadersTranslator +{ + EncryptionKey extractKey(Map> headers); + + Map> createHeaders(EncryptionKey encryption); + + static EncryptionHeadersTranslator encryptionHeadersTranslator(Location location) + { + requireNonNull(location, "location is null"); + return location.scheme() + .map(EncryptionHeadersTranslator::forScheme) + .orElseThrow(() -> new IllegalArgumentException("Unknown location scheme: " + location)); + } + + private static EncryptionHeadersTranslator forScheme(String scheme) + { + // These should match schemes supported in the FileSystemSpoolingModule + return switch (scheme) { + case "s3" -> new S3EncryptionHeadersTranslator(); + case "gs" -> new GcsEncryptionHeadersTranslator(); + case "abfs" -> new AzureEncryptionHeadersTranslator(); + default -> throw new IllegalArgumentException("Unknown file system scheme: " + scheme); + }; + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionUtils.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionUtils.java deleted file mode 100644 index bffd24b07b0f..000000000000 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionUtils.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.spooling.filesystem.encryption; - -import com.google.crypto.tink.BinaryKeysetReader; -import com.google.crypto.tink.CleartextKeysetHandle; -import com.google.crypto.tink.KeyTemplates; -import com.google.crypto.tink.KeysetHandle; -import com.google.crypto.tink.StreamingAead; -import com.google.crypto.tink.mac.MacConfig; -import com.google.crypto.tink.streamingaead.StreamingAeadConfig; -import io.airlift.slice.Slice; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.OutputStream; -import java.security.GeneralSecurityException; - -import static com.google.crypto.tink.BinaryKeysetWriter.withOutputStream; -import static io.airlift.slice.Slices.wrappedBuffer; - -public class EncryptionUtils -{ - private static volatile boolean initialized; - - private EncryptionUtils() {} - - public static Slice generateRandomKey() - { - ensureInitialized(); - - try (ByteArrayOutputStream outputStream = new ByteArrayOutputStream()) { - KeysetHandle aes256KeySet = KeysetHandle.generateNew(KeyTemplates.get("AES256_GCM_HKDF_1MB")); - CleartextKeysetHandle.write(aes256KeySet, withOutputStream(outputStream)); - return wrappedBuffer(outputStream.toByteArray()); - } - catch (IOException | GeneralSecurityException e) { - throw new RuntimeException(e); - } - } - - public static KeysetHandle readKey(Slice key) - { - try { - return CleartextKeysetHandle.read(BinaryKeysetReader.withBytes(key.getBytes())); - } - catch (IOException | GeneralSecurityException e) { - throw new RuntimeException(e); - } - } - - public static OutputStream encryptingOutputStream(OutputStream output, Slice key) - throws IOException - { - ensureInitialized(); - - try { - return readKey(key) - .getPrimitive(StreamingAead.class) - .newEncryptingStream(output, new byte[0]); - } - catch (GeneralSecurityException e) { - throw new IOException("Could not initialize encryption", e); - } - } - - public static InputStream decryptingInputStream(InputStream input, Slice key) - throws IOException - { - ensureInitialized(); - - try { - return readKey(key) - .getPrimitive(StreamingAead.class) - .newDecryptingStream(input, new byte[0]); - } - catch (GeneralSecurityException e) { - throw new IOException("Could not initialize decryption", e); - } - } - - private static void ensureInitialized() - { - if (initialized) { - return; - } - - synchronized (EncryptionUtils.class) { - if (initialized) { - return; - } - - try { - StreamingAeadConfig.register(); - - MacConfig.register(); - initialized = true; - } - catch (GeneralSecurityException e) { - throw new RuntimeException("Could not initialize encryption", e); - } - } - } -} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java new file mode 100644 index 000000000000..f4fa875b3791 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.hash.Hashing; +import io.trino.filesystem.encryption.EncryptionKey; + +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader; + +public class GcsEncryptionHeadersTranslator + implements EncryptionHeadersTranslator +{ + @Override + public EncryptionKey extractKey(Map> headers) + { + byte[] key = Base64.getDecoder().decode(getOnlyHeader(headers, "x-goog-encryption-key")); + String sha256Checksum = getOnlyHeader(headers, "x-goog-encryption-key-sha256"); + EncryptionKey encryption = new EncryptionKey(key, getOnlyHeader(headers, "x-goog-encryption-algorithm")); + checkArgument(sha256(encryption).equals(sha256Checksum), "Key SHA256 checksum does not match"); + return encryption; + } + + @Override + public Map> createHeaders(EncryptionKey encryption) + { + return ImmutableMap.of( + "x-goog-encryption-algorithm", ImmutableList.of(encryption.algorithm()), + "x-goog-encryption-key", ImmutableList.of(encoded(encryption)), + "x-goog-encryption-key-sha256", ImmutableList.of(sha256(encryption))); + } + + private static String sha256(EncryptionKey key) + { + return Base64.getEncoder().encodeToString(Hashing.sha256().hashBytes(key.key()).asBytes()); + } + + private static String encoded(EncryptionKey key) + { + return Base64.getEncoder().encodeToString(key.key()); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java new file mode 100644 index 000000000000..40532a76f1c7 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; + +public class HeadersUtils +{ + private HeadersUtils() {} + + public static String getOnlyHeader(Map> headers, String name) + { + List values = headers.get(name); + checkArgument(values != null && !values.isEmpty(), "Required header " + name + " was not found"); + checkArgument(values.size() == 1, "Required header " + name + " contains more than one value"); + return values.getFirst(); + } +} diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java new file mode 100644 index 000000000000..de902cabe4e2 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.hash.Hashing; +import io.trino.filesystem.encryption.EncryptionKey; + +import java.util.Base64; +import java.util.List; +import java.util.Map; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader; + +public class S3EncryptionHeadersTranslator + implements EncryptionHeadersTranslator +{ + @Override + public EncryptionKey extractKey(Map> headers) + { + byte[] key = Base64.getDecoder().decode(getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key")); + String md5Checksum = getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key-MD5"); + EncryptionKey encryption = new EncryptionKey(key, getOnlyHeader(headers, "x-amz-server-side-encryption-customer-algorithm")); + checkArgument(md5(encryption).equals(md5Checksum), "Key MD5 checksum does not match"); + return encryption; + } + + @Override + public Map> createHeaders(EncryptionKey encryption) + { + return ImmutableMap.of( + "x-amz-server-side-encryption-customer-algorithm", ImmutableList.of(encryption.algorithm()), + "x-amz-server-side-encryption-customer-key", ImmutableList.of(encoded(encryption)), + "x-amz-server-side-encryption-customer-key-MD5", ImmutableList.of(md5(encryption))); + } + + public static String encoded(EncryptionKey key) + { + return Base64.getEncoder().encodeToString(key.key()); + } + + @SuppressWarnings("deprecation") // AWS SSE-C requires MD5 checksum + public static String md5(EncryptionKey key) + { + return Base64.getEncoder().encodeToString(Hashing.md5().hashBytes(key.key()).asBytes()); + } +} diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/AbstractFileSystemSpoolingManagerTest.java similarity index 66% rename from plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java rename to plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/AbstractFileSystemSpoolingManagerTest.java index bc0c83e3f7c8..2398d2568e27 100644 --- a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManager.java +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/AbstractFileSystemSpoolingManagerTest.java @@ -13,19 +13,13 @@ */ package io.trino.spooling.filesystem; -import io.airlift.units.DataSize; import io.azam.ulidj.ULID; -import io.trino.filesystem.s3.S3FileSystemConfig; -import io.trino.filesystem.s3.S3FileSystemFactory; -import io.trino.filesystem.s3.S3FileSystemStats; +import io.trino.filesystem.encryption.EncryptionKey; import io.trino.spi.QueryId; import io.trino.spi.protocol.SpooledLocation; import io.trino.spi.protocol.SpooledSegmentHandle; import io.trino.spi.protocol.SpoolingContext; import io.trino.spi.protocol.SpoolingManager; -import io.trino.testing.containers.Minio; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -33,38 +27,16 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Optional; -import java.util.UUID; -import static io.opentelemetry.api.OpenTelemetry.noop; -import static io.trino.spooling.filesystem.encryption.EncryptionUtils.generateRandomKey; -import static io.trino.testing.containers.Minio.MINIO_REGION; +import static io.trino.filesystem.encryption.EncryptionKey.randomAes256; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; @TestInstance(PER_CLASS) -public class TestFileSystemSpoolingManager +public abstract class AbstractFileSystemSpoolingManagerTest { - private static final String BUCKET_NAME = "spooling" + UUID.randomUUID().toString() - .replace("-", ""); - - private Minio minio; - - @BeforeAll - public void setup() - { - minio = Minio.builder().build(); - minio.start(); - minio.createBucket(BUCKET_NAME); - } - - @AfterAll - public void teardown() - { - minio.stop(); - } - @Test public void testRetrieveSpooledSegment() throws Exception @@ -109,7 +81,8 @@ public void testAcknowledgedSegmentCantBeRetrievedAgain() @Test public void testHandleRoundTrip() { - FileSystemSpooledSegmentHandle handle = new FileSystemSpooledSegmentHandle("json", QueryId.valueOf("a"), ULID.randomBinary(), Optional.of(generateRandomKey())); + EncryptionKey key = randomAes256(); + FileSystemSpooledSegmentHandle handle = new FileSystemSpooledSegmentHandle("json", QueryId.valueOf("a"), ULID.randomBinary(), Optional.of(key)); SpooledLocation location = getSpoolingManager().location(handle); FileSystemSpooledSegmentHandle handle2 = (FileSystemSpooledSegmentHandle) getSpoolingManager().handle(location); @@ -117,21 +90,8 @@ public void testHandleRoundTrip() assertThat(handle.storageObjectName()).isEqualTo(handle2.storageObjectName()); assertThat(handle.uuid()).isEqualTo(handle2.uuid()); assertThat(handle.expirationTime()).isEqualTo(handle2.expirationTime()); - assertThat(handle.encryptionKey()).isEqualTo(handle2.encryptionKey()); + assertThat(handle2.encryptionKey()).isPresent().hasValue(key); } - private SpoolingManager getSpoolingManager() - { - FileSystemSpoolingConfig spoolingConfig = new FileSystemSpoolingConfig(); - spoolingConfig.setS3Enabled(true); - spoolingConfig.setLocation("s3://%s/".formatted(BUCKET_NAME)); - S3FileSystemConfig filesystemConfig = new S3FileSystemConfig() - .setEndpoint(minio.getMinioAddress()) - .setRegion(MINIO_REGION) - .setPathStyleAccess(true) - .setAwsAccessKey(Minio.MINIO_ACCESS_KEY) - .setAwsSecretKey(Minio.MINIO_SECRET_KEY) - .setStreamingPartSize(DataSize.valueOf("5.5MB")); - return new FileSystemSpoolingManager(spoolingConfig, new S3FileSystemFactory(noop(), filesystemConfig, new S3FileSystemStats())); - } + protected abstract SpoolingManager getSpoolingManager(); } diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManagerLocalStack.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManagerLocalStack.java new file mode 100644 index 000000000000..8be8f4fbe889 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManagerLocalStack.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import io.airlift.units.DataSize; +import io.trino.filesystem.s3.S3FileSystemConfig; +import io.trino.filesystem.s3.S3FileSystemFactory; +import io.trino.filesystem.s3.S3FileSystemStats; +import io.trino.spi.protocol.SpoolingManager; +import org.junit.jupiter.api.BeforeAll; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.utility.DockerImageName; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; + +import static io.opentelemetry.api.OpenTelemetry.noop; +import static org.testcontainers.containers.localstack.LocalStackContainer.Service.S3; + +@Testcontainers +public class TestFileSystemSpoolingManagerLocalStack + extends AbstractFileSystemSpoolingManagerTest +{ + private static final String BUCKET_NAME = "test-bucket"; + + @Container + private static final LocalStackContainer LOCALSTACK = new LocalStackContainer(DockerImageName.parse("localstack/localstack:3.7.0")) + .withServices(S3); + + @BeforeAll + public void setup() + { + try (S3Client s3Client = createS3Client()) { + s3Client.createBucket(builder -> builder.bucket(BUCKET_NAME).build()); + } + } + + @Override + protected SpoolingManager getSpoolingManager() + { + FileSystemSpoolingConfig spoolingConfig = new FileSystemSpoolingConfig(); + spoolingConfig.setS3Enabled(true); + spoolingConfig.setLocation("s3://%s/".formatted(BUCKET_NAME)); + spoolingConfig.setEncryptionEnabled(true); // Localstack supports SSE-C so we can test it + S3FileSystemConfig filesystemConfig = new S3FileSystemConfig() + .setEndpoint(LOCALSTACK.getEndpointOverride(LocalStackContainer.Service.S3).toString()) + .setRegion(LOCALSTACK.getRegion()) + .setAwsAccessKey(LOCALSTACK.getAccessKey()) + .setAwsSecretKey(LOCALSTACK.getSecretKey()) + .setStreamingPartSize(DataSize.valueOf("5.5MB")); + return new FileSystemSpoolingManager(spoolingConfig, new S3FileSystemFactory(noop(), filesystemConfig, new S3FileSystemStats())); + } + + protected S3Client createS3Client() + { + return S3Client.builder() + .endpointOverride(LOCALSTACK.getEndpointOverride(LocalStackContainer.Service.S3)) + .region(Region.of(LOCALSTACK.getRegion())) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create(LOCALSTACK.getAccessKey(), LOCALSTACK.getSecretKey()))) + .build(); + } +} diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManagerMinio.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManagerMinio.java new file mode 100644 index 000000000000..1960377ee006 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/TestFileSystemSpoolingManagerMinio.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem; + +import io.airlift.units.DataSize; +import io.trino.filesystem.s3.S3FileSystemConfig; +import io.trino.filesystem.s3.S3FileSystemFactory; +import io.trino.filesystem.s3.S3FileSystemStats; +import io.trino.spi.protocol.SpoolingManager; +import io.trino.testing.containers.Minio; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +import java.util.UUID; + +import static io.opentelemetry.api.OpenTelemetry.noop; +import static io.trino.testing.containers.Minio.MINIO_REGION; + +public class TestFileSystemSpoolingManagerMinio + extends AbstractFileSystemSpoolingManagerTest +{ + private static final String BUCKET_NAME = "spooling" + UUID.randomUUID().toString() + .replace("-", ""); + + private Minio minio; + + @BeforeAll + public void setup() + { + minio = Minio.builder().build(); + minio.start(); + minio.createBucket(BUCKET_NAME); + } + + @AfterAll + public void teardown() + { + minio.stop(); + } + + @Override + protected SpoolingManager getSpoolingManager() + { + FileSystemSpoolingConfig spoolingConfig = new FileSystemSpoolingConfig(); + spoolingConfig.setS3Enabled(true); + spoolingConfig.setLocation("s3://%s/".formatted(BUCKET_NAME)); + spoolingConfig.setEncryptionEnabled(false); // Minio doesn't support SSE-C without TLS + S3FileSystemConfig filesystemConfig = new S3FileSystemConfig() + .setEndpoint(minio.getMinioAddress()) + .setRegion(MINIO_REGION) + .setPathStyleAccess(true) + .setAwsAccessKey(Minio.MINIO_ACCESS_KEY) + .setAwsSecretKey(Minio.MINIO_SECRET_KEY) + .setStreamingPartSize(DataSize.valueOf("5.5MB")); + return new FileSystemSpoolingManager(spoolingConfig, new S3FileSystemFactory(noop(), filesystemConfig, new S3FileSystemStats())); + } +} diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java new file mode 100644 index 000000000000..a7fbd295d9ed --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import com.google.common.collect.ImmutableMap; +import io.trino.filesystem.encryption.EncryptionKey; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestAzureEncryptionHeadersTranslator +{ + private static final EncryptionHeadersTranslator SSE = new AzureEncryptionHeadersTranslator(); + + @Test + public void testKnownKey() + { + byte[] key = "TrinoWillFlyWithSpooledProtocol!".getBytes(UTF_8); + EncryptionKey encryption = new EncryptionKey(key, "AES256"); + + Map> headers = SSE.createHeaders(encryption); + assertThat(headers) + .hasSize(3) + .containsEntry("x-ms-encryption-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE=")) + .containsEntry("x-ms-encryption-key-sha256", List.of("bXwXXQkzTJYdEN+cDvfUtOobMCc1kKoPVD6aVi1wb9A=")) + .containsEntry("x-ms-encryption-algorithm", List.of("AES256")); + } + + @Test + public void testRoundTrip() + { + EncryptionKey key = EncryptionKey.randomAes256(); + assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); + } + + @Test + public void testThrowsOnInvalidChecksum() + { + Map> headers = ImmutableMap.of( + "x-ms-encryption-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE="), + "x-ms-encryption-key-sha256", List.of("brokenchecksum"), + "x-ms-encryption-algorithm", List.of("AES256")); + + assertThatThrownBy(() -> SSE.extractKey(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Key SHA256 checksum does not match"); + } +} diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java new file mode 100644 index 000000000000..2b58a85527ad --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java @@ -0,0 +1,65 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import com.google.common.collect.ImmutableMap; +import io.trino.filesystem.encryption.EncryptionKey; +import org.assertj.core.api.AssertionsForClassTypes; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestGcsEncryptionHeadersTranslator +{ + private static final EncryptionHeadersTranslator SSE = new GcsEncryptionHeadersTranslator(); + + @Test + public void testKnownKey() + { + byte[] key = "TrinoWillFlyWithSpooledProtocol!".getBytes(UTF_8); + EncryptionKey encryption = new EncryptionKey(key, "AES256"); + + Map> headers = SSE.createHeaders(encryption); + assertThat(headers) + .hasSize(3) + .containsEntry("x-goog-encryption-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE=")) + .containsEntry("x-goog-encryption-key-sha256", List.of("bXwXXQkzTJYdEN+cDvfUtOobMCc1kKoPVD6aVi1wb9A=")) + .containsEntry("x-goog-encryption-algorithm", List.of("AES256")); + } + + @Test + public void testRoundTrip() + { + EncryptionKey key = EncryptionKey.randomAes256(); + AssertionsForClassTypes.assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); + } + + @Test + public void testThrowsOnInvalidChecksum() + { + Map> headers = ImmutableMap.of( + "x-goog-encryption-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE="), + "x-goog-encryption-key-sha256", List.of("brokenchecksum"), + "x-goog-encryption-algorithm", List.of("AES256")); + + assertThatThrownBy(() -> SSE.extractKey(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Key SHA256 checksum does not match"); + } +} diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java new file mode 100644 index 000000000000..71267cd4cb33 --- /dev/null +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.spooling.filesystem.encryption; + +import com.google.common.collect.ImmutableMap; +import io.trino.filesystem.encryption.EncryptionKey; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class TestS3EncryptionHeadersTranslator +{ + private static final EncryptionHeadersTranslator SSE = new S3EncryptionHeadersTranslator(); + + @Test + public void testKnownKey() + { + byte[] key = "TrinoWillFlyWithSpooledProtocol!".getBytes(UTF_8); + EncryptionKey encryption = new EncryptionKey(key, "AES256"); + + Map> headers = SSE.createHeaders(encryption); + assertThat(headers) + .hasSize(3) + .containsEntry("x-amz-server-side-encryption-customer-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE=")) + .containsEntry("x-amz-server-side-encryption-customer-key-MD5", List.of("CX3f4fSIpiyVyQDCzuhDWg==")) + .containsEntry("x-amz-server-side-encryption-customer-algorithm", List.of("AES256")); + } + + @Test + public void testRoundTrip() + { + EncryptionKey key = EncryptionKey.randomAes256(); + assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); + } + + @Test + public void testThrowsOnInvalidChecksum() + { + Map> headers = ImmutableMap.of( + "x-amz-server-side-encryption-customer-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE="), + "x-amz-server-side-encryption-customer-key-MD5", List.of("brokenchecksum"), + "x-amz-server-side-encryption-customer-algorithm", List.of("AES256")); + + assertThatThrownBy(() -> SSE.extractKey(headers)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Key MD5 checksum does not match"); + } +} diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 866ad75b9cb3..679c1b8a40d7 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -305,12 +305,53 @@ jmh-core test + + org.testcontainers + localstack + test + org.yaml snakeyaml test + + + software.amazon.awssdk + auth + test + + + + software.amazon.awssdk + aws-core + test + + + + software.amazon.awssdk + regions + test + + + + software.amazon.awssdk + s3 + test + + + commons-logging + commons-logging + + + + + + software.amazon.awssdk + sdk-core + test + diff --git a/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java b/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java index e9995f520ea9..07b842603641 100644 --- a/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java +++ b/testing/trino-tests/src/test/java/io/trino/server/protocol/AbstractSpooledQueryDataDistributedQueries.java @@ -27,9 +27,15 @@ import io.trino.testing.QueryRunner; import io.trino.testing.TestingStatementClientFactory; import io.trino.testing.TestingTrinoClient; -import io.trino.testing.containers.Minio; import io.trino.tpch.TpchTable; import okhttp3.OkHttpClient; +import org.testcontainers.containers.localstack.LocalStackContainer; +import org.testcontainers.containers.localstack.LocalStackContainer.Service; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.CreateBucketRequest; import java.util.Map; import java.util.Optional; @@ -37,16 +43,13 @@ import static io.airlift.testing.Closeables.closeAllSuppress; import static io.trino.client.StatementClientFactory.newStatementClient; -import static io.trino.testing.containers.Minio.MINIO_ACCESS_KEY; -import static io.trino.testing.containers.Minio.MINIO_REGION; -import static io.trino.testing.containers.Minio.MINIO_SECRET_KEY; import static io.trino.util.Ciphers.createRandomAesEncryptionKey; import static java.util.Base64.getEncoder; public abstract class AbstractSpooledQueryDataDistributedQueries extends AbstractTestEngineOnlyQueries { - private Minio minio; + private LocalStackContainer localstack; protected abstract String encoding(); @@ -59,11 +62,14 @@ protected Map spoolingConfig() protected QueryRunner createQueryRunner() throws Exception { - minio = closeAfterClass(Minio.builder().build()); - minio.start(); + localstack = closeAfterClass(new LocalStackContainer("s3-latest")); + localstack.start(); String bucketName = "segments" + UUID.randomUUID(); - minio.createBucket(bucketName, true); + + try (S3Client client = createS3Client(localstack)) { + client.createBucket(CreateBucketRequest.builder().bucket(bucketName).build()); + } DistributedQueryRunner queryRunner = MemoryQueryRunner.builder() .setInitialTables(TpchTable.getTables()) @@ -75,13 +81,11 @@ protected QueryRunner createQueryRunner() Map spoolingConfig = ImmutableMap.builder() .put("fs.s3.enabled", "true") .put("fs.location", "s3://" + bucketName + "/") - // Direct storage access with encryption requires SSE-c which is not yet implemented - .put("fs.segment.encryption", "false") - .put("s3.endpoint", minio.getMinioAddress()) - .put("s3.region", MINIO_REGION) - .put("s3.aws-access-key", MINIO_ACCESS_KEY) - .put("s3.aws-secret-key", MINIO_SECRET_KEY) - .put("s3.path-style-access", "true") + .put("fs.segment.encryption", "true") + .put("s3.endpoint", localstack.getEndpointOverride(Service.S3).toString()) + .put("s3.region", localstack.getRegion()) + .put("s3.aws-access-key", localstack.getAccessKey()) + .put("s3.aws-secret-key", localstack.getSecretKey()) .putAll(spoolingConfig()) .buildKeepingLast(); runner.loadSpoolingManager("filesystem", spoolingConfig); @@ -119,4 +123,14 @@ private static String randomAES256Key() { return getEncoder().encodeToString(createRandomAesEncryptionKey().getEncoded()); } + + protected S3Client createS3Client(LocalStackContainer localstack) + { + return S3Client.builder() + .endpointOverride(localstack.getEndpointOverride(Service.S3)) + .region(Region.of(localstack.getRegion())) + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create(localstack.getAccessKey(), localstack.getSecretKey()))) + .build(); + } }