diff --git a/api/src/main/java/org/apache/iceberg/encryption/EncryptedKey.java b/api/src/main/java/org/apache/iceberg/encryption/EncryptedKey.java index c7e594efd147..c7e987e84b7d 100644 --- a/api/src/main/java/org/apache/iceberg/encryption/EncryptedKey.java +++ b/api/src/main/java/org/apache/iceberg/encryption/EncryptedKey.java @@ -18,10 +18,11 @@ */ package org.apache.iceberg.encryption; +import java.io.Serializable; import java.nio.ByteBuffer; import java.util.Map; -public interface EncryptedKey { +public interface EncryptedKey extends Serializable { String keyId(); ByteBuffer encryptedKeyMetadata(); diff --git a/aws/src/integration/java/org/apache/iceberg/aws/TestKeyManagementClient.java b/aws/src/integration/java/org/apache/iceberg/aws/TestKeyManagementClient.java index 83bacf2601cd..ef84b23f8f27 100644 --- a/aws/src/integration/java/org/apache/iceberg/aws/TestKeyManagementClient.java +++ b/aws/src/integration/java/org/apache/iceberg/aws/TestKeyManagementClient.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.util.Map; +import org.apache.iceberg.TestHelpers; import org.apache.iceberg.encryption.KeyManagementClient; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.AfterAll; @@ -31,6 +32,7 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.NullSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,6 +40,7 @@ import software.amazon.awssdk.services.kms.model.CreateKeyRequest; import software.amazon.awssdk.services.kms.model.CreateKeyResponse; import software.amazon.awssdk.services.kms.model.DataKeySpec; +import software.amazon.awssdk.services.kms.model.EncryptionAlgorithmSpec; import software.amazon.awssdk.services.kms.model.KeySpec; import software.amazon.awssdk.services.kms.model.ScheduleKeyDeletionRequest; import software.amazon.awssdk.services.kms.model.ScheduleKeyDeletionResponse; @@ -91,13 +94,42 @@ public void testKeyWrapping() { try (AwsKeyManagementClient keyManagementClient = new AwsKeyManagementClient()) { keyManagementClient.initialize(ImmutableMap.of()); - ByteBuffer key = ByteBuffer.wrap(new String("super-secret-table-master-key").getBytes()); + ByteBuffer key = ByteBuffer.wrap("super-secret-table-master-key".getBytes()); ByteBuffer encryptedKey = keyManagementClient.wrapKey(key, keyId); assertThat(keyManagementClient.unwrapKey(encryptedKey, keyId)).isEqualTo(key); } } + @ParameterizedTest + @MethodSource("org.apache.iceberg.TestHelpers#serializers") + public void testSerialization( + TestHelpers.RoundTripSerializer roundTripSerializer) + throws Exception { + try (AwsKeyManagementClient keyManagementClient = new AwsKeyManagementClient()) { + keyManagementClient.initialize( + ImmutableMap.of( + AwsProperties.KMS_ENCRYPTION_ALGORITHM_SPEC, + EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256.toString(), + AwsProperties.KMS_DATA_KEY_SPEC, + DataKeySpec.AES_128.toString())); + assertThat(keyManagementClient.encryptionAlgorithmSpec()) + .isEqualTo(EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256); + assertThat(keyManagementClient.dataKeySpec()).isEqualTo(DataKeySpec.AES_128); + + AwsKeyManagementClient result = roundTripSerializer.apply(keyManagementClient); + + ByteBuffer key = ByteBuffer.wrap("super-secret-table-master-key".getBytes()); + ByteBuffer encryptedKey = result.wrapKey(key, keyId); + + assertThat(keyManagementClient.unwrapKey(encryptedKey, keyId)).isEqualTo(key); + assertThat(result.unwrapKey(encryptedKey, keyId)).isEqualTo(key); + assertThat(result.encryptionAlgorithmSpec()) + .isEqualTo(EncryptionAlgorithmSpec.RSAES_OAEP_SHA_256); + assertThat(result.dataKeySpec()).isEqualTo(DataKeySpec.AES_128); + } + } + @ParameterizedTest @NullSource @EnumSource( diff --git a/aws/src/main/java/org/apache/iceberg/aws/AwsKeyManagementClient.java b/aws/src/main/java/org/apache/iceberg/aws/AwsKeyManagementClient.java index 6d2671f4e26d..3b1c13ebe36f 100644 --- a/aws/src/main/java/org/apache/iceberg/aws/AwsKeyManagementClient.java +++ b/aws/src/main/java/org/apache/iceberg/aws/AwsKeyManagementClient.java @@ -20,7 +20,9 @@ import java.nio.ByteBuffer; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.iceberg.encryption.KeyManagementClient; +import org.apache.iceberg.util.SerializableMap; import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.services.kms.KmsClient; import software.amazon.awssdk.services.kms.model.DataKeySpec; @@ -39,14 +41,17 @@ */ public class AwsKeyManagementClient implements KeyManagementClient { - private KmsClient kmsClient; + private final AtomicBoolean isResourceClosed = new AtomicBoolean(false); + + private Map allProperties; private EncryptionAlgorithmSpec encryptionAlgorithmSpec; private DataKeySpec dataKeySpec; + private transient volatile KmsClient kmsClient; + @Override public void initialize(Map properties) { - AwsClientFactory clientFactory = AwsClientFactories.from(properties); - this.kmsClient = clientFactory.kms(); + this.allProperties = SerializableMap.copyOf(properties); AwsProperties awsProperties = new AwsProperties(properties); this.encryptionAlgorithmSpec = awsProperties.kmsEncryptionAlgorithmSpec(); @@ -62,7 +67,7 @@ public ByteBuffer wrapKey(ByteBuffer key, String wrappingKeyId) { .plaintext(SdkBytes.fromByteBuffer(key)) .build(); - EncryptResponse result = kmsClient.encrypt(request); + EncryptResponse result = kmsClient().encrypt(request); return result.ciphertextBlob().asByteBuffer(); } @@ -76,11 +81,9 @@ public KeyGenerationResult generateKey(String wrappingKeyId) { GenerateDataKeyRequest request = GenerateDataKeyRequest.builder().keyId(wrappingKeyId).keySpec(dataKeySpec).build(); - GenerateDataKeyResponse response = kmsClient.generateDataKey(request); - KeyGenerationResult result = - new KeyGenerationResult( - response.plaintext().asByteBuffer(), response.ciphertextBlob().asByteBuffer()); - return result; + GenerateDataKeyResponse response = kmsClient().generateDataKey(request); + return new KeyGenerationResult( + response.plaintext().asByteBuffer(), response.ciphertextBlob().asByteBuffer()); } @Override @@ -92,14 +95,36 @@ public ByteBuffer unwrapKey(ByteBuffer wrappedKey, String wrappingKeyId) { .ciphertextBlob(SdkBytes.fromByteBuffer(wrappedKey)) .build(); - DecryptResponse result = kmsClient.decrypt(request); + DecryptResponse result = kmsClient().decrypt(request); return result.plaintext().asByteBuffer(); } @Override public void close() { - if (kmsClient != null) { - kmsClient.close(); + if (isResourceClosed.compareAndSet(false, true)) { + if (kmsClient != null) { + kmsClient.close(); + } + } + } + + EncryptionAlgorithmSpec encryptionAlgorithmSpec() { + return encryptionAlgorithmSpec; + } + + DataKeySpec dataKeySpec() { + return dataKeySpec; + } + + private KmsClient kmsClient() { + if (kmsClient == null) { + synchronized (this) { + if (kmsClient == null) { + AwsClientFactory clientFactory = AwsClientFactories.from(allProperties); + kmsClient = clientFactory.kms(); + } + } } + return kmsClient; } } diff --git a/azure/src/integration/java/org/apache/iceberg/azure/keymanagement/TestAzureKeyManagementClient.java b/azure/src/integration/java/org/apache/iceberg/azure/keymanagement/TestAzureKeyManagementClient.java index 32adcd46b702..88b498d98fa8 100644 --- a/azure/src/integration/java/org/apache/iceberg/azure/keymanagement/TestAzureKeyManagementClient.java +++ b/azure/src/integration/java/org/apache/iceberg/azure/keymanagement/TestAzureKeyManagementClient.java @@ -27,6 +27,7 @@ import com.azure.security.keyvault.keys.models.KeyType; import java.nio.ByteBuffer; import java.time.Duration; +import org.apache.iceberg.TestHelpers; import org.apache.iceberg.encryption.KeyManagementClient; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.junit.jupiter.api.AfterAll; @@ -34,6 +35,8 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; @EnabledIfEnvironmentVariables({ @EnabledIfEnvironmentVariable(named = "AZURE_KEYVAULT_URL", matches = ".*") @@ -41,21 +44,21 @@ public class TestAzureKeyManagementClient { private static final String ICEBERG_TEST_KEY_NAME = "iceberg-test-key"; - private static KeyClient keyClient; + private static final String KEY_VAULT_URI = System.getenv("AZURE_KEYVAULT_URL"); private static KeyManagementClient azureKeyManagementClient; + private static KeyClient keyClient; @BeforeAll public static void beforeClass() { - String keyVaultUri = System.getenv("AZURE_KEYVAULT_URL"); keyClient = new KeyClientBuilder() - .vaultUrl(keyVaultUri) + .vaultUrl(KEY_VAULT_URI) .credential(new DefaultAzureCredentialBuilder().build()) .buildClient(); keyClient.createKey(ICEBERG_TEST_KEY_NAME, KeyType.RSA); azureKeyManagementClient = new AzureKeyManagementClient(); - azureKeyManagementClient.initialize(ImmutableMap.of(AZURE_KEYVAULT_URL, keyVaultUri)); + azureKeyManagementClient.initialize(ImmutableMap.of(AZURE_KEYVAULT_URL, KEY_VAULT_URI)); } @AfterAll @@ -81,4 +84,22 @@ public void keyWrapping() { public void keyGenerationNotSupported() { assertThat(azureKeyManagementClient.supportsKeyGeneration()).isFalse(); } + + @ParameterizedTest + @MethodSource("org.apache.iceberg.TestHelpers#serializers") + public void testSerialization( + TestHelpers.RoundTripSerializer roundTripSerializer) + throws Exception { + try (AzureKeyManagementClient keyManagementClient = new AzureKeyManagementClient()) { + keyManagementClient.initialize(ImmutableMap.of(AZURE_KEYVAULT_URL, KEY_VAULT_URI)); + + AzureKeyManagementClient result = roundTripSerializer.apply(keyManagementClient); + + ByteBuffer key = ByteBuffer.wrap("super-secret-table-master-key".getBytes()); + ByteBuffer encryptedKey = result.wrapKey(key, ICEBERG_TEST_KEY_NAME); + + assertThat(keyManagementClient.unwrapKey(encryptedKey, ICEBERG_TEST_KEY_NAME)).isEqualTo(key); + assertThat(result.unwrapKey(encryptedKey, ICEBERG_TEST_KEY_NAME)).isEqualTo(key); + } + } } diff --git a/azure/src/main/java/org/apache/iceberg/azure/keymanagement/AzureKeyManagementClient.java b/azure/src/main/java/org/apache/iceberg/azure/keymanagement/AzureKeyManagementClient.java index 4732d3d410c4..66bf0678bce9 100644 --- a/azure/src/main/java/org/apache/iceberg/azure/keymanagement/AzureKeyManagementClient.java +++ b/azure/src/main/java/org/apache/iceberg/azure/keymanagement/AzureKeyManagementClient.java @@ -29,40 +29,81 @@ import org.apache.iceberg.azure.AzureProperties; import org.apache.iceberg.encryption.KeyManagementClient; import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.SerializableMap; /** Azure key management client which connects to Azure Key Vault. */ public class AzureKeyManagementClient implements KeyManagementClient { - private KeyClient keyClient; - private KeyWrapAlgorithm keyWrapAlgorithm; + + private Map allProperties; + + private transient volatile ClientState state; @Override public void initialize(Map properties) { - AzureProperties azureProperties = new AzureProperties(properties); - - this.keyWrapAlgorithm = azureProperties.keyWrapAlgorithm(); - KeyClientBuilder keyClientBuilder = new KeyClientBuilder(); - azureProperties.keyVaultUrl().ifPresent(keyClientBuilder::vaultUrl); - this.keyClient = - keyClientBuilder - .credential(AdlsTokenCredentialProviders.from(properties).credential()) - .buildClient(); + this.allProperties = SerializableMap.copyOf(properties); } @Override public ByteBuffer wrapKey(ByteBuffer key, String wrappingKeyId) { WrapResult wrapResult = - keyClient + keyClient() .getCryptographyClient(wrappingKeyId) - .wrapKey(keyWrapAlgorithm, ByteBuffers.toByteArray(key)); + .wrapKey(keyWrapAlgorithm(), ByteBuffers.toByteArray(key)); return ByteBuffer.wrap(wrapResult.getEncryptedKey()); } @Override public ByteBuffer unwrapKey(ByteBuffer wrappedKey, String wrappingKeyId) { UnwrapResult unwrapResult = - keyClient + keyClient() .getCryptographyClient(wrappingKeyId) - .unwrapKey(keyWrapAlgorithm, ByteBuffers.toByteArray(wrappedKey)); + .unwrapKey(keyWrapAlgorithm(), ByteBuffers.toByteArray(wrappedKey)); return ByteBuffer.wrap(unwrapResult.getKey()); } + + private KeyClient keyClient() { + return state().keyClient(); + } + + private KeyWrapAlgorithm keyWrapAlgorithm() { + return state().keyWrapAlgorithm(); + } + + private ClientState state() { + if (state == null) { + synchronized (this) { + if (state == null) { + AzureProperties azureProperties = new AzureProperties(allProperties); + KeyClientBuilder keyClientBuilder = new KeyClientBuilder(); + azureProperties.keyVaultUrl().ifPresent(keyClientBuilder::vaultUrl); + KeyClient keyClient = + keyClientBuilder + .credential(AdlsTokenCredentialProviders.from(allProperties).credential()) + .buildClient(); + KeyWrapAlgorithm keyWrapAlgorithm = azureProperties.keyWrapAlgorithm(); + state = new ClientState(keyClient, keyWrapAlgorithm); + } + } + } + return state; + } + + private static class ClientState { + + private final KeyClient keyClient; + private final KeyWrapAlgorithm keyWrapAlgorithm; + + ClientState(KeyClient keyClient, KeyWrapAlgorithm keyWrapAlgorithm) { + this.keyClient = keyClient; + this.keyWrapAlgorithm = keyWrapAlgorithm; + } + + KeyClient keyClient() { + return keyClient; + } + + KeyWrapAlgorithm keyWrapAlgorithm() { + return keyWrapAlgorithm; + } + } } diff --git a/core/src/main/java/org/apache/iceberg/encryption/BaseEncryptedKey.java b/core/src/main/java/org/apache/iceberg/encryption/BaseEncryptedKey.java index 389613ee7937..77dea1c67ae0 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/BaseEncryptedKey.java +++ b/core/src/main/java/org/apache/iceberg/encryption/BaseEncryptedKey.java @@ -21,10 +21,12 @@ import java.nio.ByteBuffer; import java.util.Map; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.SerializableMap; public class BaseEncryptedKey implements EncryptedKey { private final String keyId; - private final ByteBuffer keyMetadata; + private final byte[] keyMetadata; private final String encryptedById; private final Map properties; @@ -33,9 +35,9 @@ public BaseEncryptedKey( Preconditions.checkArgument(keyId != null, "Key id cannot be null"); Preconditions.checkArgument(keyMetadata != null, "Encrypted key metadata cannot be null"); this.keyId = keyId; - this.keyMetadata = keyMetadata; + this.keyMetadata = ByteBuffers.toByteArray(keyMetadata); this.encryptedById = encryptedById; - this.properties = properties; + this.properties = SerializableMap.copyOf(properties); } @Override @@ -45,7 +47,7 @@ public String keyId() { @Override public ByteBuffer encryptedKeyMetadata() { - return keyMetadata; + return ByteBuffer.wrap(keyMetadata); } @Override diff --git a/core/src/main/java/org/apache/iceberg/encryption/StandardEncryptionManager.java b/core/src/main/java/org/apache/iceberg/encryption/StandardEncryptionManager.java index 043f21728b30..bb5126d23e3f 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/StandardEncryptionManager.java +++ b/core/src/main/java/org/apache/iceberg/encryption/StandardEncryptionManager.java @@ -34,6 +34,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.util.ByteBuffers; +import org.apache.iceberg.util.SerializableMap; public class StandardEncryptionManager implements EncryptionManager { // Maximal lifespan of key encryption keys is 2 years according to NIST SP 800-57 (PART 1 REV. 5, @@ -43,41 +44,13 @@ public class StandardEncryptionManager implements EncryptionManager { private final String tableKeyId; private final int dataKeyLength; + private final Map encryptionKeys; + private final KeyManagementClient kmsClient; // used in key encryption key rotation unitests private long testTimeShift; - // unserializable elements of the EncryptionManager - private class TransientEncryptionState { - private final KeyManagementClient kmsClient; - private final Map encryptionKeys; - private final LoadingCache unwrappedKeyCache; - - private TransientEncryptionState(KeyManagementClient kmsClient, List keys) { - this.kmsClient = kmsClient; - this.encryptionKeys = Maps.newLinkedHashMap(); - - if (keys != null) { - for (EncryptedKey key : keys) { - encryptionKeys.put( - key.keyId(), - new BaseEncryptedKey( - key.keyId(), key.encryptedKeyMetadata(), key.encryptedById(), key.properties())); - } - } - - this.unwrappedKeyCache = - Caffeine.newBuilder() - .expireAfterWrite(1, TimeUnit.HOURS) - .build( - keyId -> - kmsClient.unwrapKey( - encryptionKeys.get(keyId).encryptedKeyMetadata(), tableKeyId)); - } - } - - private final transient TransientEncryptionState transientState; - + private transient volatile LoadingCache unwrappedKeyCache; private transient volatile SecureRandom lazyRNG = null; /** @@ -107,9 +80,19 @@ public StandardEncryptionManager( dataKeyLength); Preconditions.checkNotNull(kmsClient, "Invalid KMS client: null"); this.tableKeyId = tableKeyId; - this.transientState = new TransientEncryptionState(kmsClient, keys); + this.kmsClient = kmsClient; this.dataKeyLength = dataKeyLength; this.testTimeShift = 0; + + this.encryptionKeys = SerializableMap.copyOf(Maps.newLinkedHashMap()); + if (keys != null) { + for (EncryptedKey key : keys) { + this.encryptionKeys.put( + key.keyId(), + new BaseEncryptedKey( + key.keyId(), key.encryptedKeyMetadata(), key.encryptedById(), key.properties())); + } + } } @Override @@ -132,6 +115,20 @@ public Iterable decrypt(Iterable encrypted) { return Iterables.transform(encrypted, this::decrypt); } + private LoadingCache unwrappedKeyCache() { + if (this.unwrappedKeyCache == null) { + this.unwrappedKeyCache = + Caffeine.newBuilder() + .expireAfterWrite(1, TimeUnit.HOURS) + .build( + keyId -> + kmsClient.unwrapKey( + encryptionKeys.get(keyId).encryptedKeyMetadata(), tableKeyId)); + } + + return unwrappedKeyCache; + } + private SecureRandom workerRNG() { if (this.lazyRNG == null) { this.lazyRNG = new SecureRandom(); @@ -145,11 +142,7 @@ private SecureRandom workerRNG() { */ @Deprecated public ByteBuffer wrapKey(ByteBuffer secretKey) { - Preconditions.checkState( - transientState != null, - "Cannot wrap key after called after serialization (missing KMS client)"); - - return transientState.kmsClient.wrapKey(secretKey, tableKeyId); + return kmsClient.wrapKey(secretKey, tableKeyId); } /** @@ -157,25 +150,17 @@ public ByteBuffer wrapKey(ByteBuffer secretKey) { */ @Deprecated public ByteBuffer unwrapKey(ByteBuffer wrappedSecretKey) { - Preconditions.checkState(transientState != null, "Cannot unwrap key after serialization"); - - return transientState.kmsClient.unwrapKey(wrappedSecretKey, tableKeyId); + return kmsClient.unwrapKey(wrappedSecretKey, tableKeyId); } Map encryptionKeys() { - Preconditions.checkState( - transientState != null, "Cannot return the encryption keys after serialization"); - - return transientState.encryptionKeys; + return encryptionKeys; } String keyEncryptionKeyID() { - Preconditions.checkState( - transientState != null, "Cannot return the current key after serialization"); - // Find unexpired key encryption key - for (String keyID : transientState.encryptionKeys.keySet()) { - EncryptedKey key = transientState.encryptionKeys.get(keyID); + for (String keyID : encryptionKeys.keySet()) { + EncryptedKey key = encryptionKeys.get(keyID); if (key.encryptedById().equals(tableKeyId)) { // this is a key encryption key String timestampProperty = key.properties().get(KEY_TIMESTAMP); long keyTimestamp = Long.parseLong(timestampProperty); @@ -187,14 +172,14 @@ String keyEncryptionKeyID() { // No unexpired key encryption keys; create one ByteBuffer unwrapped = newKey(); - ByteBuffer wrapped = transientState.kmsClient.wrapKey(unwrapped, tableKeyId); + ByteBuffer wrapped = kmsClient.wrapKey(unwrapped, tableKeyId); Map properties = Maps.newHashMap(); properties.put(KEY_TIMESTAMP, "" + currentTimeMillis()); EncryptedKey key = new BaseEncryptedKey(generateKeyId(), wrapped, tableKeyId, properties); // update internal tracking - transientState.unwrappedKeyCache.put(key.keyId(), unwrapped); - transientState.encryptionKeys.put(key.keyId(), key); + unwrappedKeyCache().put(key.keyId(), unwrapped); + encryptionKeys.put(key.keyId(), key); return key.keyId(); } @@ -209,10 +194,7 @@ private long currentTimeMillis() { } ByteBuffer encryptedByKey(String manifestListKeyID) { - Preconditions.checkState( - transientState != null, "Cannot find key encryption key after serialization"); - - EncryptedKey encryptedKeyMetadata = transientState.encryptionKeys.get(manifestListKeyID); + EncryptedKey encryptedKeyMetadata = encryptionKeys.get(manifestListKeyID); Preconditions.checkState( encryptedKeyMetadata != null, @@ -224,25 +206,21 @@ ByteBuffer encryptedByKey(String manifestListKeyID) { "%s is a key encryption key, not manifest list key metadata", manifestListKeyID); - return transientState.unwrappedKeyCache.get(encryptedKeyMetadata.encryptedById()); + return unwrappedKeyCache().get(encryptedKeyMetadata.encryptedById()); } public String addManifestListKeyMetadata(NativeEncryptionKeyMetadata keyMetadata) { - Preconditions.checkState(transientState != null, "Cannot add key metadata after serialization"); - String manifestListKeyID = generateKeyId(); String keyEncryptionKeyID = keyEncryptionKeyID(); String keyEncryptionKeyTimestamp = - transientState.encryptionKeys.get(keyEncryptionKeyID).properties().get(KEY_TIMESTAMP); + encryptionKeys.get(keyEncryptionKeyID).properties().get(KEY_TIMESTAMP); ByteBuffer encryptedKeyMetadata = EncryptionUtil.encryptManifestListKeyMetadata( - transientState.unwrappedKeyCache.get(keyEncryptionKeyID), - keyEncryptionKeyTimestamp, - keyMetadata); + unwrappedKeyCache().get(keyEncryptionKeyID), keyEncryptionKeyTimestamp, keyMetadata); BaseEncryptedKey key = new BaseEncryptedKey(manifestListKeyID, encryptedKeyMetadata, keyEncryptionKeyID, null); - transientState.encryptionKeys.put(key.keyId(), key); + encryptionKeys.put(key.keyId(), key); return manifestListKeyID; } diff --git a/core/src/test/java/org/apache/iceberg/encryption/TestBaseEncryptedKeySerialization.java b/core/src/test/java/org/apache/iceberg/encryption/TestBaseEncryptedKeySerialization.java new file mode 100644 index 000000000000..f6219183ff5f --- /dev/null +++ b/core/src/test/java/org/apache/iceberg/encryption/TestBaseEncryptedKeySerialization.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.encryption; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import org.apache.iceberg.TestHelpers; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +public class TestBaseEncryptedKeySerialization { + + @ParameterizedTest + @MethodSource("org.apache.iceberg.TestHelpers#serializers") + public void testSerialization(TestHelpers.RoundTripSerializer roundTripSerializer) + throws Exception { + byte[] keyBytes = "key".getBytes(StandardCharsets.UTF_8); + EncryptedKey key = + new BaseEncryptedKey("a", ByteBuffer.wrap(keyBytes), "b", Map.of("test", "value")); + + EncryptedKey result = roundTripSerializer.apply(key); + + assertThat(result.keyId()).isEqualTo(key.keyId()); + assertThat(result.encryptedById()).isEqualTo(key.encryptedById()); + assertThat(result.encryptedKeyMetadata()).isEqualTo(key.encryptedKeyMetadata()); + assertThat(result.properties()).isEqualTo(key.properties()); + } +} diff --git a/gcp/src/integration/java/org/apache/iceberg/gcp/TestKeyManagementClient.java b/gcp/src/integration/java/org/apache/iceberg/gcp/TestKeyManagementClient.java index 1e02013b3a0e..f1d954ea62c0 100644 --- a/gcp/src/integration/java/org/apache/iceberg/gcp/TestKeyManagementClient.java +++ b/gcp/src/integration/java/org/apache/iceberg/gcp/TestKeyManagementClient.java @@ -35,9 +35,12 @@ import java.nio.ByteBuffer; import java.util.Map; import java.util.UUID; +import org.apache.iceberg.TestHelpers; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; public abstract class TestKeyManagementClient { @@ -106,10 +109,30 @@ public void testKeyWrapping() { try (GcpKeyManagementClient keyManagementClient = new GcpKeyManagementClient(); ) { keyManagementClient.initialize(properties()); - ByteBuffer key = ByteBuffer.wrap(new String("super-secret-table-master-key").getBytes()); + ByteBuffer key = ByteBuffer.wrap("super-secret-table-master-key".getBytes()); ByteBuffer encryptedKey = keyManagementClient.wrapKey(key, keyname); assertThat(keyManagementClient.unwrapKey(encryptedKey, keyname)).isEqualTo(key); } } + + @ParameterizedTest + @MethodSource("org.apache.iceberg.TestHelpers#serializers") + public void testSerialization( + TestHelpers.RoundTripSerializer roundTripSerializer) + throws Exception { + String keyname = CryptoKeyName.of(projectId, LOCATION, KEY_RING_ID, keyId).toString(); + + try (GcpKeyManagementClient keyManagementClient = new GcpKeyManagementClient(); ) { + keyManagementClient.initialize(properties()); + + GcpKeyManagementClient result = roundTripSerializer.apply(keyManagementClient); + + ByteBuffer key = ByteBuffer.wrap("super-secret-table-master-key".getBytes()); + ByteBuffer encryptedKey = result.wrapKey(key, keyname); + + assertThat(keyManagementClient.unwrapKey(encryptedKey, keyname)).isEqualTo(key); + assertThat(result.unwrapKey(encryptedKey, keyname)).isEqualTo(key); + } + } } diff --git a/gcp/src/main/java/org/apache/iceberg/gcp/GcpKeyManagementClient.java b/gcp/src/main/java/org/apache/iceberg/gcp/GcpKeyManagementClient.java index ff18d5234d2f..22161f9e0a64 100644 --- a/gcp/src/main/java/org/apache/iceberg/gcp/GcpKeyManagementClient.java +++ b/gcp/src/main/java/org/apache/iceberg/gcp/GcpKeyManagementClient.java @@ -29,10 +29,13 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import org.apache.iceberg.common.DynClasses; import org.apache.iceberg.common.DynMethods; import org.apache.iceberg.encryption.KeyManagementClient; +import org.apache.iceberg.exceptions.RuntimeIOException; import org.apache.iceberg.io.CloseableGroup; +import org.apache.iceberg.util.SerializableMap; /** * Key management client implementation that uses Google Cloud Key Management. To be used for @@ -45,31 +48,17 @@ */ public class GcpKeyManagementClient implements KeyManagementClient { - private KeyManagementServiceClient kmsClient; - private CloseableGroup closeableGroup = new CloseableGroup(); + private final AtomicBoolean isResourceClosed = new AtomicBoolean(false); + + private Map allProperties; + + private transient volatile KeyManagementServiceClient kmsClient; + private transient volatile CloseableGroup closeableGroup = new CloseableGroup(); @Override public void initialize(Map properties) { + this.allProperties = SerializableMap.copyOf(properties); this.closeableGroup = new CloseableGroup(); - closeableGroup.setSuppressCloseFailure(true); - - GCPProperties gcpProperties = new GCPProperties(properties); - - try { - KeyManagementServiceSettings.Builder kmsBuilder = KeyManagementServiceSettings.newBuilder(); - if (gcpProperties.oauth2Token().isPresent()) { - OAuth2Credentials oAuth2Credentials = - GCPAuthUtils.oauth2CredentialsFromGcpProperties(gcpProperties, closeableGroup); - kmsBuilder.setCredentialsProvider(FixedCredentialsProvider.create(oAuth2Credentials)); - } - - // if not OAuth then defaults to GoogleCredentials.getApplicationDefault() - this.kmsClient = KeyManagementServiceClient.create(kmsBuilder.build()); - closeableGroup.addCloseable(kmsClient); - - } catch (IOException e) { - throw new RuntimeException("Failed to create GCP cloud KMS service client", e); - } } @Override @@ -78,7 +67,7 @@ public ByteBuffer wrapKey(ByteBuffer key, String wrappingKeyId) { requestBuilder = ByteStringShim.setPlainText(requestBuilder, key); EncryptRequest encryptRequest = requestBuilder.build(); - EncryptResponse encryptResponse = kmsClient.encrypt(encryptRequest); + EncryptResponse encryptResponse = kmsClient().encrypt(encryptRequest); // need ByteString.copyFrom() leaves the BB in an end position, need to reset key.position(0); @@ -91,7 +80,7 @@ public ByteBuffer unwrapKey(ByteBuffer wrappedKey, String wrappingKeyId) { requestBuilder = ByteStringShim.setCipherText(requestBuilder, wrappedKey); DecryptRequest decryptRequest = requestBuilder.build(); - DecryptResponse decryptResponse = kmsClient.decrypt(decryptRequest); + DecryptResponse decryptResponse = kmsClient().decrypt(decryptRequest); // need ByteString.copyFrom() leaves the BB in an end position, need to reset wrappedKey.position(0); @@ -100,11 +89,43 @@ public ByteBuffer unwrapKey(ByteBuffer wrappedKey, String wrappingKeyId) { @Override public void close() { - try { - closeableGroup.close(); - } catch (IOException ioe) { - // closure exceptions already suppressed and logged in closeableGroup + if (isResourceClosed.compareAndSet(false, true)) { + if (closeableGroup != null) { + closeableGroup.setSuppressCloseFailure(true); + try { + closeableGroup.close(); + } catch (IOException ioe) { + // closure exceptions already suppressed and logged in closeableGroup + } + } + } + } + + private KeyManagementServiceClient kmsClient() { + if (kmsClient == null) { + synchronized (this) { + if (kmsClient == null) { + GCPProperties gcpProperties = new GCPProperties(allProperties); + try { + KeyManagementServiceSettings.Builder kmsBuilder = + KeyManagementServiceSettings.newBuilder(); + if (gcpProperties.oauth2Token().isPresent()) { + OAuth2Credentials oAuth2Credentials = + GCPAuthUtils.oauth2CredentialsFromGcpProperties(gcpProperties, closeableGroup); + kmsBuilder.setCredentialsProvider(FixedCredentialsProvider.create(oAuth2Credentials)); + } + + // if not OAuth then defaults to GoogleCredentials.getApplicationDefault() + this.kmsClient = KeyManagementServiceClient.create(kmsBuilder.build()); + closeableGroup.addCloseable(kmsClient); + + } catch (IOException e) { + throw new RuntimeIOException(e, "Failed to create GCP cloud KMS service client"); + } + } + } } + return kmsClient; } private static final class ByteStringShim { diff --git a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java index 905516bff92c..a38506d621f9 100644 --- a/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java +++ b/spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java @@ -331,6 +331,25 @@ public void testManifestEncryption() throws IOException { } } + @TestTemplate + public void testDropTableWithPurge() { + List dataFileTable = + sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.ALL_DATA_FILES); + List dataFiles = + Streams.concat(dataFileTable.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + assertThat(dataFiles).isNotEmpty(); + assertThat(dataFiles) + .allSatisfy(filePath -> assertThat(localInput(filePath).exists()).isTrue()); + + sql("DROP TABLE %s PURGE", tableName); + + assertThat(catalog.tableExists(tableIdent)).as("Table should not exist").isFalse(); + assertThat(dataFiles) + .allSatisfy(filePath -> assertThat(localInput(filePath).exists()).isFalse()); + } + private void checkMetadataFileEncryption(InputFile file) throws IOException { SeekableInputStream stream = file.newStream(); byte[] magic = new byte[4]; diff --git a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java index 905516bff92c..a38506d621f9 100644 --- a/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java +++ b/spark/v4.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java @@ -331,6 +331,25 @@ public void testManifestEncryption() throws IOException { } } + @TestTemplate + public void testDropTableWithPurge() { + List dataFileTable = + sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.ALL_DATA_FILES); + List dataFiles = + Streams.concat(dataFileTable.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + assertThat(dataFiles).isNotEmpty(); + assertThat(dataFiles) + .allSatisfy(filePath -> assertThat(localInput(filePath).exists()).isTrue()); + + sql("DROP TABLE %s PURGE", tableName); + + assertThat(catalog.tableExists(tableIdent)).as("Table should not exist").isFalse(); + assertThat(dataFiles) + .allSatisfy(filePath -> assertThat(localInput(filePath).exists()).isFalse()); + } + private void checkMetadataFileEncryption(InputFile file) throws IOException { SeekableInputStream stream = file.newStream(); byte[] magic = new byte[4];