diff --git a/api/src/main/java/org/apache/iceberg/encryption/EncryptingFileIO.java b/api/src/main/java/org/apache/iceberg/encryption/EncryptingFileIO.java index 0203361844a5..fb86769173a0 100644 --- a/api/src/main/java/org/apache/iceberg/encryption/EncryptingFileIO.java +++ b/api/src/main/java/org/apache/iceberg/encryption/EncryptingFileIO.java @@ -28,9 +28,11 @@ import org.apache.iceberg.DataFile; import org.apache.iceberg.DeleteFile; import org.apache.iceberg.ManifestFile; +import org.apache.iceberg.exceptions.RuntimeIOException; import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.io.OutputFile; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Iterables; @@ -109,14 +111,28 @@ public InputFile newInputFile(ManifestFile manifest) { } } + /** + * @deprecated will be removed in 2.0.0. use {@link #newDecryptingInputFile(String, long, + * ByteBuffer)} instead. + */ + @Deprecated public InputFile newDecryptingInputFile(String path, ByteBuffer buffer) { - return em.decrypt(wrap(io.newInputFile(path), buffer)); + throw new RuntimeIOException("Deprecated API. File decryption without length is not safe"); } public InputFile newDecryptingInputFile(String path, long length, ByteBuffer buffer) { - // TODO: is the length correct for the encrypted file? It may be the length of the plaintext - // stream - return em.decrypt(wrap(io.newInputFile(path, length), buffer)); + Preconditions.checkArgument( + length > 0, "Cannot safely decrypt file %s because its length is not specified", path); + + InputFile inputFile = io.newInputFile(path, length); + + if (inputFile.getLength() != length) { + throw new RuntimeIOException( + "Cannot safely decrypt a file because its size was changed by FileIO %s from %s to %s", + io.getClass(), length, inputFile.getLength()); + } + + return em.decrypt(wrap(inputFile, buffer)); } @Override @@ -157,7 +173,7 @@ private static SimpleEncryptedInputFile wrap(InputFile encryptedInputFile, ByteB } private static EncryptionKeyMetadata toKeyMetadata(ByteBuffer buffer) { - return buffer != null ? new SimpleKeyMetadata(buffer) : EmptyKeyMetadata.get(); + return buffer != null ? new SimpleKeyMetadata(buffer) : EncryptionKeyMetadata.empty(); } private static class SimpleEncryptedInputFile implements EncryptedInputFile { @@ -198,22 +214,4 @@ public EncryptionKeyMetadata copy() { return new SimpleKeyMetadata(metadataBuffer.duplicate()); } } - - private static class EmptyKeyMetadata implements EncryptionKeyMetadata { - private static final EmptyKeyMetadata INSTANCE = new EmptyKeyMetadata(); - - private static EmptyKeyMetadata get() { - return INSTANCE; - } - - @Override - public ByteBuffer buffer() { - return null; - } - - @Override - public EncryptionKeyMetadata copy() { - return this; - } - } } diff --git a/core/src/main/java/org/apache/iceberg/TableMetadata.java b/core/src/main/java/org/apache/iceberg/TableMetadata.java index 9587c57a0fd2..512bedde00ca 100644 --- a/core/src/main/java/org/apache/iceberg/TableMetadata.java +++ b/core/src/main/java/org/apache/iceberg/TableMetadata.java @@ -29,6 +29,7 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; +import org.apache.iceberg.encryption.KeyEncryptionKey; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.base.MoreObjects; import org.apache.iceberg.relocated.com.google.common.base.Objects; @@ -260,6 +261,7 @@ public String toString() { private final List partitionStatisticsFiles; private final List changes; private SerializableSupplier> snapshotsSupplier; + private Map kekCache; private volatile List snapshots; private volatile Map snapshotsById; private volatile Map refs; @@ -512,6 +514,14 @@ public List snapshots() { return snapshots; } + public void setKekCache(Map kekCache) { + this.kekCache = kekCache; + } + + public Map kekCache() { + return kekCache; + } + private synchronized void ensureSnapshotsLoaded() { if (!snapshotsLoaded) { List loadedSnapshots = Lists.newArrayList(snapshotsSupplier.get()); diff --git a/core/src/main/java/org/apache/iceberg/TableMetadataParser.java b/core/src/main/java/org/apache/iceberg/TableMetadataParser.java index 8bda184142cd..eab9aceef110 100644 --- a/core/src/main/java/org/apache/iceberg/TableMetadataParser.java +++ b/core/src/main/java/org/apache/iceberg/TableMetadataParser.java @@ -34,6 +34,8 @@ import java.util.zip.GZIPOutputStream; import org.apache.iceberg.TableMetadata.MetadataLogEntry; import org.apache.iceberg.TableMetadata.SnapshotLogEntry; +import org.apache.iceberg.encryption.EncryptionUtil; +import org.apache.iceberg.encryption.KeyEncryptionKey; import org.apache.iceberg.exceptions.RuntimeIOException; import org.apache.iceberg.io.FileIO; import org.apache.iceberg.io.InputFile; @@ -42,6 +44,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.util.JsonUtil; public class TableMetadataParser { @@ -104,6 +107,9 @@ private TableMetadataParser() {} static final String REFS = "refs"; static final String SNAPSHOTS = "snapshots"; static final String SNAPSHOT_ID = "snapshot-id"; + static final String KEK_CACHE = "kek-cache"; + static final String KEK_ID = "kek-id"; + static final String KEK_WRAP = "kek-wrap"; static final String TIMESTAMP_MS = "timestamp-ms"; static final String SNAPSHOT_LOG = "snapshot-log"; static final String METADATA_FILE = "metadata-file"; @@ -220,6 +226,18 @@ public static void toJson(TableMetadata metadata, JsonGenerator generator) throw toJson(metadata.refs(), generator); + if (metadata.kekCache() != null && !metadata.kekCache().isEmpty()) { + generator.writeArrayFieldStart(KEK_CACHE); + for (Map.Entry entry : metadata.kekCache().entrySet()) { + generator.writeStartObject(); + generator.writeStringField(KEK_ID, entry.getKey()); + generator.writeStringField(KEK_WRAP, entry.getValue().wrappedKey()); + generator.writeNumberField(TIMESTAMP_MS, entry.getValue().timestamp()); + generator.writeEndObject(); + } + generator.writeEndArray(); + } + generator.writeArrayFieldStart(SNAPSHOTS); for (Snapshot snapshot : metadata.snapshots()) { SnapshotParser.toJson(snapshot, generator); @@ -277,7 +295,11 @@ public static TableMetadata read(FileIO io, InputFile file) { Codec codec = Codec.fromFileName(file.location()); try (InputStream is = codec == Codec.GZIP ? new GZIPInputStream(file.newStream()) : file.newStream()) { - return fromJson(file, JsonUtil.mapper().readValue(is, JsonNode.class)); + TableMetadata tableMetadata = fromJson(file, JsonUtil.mapper().readValue(is, JsonNode.class)); + if (tableMetadata.kekCache() != null) { + EncryptionUtil.getKekCacheFromMetadata(io, tableMetadata.kekCache()); + } + return tableMetadata; } catch (IOException e) { throw new RuntimeIOException(e, "Failed to read file: %s", file); } @@ -466,6 +488,23 @@ public static TableMetadata fromJson(String metadataLocation, JsonNode node) { refs = ImmutableMap.of(); } + Map kekCache = null; + if (node.has(KEK_CACHE)) { + kekCache = Maps.newHashMap(); + Iterator cacheIterator = node.get(KEK_CACHE).elements(); + while (cacheIterator.hasNext()) { + JsonNode entryNode = cacheIterator.next(); + String kekID = JsonUtil.getString(KEK_ID, entryNode); + kekCache.put( + kekID, + new KeyEncryptionKey( + kekID, + null, // key will be unwrapped later + JsonUtil.getString(KEK_WRAP, entryNode), + JsonUtil.getLong(TIMESTAMP_MS, entryNode))); + } + } + List snapshots; if (node.has(SNAPSHOTS)) { JsonNode snapshotArray = JsonUtil.get(SNAPSHOTS, node); @@ -519,31 +558,38 @@ public static TableMetadata fromJson(String metadataLocation, JsonNode node) { } } - return new TableMetadata( - metadataLocation, - formatVersion, - uuid, - location, - lastSequenceNumber, - lastUpdatedMillis, - lastAssignedColumnId, - currentSchemaId, - schemas, - defaultSpecId, - specs, - lastAssignedPartitionId, - defaultSortOrderId, - sortOrders, - properties, - currentSnapshotId, - snapshots, - null, - entries.build(), - metadataEntries.build(), - refs, - statisticsFiles, - partitionStatisticsFiles, - ImmutableList.of() /* no changes from the file */); + TableMetadata result = + new TableMetadata( + metadataLocation, + formatVersion, + uuid, + location, + lastSequenceNumber, + lastUpdatedMillis, + lastAssignedColumnId, + currentSchemaId, + schemas, + defaultSpecId, + specs, + lastAssignedPartitionId, + defaultSortOrderId, + sortOrders, + properties, + currentSnapshotId, + snapshots, + null, + entries.build(), + metadataEntries.build(), + refs, + statisticsFiles, + partitionStatisticsFiles, + ImmutableList.of()); /* no changes from the file */ + + if (kekCache != null) { + result.setKekCache(kekCache); + } + + return result; } private static Map refsFromJson(JsonNode refMap) { diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java index da437b7540db..0922a67622c9 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java @@ -117,6 +117,10 @@ public void flush() throws IOException { @Override public void close() throws IOException { + if (isClosed) { + return; + } + if (!isHeaderWritten) { writeHeader(); } diff --git a/core/src/main/java/org/apache/iceberg/encryption/KeyManagementClient.java b/core/src/main/java/org/apache/iceberg/encryption/KeyManagementClient.java index a7fb494cc8e1..6f834c69ed86 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/KeyManagementClient.java +++ b/core/src/main/java/org/apache/iceberg/encryption/KeyManagementClient.java @@ -24,7 +24,7 @@ import java.util.Map; /** A minimum client interface to connect to a key management service (KMS). */ -interface KeyManagementClient extends Serializable, Closeable { +public interface KeyManagementClient extends Serializable, Closeable { /** * Wrap a secret key, using a wrapping/master key which is stored in KMS and referenced by an ID. diff --git a/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveCatalog.java b/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveCatalog.java index b4f49e29fc49..02525fd34af1 100644 --- a/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveCatalog.java +++ b/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveCatalog.java @@ -43,6 +43,8 @@ import org.apache.iceberg.catalog.Namespace; import org.apache.iceberg.catalog.SupportsNamespaces; import org.apache.iceberg.catalog.TableIdentifier; +import org.apache.iceberg.encryption.EncryptionUtil; +import org.apache.iceberg.encryption.KeyManagementClient; import org.apache.iceberg.exceptions.NamespaceNotEmptyException; import org.apache.iceberg.exceptions.NoSuchNamespaceException; import org.apache.iceberg.exceptions.NoSuchTableException; @@ -56,6 +58,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.util.LocationUtil; +import org.apache.iceberg.util.PropertyUtil; import org.apache.thrift.TException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -76,9 +79,11 @@ public class HiveCatalog extends BaseMetastoreCatalog implements SupportsNamespa private String name; private Configuration conf; private FileIO fileIO; + private KeyManagementClient keyManagementClient; private ClientPool clients; private boolean listAllTables = false; private Map catalogProperties; + private long writerKekTimeout; public HiveCatalog() {} @@ -110,6 +115,15 @@ public void initialize(String inputName, Map properties) { ? new HadoopFileIO(conf) : CatalogUtil.loadFileIO(fileIOImpl, properties, conf); + if (catalogProperties.containsKey(CatalogProperties.ENCRYPTION_KMS_IMPL)) { + this.keyManagementClient = EncryptionUtil.createKmsClient(properties); + this.writerKekTimeout = + PropertyUtil.propertyAsLong( + properties, + CatalogProperties.WRITER_KEK_TIMEOUT_MS, + CatalogProperties.WRITER_KEK_TIMEOUT_MS_DEFAULT); + } + this.clients = new CachedClientPool(conf, properties); } @@ -512,7 +526,8 @@ private boolean isValidateNamespace(Namespace namespace) { public TableOperations newTableOps(TableIdentifier tableIdentifier) { String dbName = tableIdentifier.namespace().level(0); String tableName = tableIdentifier.name(); - return new HiveTableOperations(conf, clients, fileIO, name, dbName, tableName); + return new HiveTableOperations( + conf, clients, fileIO, keyManagementClient, name, dbName, tableName, writerKekTimeout); } @Override diff --git a/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java b/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java index 64f091385297..793d05a3bd46 100644 --- a/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java +++ b/hive-metastore/src/main/java/org/apache/iceberg/hive/HiveTableOperations.java @@ -45,6 +45,12 @@ import org.apache.iceberg.SortOrderParser; import org.apache.iceberg.TableMetadata; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.encryption.EncryptingFileIO; +import org.apache.iceberg.encryption.EncryptionManager; +import org.apache.iceberg.encryption.EncryptionUtil; +import org.apache.iceberg.encryption.KeyManagementClient; +import org.apache.iceberg.encryption.PlaintextEncryptionManager; +import org.apache.iceberg.encryption.StandardEncryptionManager; import org.apache.iceberg.exceptions.AlreadyExistsException; import org.apache.iceberg.exceptions.CommitFailedException; import org.apache.iceberg.exceptions.CommitStateUnknownException; @@ -106,8 +112,17 @@ public static String translateToIcebergProp(String hmsProp) { private final long maxHiveTablePropertySize; private final int metadataRefreshMaxRetries; private final FileIO fileIO; + private final KeyManagementClient keyManagementClient; private final ClientPool metaClients; + private final long writerKekTimeout; + private EncryptionManager encryptionManager; + private EncryptingFileIO encryptingFileIO; + private boolean encryptedTable; + private String encryptionKeyId; + private int encryptionDekLength; + + @VisibleForTesting protected HiveTableOperations( Configuration conf, ClientPool metaClients, @@ -115,9 +130,22 @@ protected HiveTableOperations( String catalogName, String database, String table) { + this(conf, metaClients, fileIO, null, catalogName, database, table, 0L); + } + + protected HiveTableOperations( + Configuration conf, + ClientPool metaClients, + FileIO fileIO, + KeyManagementClient keyManagementClient, + String catalogName, + String database, + String table, + long writerKekTimeout) { this.conf = conf; this.metaClients = metaClients; this.fileIO = fileIO; + this.keyManagementClient = keyManagementClient; this.fullName = catalogName + "." + database + "." + table; this.catalogName = catalogName; this.database = database; @@ -128,6 +156,8 @@ protected HiveTableOperations( HIVE_ICEBERG_METADATA_REFRESH_MAX_RETRIES_DEFAULT); this.maxHiveTablePropertySize = conf.getLong(HIVE_TABLE_PROPERTY_MAX_SIZE, HIVE_TABLE_PROPERTY_MAX_SIZE_DEFAULT); + this.encryptedTable = false; + this.writerKekTimeout = writerKekTimeout; } @Override @@ -137,7 +167,47 @@ protected String tableName() { @Override public FileIO io() { - return fileIO; + if (encryptionManager == null) { + encryptionManager = encryption(); + } + + if (!encryptedTable) { + return fileIO; + } + + if (encryptingFileIO != null) { + return encryptingFileIO; + } + + encryptingFileIO = EncryptingFileIO.combine(fileIO, encryptionManager); + return encryptingFileIO; + } + + @Override + public EncryptionManager encryption() { + if (encryptionManager != null) { + return encryptionManager; + } + + if (encryptionKeyId == null) { + encryptionPropsFromHms(); + } + + if (encryptionKeyId != null) { + if (keyManagementClient == null) { + throw new RuntimeException( + "Cant create encryption manager, because key management client is not set"); + } + + encryptedTable = true; + encryptionManager = + EncryptionUtil.createEncryptionManager( + encryptionKeyId, encryptionDekLength, keyManagementClient, writerKekTimeout); + } else { + encryptionManager = PlaintextEncryptionManager.instance(); + } + + return encryptionManager; } @Override @@ -148,7 +218,6 @@ protected void doRefresh() { HiveOperationsBase.validateTableIsIceberg(table, fullName); metadataLocation = table.getParameters().get(METADATA_LOCATION_PROP); - } catch (NoSuchObjectException e) { if (currentMetadataLocation() != null) { throw new NoSuchTableException("No such table: %s.%s", database, tableName); @@ -167,10 +236,17 @@ protected void doRefresh() { refreshFromMetadataLocation(metadataLocation, metadataRefreshMaxRetries); } - @SuppressWarnings("checkstyle:CyclomaticComplexity") + @SuppressWarnings({"checkstyle:CyclomaticComplexity", "checkstyle:MethodLength"}) @Override protected void doCommit(TableMetadata base, TableMetadata metadata) { boolean newTable = base == null; + + encryptionPropsFromMetadata(metadata.properties()); + + if (encryption() instanceof StandardEncryptionManager) { + metadata.setKekCache(EncryptionUtil.kekCache(encryptionManager)); + } + String newMetadataLocation = writeNewMetadataIfRequired(newTable, metadata); boolean hiveEngineEnabled = hiveEngineEnabled(metadata, conf); boolean keepHiveStats = conf.getBoolean(ConfigProperties.KEEP_HIVE_STATS, false); @@ -226,6 +302,10 @@ protected void doCommit(TableMetadata base, TableMetadata metadata) { .collect(Collectors.toSet()); } + if (removedProps.contains(TableProperties.ENCRYPTION_TABLE_KEY)) { + throw new RuntimeException("Cannot remove key in encrypted table"); + } + Map summary = Optional.ofNullable(metadata.currentSnapshot()) .map(Snapshot::summary) @@ -319,6 +399,46 @@ protected void doCommit(TableMetadata base, TableMetadata metadata) { "Committed to table {} with the new metadata location {}", fullName, newMetadataLocation); } + private void encryptionPropsFromMetadata(Map tableProperties) { + if (encryptionKeyId == null) { + encryptionKeyId = tableProperties.get(TableProperties.ENCRYPTION_TABLE_KEY); + } + + if (encryptionKeyId != null && encryptionDekLength <= 0) { + String dekLength = tableProperties.get(TableProperties.ENCRYPTION_DEK_LENGTH); + encryptionDekLength = + (dekLength == null) + ? TableProperties.ENCRYPTION_DEK_LENGTH_DEFAULT + : Integer.valueOf(dekLength); + } + } + + private void encryptionPropsFromHms() { + try { + Table table = loadHmsTable(); + if (table == null) { + return; + } + + encryptionKeyId = table.getParameters().get(TableProperties.ENCRYPTION_TABLE_KEY); + if (encryptionKeyId != null) { + String dekLength = table.getParameters().get(TableProperties.ENCRYPTION_DEK_LENGTH); + encryptionDekLength = + (dekLength == null) + ? TableProperties.ENCRYPTION_DEK_LENGTH_DEFAULT + : Integer.valueOf(dekLength); + } + } catch (TException e) { + String errMsg = + String.format("Failed to get table info from metastore %s.%s", database, tableName); + throw new RuntimeException(errMsg, e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted during encryption key id retrieval", e); + } + } + + @SuppressWarnings("checkstyle:CyclomaticComplexity") private void setHmsTableParameters( String newMetadataLocation, Table tbl, diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAvroTableEncryption.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAvroTableEncryption.java new file mode 100644 index 000000000000..1ba2616cdd61 --- /dev/null +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestAvroTableEncryption.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.apache.iceberg.Files.localInput; +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.avro.Avro; +import org.apache.iceberg.encryption.Ciphers; +import org.apache.iceberg.encryption.UnitestKMS; +import org.apache.iceberg.exceptions.RuntimeIOException; +import org.apache.iceberg.io.SeekableInputStream; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.types.Types; +import org.junit.Assert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestAvroTableEncryption extends CatalogTestBase { + + private static Map appendCatalogEncryptionProperties(Map props) { + Map newProps = Maps.newHashMap(); + newProps.putAll(props); + newProps.put(CatalogProperties.ENCRYPTION_KMS_IMPL, UnitestKMS.class.getCanonicalName()); + return newProps; + } + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + appendCatalogEncryptionProperties(SparkCatalogConfig.HIVE.properties()) + } + }; + } + + @BeforeEach + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string, float float) USING iceberg " + + "TBLPROPERTIES ( " + + "'encryption.key-id'='%s' , " + + "'write.format.default'='AVRO')", + tableName, UnitestKMS.MASTER_KEY_NAME1); + + sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testSelect() { + List expected = + ImmutableList.of(row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testDirectDataFileRead() throws IOException { + List dataFileTable = + sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.ALL_DATA_FILES); + List dataFiles = + Streams.concat(dataFileTable.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + Schema schema = new Schema(optional(0, "id", Types.IntegerType.get())); + byte[] magic = new byte[4]; + for (String filePath : dataFiles) { + AssertHelpers.assertThrows( + "Read without keys", + RuntimeIOException.class, + "Failed to open file", + () -> Avro.read(localInput(filePath)).project(schema).build().iterator().next()); + + // Verify encryption of data files + SeekableInputStream dataFileReader = localInput(filePath).newStream(); + dataFileReader.read(magic); + dataFileReader.close(); + Assert.assertArrayEquals( + magic, Ciphers.GCM_STREAM_MAGIC_STRING.getBytes(StandardCharsets.UTF_8)); + } + } +} diff --git a/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java new file mode 100644 index 000000000000..2a2358cacc94 --- /dev/null +++ b/spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/sql/TestTableEncryption.java @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import static org.apache.iceberg.Files.localInput; +import static org.apache.iceberg.types.Types.NestedField.optional; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.CatalogProperties; +import org.apache.iceberg.MetadataTableType; +import org.apache.iceberg.Parameters; +import org.apache.iceberg.Schema; +import org.apache.iceberg.encryption.Ciphers; +import org.apache.iceberg.encryption.UnitestKMS; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.io.SeekableInputStream; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.relocated.com.google.common.collect.Streams; +import org.apache.iceberg.spark.CatalogTestBase; +import org.apache.iceberg.spark.SparkCatalogConfig; +import org.apache.iceberg.types.Types; +import org.apache.parquet.crypto.ParquetCryptoRuntimeException; +import org.junit.Assert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestTemplate; + +public class TestTableEncryption extends CatalogTestBase { + + private static Map appendCatalogEncryptionProperties(Map props) { + Map newProps = Maps.newHashMap(); + newProps.putAll(props); + newProps.put(CatalogProperties.ENCRYPTION_KMS_IMPL, UnitestKMS.class.getCanonicalName()); + return newProps; + } + + @Parameters(name = "catalogName = {0}, implementation = {1}, config = {2}") + protected static Object[][] parameters() { + return new Object[][] { + { + SparkCatalogConfig.HIVE.catalogName(), + SparkCatalogConfig.HIVE.implementation(), + appendCatalogEncryptionProperties(SparkCatalogConfig.HIVE.properties()) + } + }; + } + + @BeforeEach + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string, float float) USING iceberg " + + "TBLPROPERTIES ( " + + "'encryption.key-id'='%s')", + tableName, UnitestKMS.MASTER_KEY_NAME1); + + sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName); + } + + @AfterEach + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @TestTemplate + public void testSelect() { + List expected = + ImmutableList.of(row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN)); + + assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName)); + } + + @TestTemplate + public void testInsertAndDelete() { + sql("INSERT INTO %s VALUES (4, 'd', 4.0), (5, 'e', 5.0), (6, 'f', float('NaN'))", tableName); + + List expected = + ImmutableList.of( + row(1L, "a", 1.0F), + row(2L, "b", 2.0F), + row(3L, "c", Float.NaN), + row(4L, "d", 4.0F), + row(5L, "e", 5.0F), + row(6L, "f", Float.NaN)); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT * FROM %s ORDER BY id", tableName)); + + sql("DELETE FROM %s WHERE id < 4", tableName); + + expected = ImmutableList.of(row(4L, "d", 4.0F), row(5L, "e", 5.0F), row(6L, "f", Float.NaN)); + + assertEquals( + "Should return all expected rows", + expected, + sql("SELECT * FROM %s ORDER BY id", tableName)); + } + + @TestTemplate + public void testDirectDataFileRead() { + List dataFileTable = + sql("SELECT file_path FROM %s.%s", tableName, MetadataTableType.ALL_DATA_FILES); + List dataFiles = + Streams.concat(dataFileTable.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + Schema schema = new Schema(optional(0, "id", Types.IntegerType.get())); + for (String filePath : dataFiles) { + AssertHelpers.assertThrows( + "Read without keys", + ParquetCryptoRuntimeException.class, + "Trying to read file with encrypted footer. No keys available", + () -> + Parquet.read(localInput(filePath)) + .project(schema) + .callInit() + .build() + .iterator() + .next()); + } + } + + @TestTemplate + public void testManifestEncryption() throws IOException { + List manifestFileTable = + sql("SELECT path FROM %s.%s", tableName, MetadataTableType.MANIFESTS); + + List manifestFiles = + Streams.concat(manifestFileTable.stream()) + .map(row -> (String) row[0]) + .collect(Collectors.toList()); + + if (!(manifestFiles.size() > 0)) { + throw new RuntimeException("No manifest files found for table " + tableName); + } + + String metadataFolderPath = null; + + // Check encryption of manifest files + for (String manifestFilePath : manifestFiles) { + checkMetadataFileEncryption(localInput(manifestFilePath)); + + if (metadataFolderPath == null) { + metadataFolderPath = new File(manifestFilePath).getParent().replaceFirst("file:", ""); + } + } + + if (metadataFolderPath == null) { + throw new RuntimeException("No metadata folder found for table " + tableName); + } + + // Find manifest list and metadata files; check their encryption + File[] listOfMetadataFiles = new File(metadataFolderPath).listFiles(); + boolean foundManifestListFile = false; + boolean foundMetadataJson = false; + + for (File metadataFile : listOfMetadataFiles) { + if (metadataFile.getName().startsWith("snap-")) { + foundManifestListFile = true; + checkMetadataFileEncryption(localInput(metadataFile)); + } + } + + if (!foundManifestListFile) { + throw new RuntimeException("No manifest list files found for table " + tableName); + } + } + + private void checkMetadataFileEncryption(InputFile file) throws IOException { + SeekableInputStream stream = file.newStream(); + byte[] magic = new byte[4]; + stream.read(magic); + stream.close(); + Assert.assertArrayEquals( + magic, Ciphers.GCM_STREAM_MAGIC_STRING.getBytes(StandardCharsets.UTF_8)); + } +}