diff --git a/mr/src/main/java/org/apache/iceberg/mr/mapreduce/IcebergInputFormat.java b/mr/src/main/java/org/apache/iceberg/mr/mapreduce/IcebergInputFormat.java index ccfbb6a31006..3b5225d01510 100644 --- a/mr/src/main/java/org/apache/iceberg/mr/mapreduce/IcebergInputFormat.java +++ b/mr/src/main/java/org/apache/iceberg/mr/mapreduce/IcebergInputFormat.java @@ -57,8 +57,7 @@ import org.apache.iceberg.data.avro.DataReader; import org.apache.iceberg.data.orc.GenericOrcReader; import org.apache.iceberg.data.parquet.GenericParquetReaders; -import org.apache.iceberg.encryption.EncryptedFiles; -import org.apache.iceberg.encryption.EncryptionManager; +import org.apache.iceberg.encryption.InputFilesDecryptor; import org.apache.iceberg.expressions.Evaluator; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.Expressions; @@ -208,7 +207,7 @@ private static final class IcebergRecordReader extends RecordReader private T current; private CloseableIterator currentIterator; private FileIO io; - private EncryptionManager encryptionManager; + private InputFilesDecryptor inputFilesDecryptor; @Override public void initialize(InputSplit split, TaskAttemptContext newContext) { @@ -219,7 +218,7 @@ public void initialize(InputSplit split, TaskAttemptContext newContext) { Table table = ((IcebergSplit) split).table(); HiveIcebergStorageHandler.checkAndSetIoConfig(conf, table); this.io = table.io(); - this.encryptionManager = table.encryption(); + this.inputFilesDecryptor = new InputFilesDecryptor(task, io, table.encryption()); this.tasks = task.files().iterator(); this.tableSchema = InputFormatConfig.tableSchema(conf); this.nameMapping = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); @@ -275,9 +274,7 @@ public void close() throws IOException { private CloseableIterable openTask(FileScanTask currentTask, Schema readSchema) { DataFile file = currentTask.file(); - InputFile inputFile = encryptionManager.decrypt(EncryptedFiles.encryptedInput( - io.newInputFile(file.path().toString()), - file.keyMetadata())); + InputFile inputFile = inputFilesDecryptor.getInputFile(currentTask); CloseableIterable iterable; switch (file.format()) { diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java index f0664c7e8e29..b558edbd0a9f 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java @@ -26,7 +26,6 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.stream.Stream; import org.apache.avro.generic.GenericData; import org.apache.avro.util.Utf8; import org.apache.iceberg.CombinedScanTask; @@ -36,13 +35,10 @@ import org.apache.iceberg.Schema; import org.apache.iceberg.StructLike; import org.apache.iceberg.Table; -import org.apache.iceberg.encryption.EncryptedFiles; -import org.apache.iceberg.encryption.EncryptedInputFile; +import org.apache.iceberg.encryption.InputFilesDecryptor; import org.apache.iceberg.io.CloseableIterator; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; -import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types.NestedField; import org.apache.iceberg.types.Types.StructType; @@ -65,7 +61,7 @@ abstract class BaseDataReader implements Closeable { private final Table table; private final Iterator tasks; - private final Map inputFiles; + private final InputFilesDecryptor inputFilesDecryptor; private CloseableIterator currentIterator; private T current = null; @@ -74,20 +70,7 @@ abstract class BaseDataReader implements Closeable { BaseDataReader(Table table, CombinedScanTask task) { this.table = table; this.tasks = task.files().iterator(); - Map keyMetadata = Maps.newHashMap(); - task.files().stream() - .flatMap(fileScanTask -> Stream.concat(Stream.of(fileScanTask.file()), fileScanTask.deletes().stream())) - .forEach(file -> keyMetadata.put(file.path().toString(), file.keyMetadata())); - Stream encrypted = keyMetadata.entrySet().stream() - .map(entry -> EncryptedFiles.encryptedInput(table.io().newInputFile(entry.getKey()), entry.getValue())); - - // decrypt with the batch call to avoid multiple RPCs to a key server, if possible - Iterable decryptedFiles = table.encryption().decrypt(encrypted::iterator); - - Map files = Maps.newHashMapWithExpectedSize(task.files().size()); - decryptedFiles.forEach(decrypted -> files.putIfAbsent(decrypted.location(), decrypted)); - this.inputFiles = ImmutableMap.copyOf(files); - + this.inputFilesDecryptor = new InputFilesDecryptor(task, table().io(), table.encryption()); this.currentIterator = CloseableIterator.empty(); } @@ -139,11 +122,11 @@ public void close() throws IOException { protected InputFile getInputFile(FileScanTask task) { Preconditions.checkArgument(!task.isDataTask(), "Invalid task type"); - return inputFiles.get(task.file().path().toString()); + return inputFilesDecryptor.getInputFile(task); } protected InputFile getInputFile(String location) { - return inputFiles.get(location); + return inputFilesDecryptor.getInputFile(location); } protected Map constantsMap(FileScanTask task, Schema readSchema) {