diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java index 8663cd8745da..ab0e91bcbd61 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java @@ -18,10 +18,9 @@ */ package org.apache.iceberg.encryption; -import java.io.IOException; -import java.io.UncheckedIOException; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.io.SeekableInputStream; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; public class AesGcmInputFile implements InputFile { private final InputFile sourceFile; @@ -40,10 +39,7 @@ public AesGcmInputFile(InputFile sourceFile, byte[] dataKey, byte[] fileAADPrefi public long getLength() { if (plaintextLength == -1) { // Presumes all streams use hard-coded plaintext block size. - // Actual plaintext block size is checked upon stream creation (exception if different). - plaintextLength = - AesGcmInputStream.calculatePlaintextLength( - sourceFile.getLength(), AesGcmOutputStream.plainBlockSize); + plaintextLength = AesGcmInputStream.calculatePlaintextLength(sourceFile.getLength()); } return plaintextLength; @@ -51,18 +47,12 @@ public long getLength() { @Override public SeekableInputStream newStream() { - getLength(); - AesGcmInputStream result; - - try { - result = - new AesGcmInputStream( - sourceFile.newStream(), sourceFile.getLength(), dataKey, fileAADPrefix); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - - return result; + long ciphertextLength = sourceFile.getLength(); + Preconditions.checkState( + ciphertextLength >= Ciphers.GCM_STREAM_HEADER_LENGTH, + "Invalid encrypted stream: %d is shorter than the GCM stream header length", + ciphertextLength); + return new AesGcmInputStream(sourceFile.newStream(), ciphertextLength, dataKey, fileAADPrefix); } @Override diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java index d5cae583cfba..0bbb8ad4556e 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java @@ -20,92 +20,68 @@ import java.io.EOFException; import java.io.IOException; -import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -import java.util.Arrays; import org.apache.iceberg.io.IOUtil; import org.apache.iceberg.io.SeekableInputStream; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; public class AesGcmInputStream extends SeekableInputStream { private final SeekableInputStream sourceStream; - private final Ciphers.AesGcmDecryptor gcmDecryptor; + private final byte[] fileAADPrefix; + private final Ciphers.AesGcmDecryptor decryptor; private final byte[] cipherBlockBuffer; - private final int cipherBlockSize; - private final int plainBlockSize; - private final int numberOfBlocks; + private final long numBlocks; private final int lastCipherBlockSize; private final long plainStreamSize; - private final byte[] fileAADPrefix; private long plainStreamPosition; - private int currentBlockIndex; - private int currentOffsetInPlainBlock; - private byte[] currentDecryptedBlock; - private int currentDecryptedBlockIndex; + private long currentPlainBlockIndex; + private byte[] currentPlainBlock; + private int currentPlainBlockSize; AesGcmInputStream( - SeekableInputStream sourceStream, long sourceLength, byte[] aesKey, byte[] fileAADPrefix) - throws IOException { - long netSourceLength = netSourceFileLength(sourceLength); - boolean emptyCipherStream = (0 == netSourceLength); + SeekableInputStream sourceStream, long sourceLength, byte[] aesKey, byte[] fileAADPrefix) { this.sourceStream = sourceStream; + this.fileAADPrefix = fileAADPrefix; + this.decryptor = new Ciphers.AesGcmDecryptor(aesKey); + this.cipherBlockBuffer = new byte[Ciphers.CIPHER_BLOCK_SIZE]; + + this.plainStreamPosition = 0; + this.currentPlainBlockIndex = -1; + this.currentPlainBlock = null; + this.currentPlainBlockSize = 0; + + long streamLength = sourceLength - Ciphers.GCM_STREAM_HEADER_LENGTH; + long numFullBlocks = Math.toIntExact(streamLength / Ciphers.CIPHER_BLOCK_SIZE); + long cipherFullBlockLength = numFullBlocks * Ciphers.CIPHER_BLOCK_SIZE; + int cipherBytesInLastBlock = Math.toIntExact(streamLength - cipherFullBlockLength); + boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); + this.numBlocks = fullBlocksOnly ? numFullBlocks : numFullBlocks + 1; + this.lastCipherBlockSize = fullBlocksOnly ? Ciphers.CIPHER_BLOCK_SIZE : cipherBytesInLastBlock; // never 0 + + long lastPlainBlockSize = lastCipherBlockSize - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH; + this.plainStreamSize = numFullBlocks * Ciphers.PLAIN_BLOCK_SIZE + (fullBlocksOnly ? 0 : lastPlainBlockSize); + } + + private void validateHeader() throws IOException { byte[] headerBytes = new byte[Ciphers.GCM_STREAM_HEADER_LENGTH]; IOUtil.readFully(sourceStream, headerBytes, 0, headerBytes.length); - byte[] magic = new byte[Ciphers.GCM_STREAM_MAGIC_ARRAY.length]; - System.arraycopy(headerBytes, 0, magic, 0, Ciphers.GCM_STREAM_MAGIC_ARRAY.length); + Preconditions.checkState( - Arrays.equals(Ciphers.GCM_STREAM_MAGIC_ARRAY, magic), - "Cannot open encrypted file, it does not begin with magic string " - + Ciphers.GCM_STREAM_MAGIC_STRING); - this.currentDecryptedBlockIndex = -1; - - if (!emptyCipherStream) { - this.plainStreamPosition = 0; - this.fileAADPrefix = fileAADPrefix; - gcmDecryptor = new Ciphers.AesGcmDecryptor(aesKey); - plainBlockSize = - ByteBuffer.wrap(headerBytes, Ciphers.GCM_STREAM_MAGIC_ARRAY.length, 4) - .order(ByteOrder.LITTLE_ENDIAN) - .getInt(); - Preconditions.checkState(plainBlockSize > 0, "Wrong plainBlockSize " + plainBlockSize); - - Preconditions.checkState( - plainBlockSize == AesGcmOutputStream.plainBlockSize, - "Wrong plainBlockSize " - + plainBlockSize - + ". Only size of " - + AesGcmOutputStream.plainBlockSize - + " is currently supported"); - - cipherBlockSize = plainBlockSize + Ciphers.NONCE_LENGTH + Ciphers.GCM_TAG_LENGTH; - this.cipherBlockBuffer = new byte[cipherBlockSize]; - this.currentBlockIndex = 0; - this.currentOffsetInPlainBlock = 0; - - int numberOfFullBlocks = Math.toIntExact(netSourceLength / cipherBlockSize); - int cipherBytesInLastBlock = - Math.toIntExact(netSourceLength - numberOfFullBlocks * cipherBlockSize); - boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); - numberOfBlocks = fullBlocksOnly ? numberOfFullBlocks : numberOfFullBlocks + 1; - lastCipherBlockSize = fullBlocksOnly ? cipherBlockSize : cipherBytesInLastBlock; // never 0 - plainStreamSize = calculatePlaintextLength(sourceLength, plainBlockSize); - } else { - plainStreamSize = 0; - - gcmDecryptor = null; - cipherBlockBuffer = null; - cipherBlockSize = -1; - plainBlockSize = -1; - numberOfBlocks = -1; - lastCipherBlockSize = -1; - this.fileAADPrefix = null; - } + Ciphers.GCM_STREAM_MAGIC.equals(ByteBuffer.wrap(headerBytes, 0, 4)), + "Invalid GCM stream: magic does not match AGS1"); + + int plainBlockSize = ByteBuffer.wrap(headerBytes, 4, 4).order(ByteOrder.LITTLE_ENDIAN).getInt(); + Preconditions.checkState( + plainBlockSize == Ciphers.PLAIN_BLOCK_SIZE, + "Invalid GCM stream: block size %d != %d", + plainBlockSize, + Ciphers.PLAIN_BLOCK_SIZE); } @Override - public int available() throws IOException { + public int available() { long maxAvailable = plainStreamSize - plainStreamPosition; // See InputStream.available contract if (maxAvailable >= Integer.MAX_VALUE) { @@ -115,9 +91,17 @@ public int available() throws IOException { } } + private int availableInCurrentBlock() { + if (currentPlainBlockIndex < 0) { + return 0; + } + + return currentPlainBlockSize - offsetInBlock(plainStreamPosition); + } + @Override public int read(byte[] b, int off, int len) throws IOException { - Preconditions.checkState(len >= 0, "Negative read length " + len); + Preconditions.checkArgument(len >= 0, "Invalid read length: " + len); if (available() <= 0 && len > 0) { throw new EOFException(); @@ -127,50 +111,51 @@ public int read(byte[] b, int off, int len) throws IOException { return 0; } - boolean isLastBlockInStream = (currentBlockIndex + 1 == numberOfBlocks); + int totalBytesRead = 0; int resultBufferOffset = off; int remainingBytesToRead = len; while (remainingBytesToRead > 0) { - byte[] plainBlock = decryptNextBlock(isLastBlockInStream); - - int remainingBytesInBlock = plainBlock.length - currentOffsetInPlainBlock; - boolean finishTheBlock = remainingBytesToRead >= remainingBytesInBlock; - int bytesToCopy = finishTheBlock ? remainingBytesInBlock : remainingBytesToRead; - System.arraycopy(plainBlock, currentOffsetInPlainBlock, b, resultBufferOffset, bytesToCopy); - remainingBytesToRead -= bytesToCopy; - resultBufferOffset += bytesToCopy; - currentOffsetInPlainBlock += bytesToCopy; - - boolean endOfStream = isLastBlockInStream && finishTheBlock; - - if (endOfStream) { + int availableInBlock = availableInCurrentBlock(); + if (availableInBlock > 0) { + int bytesToCopy = Math.min(availableInBlock, remainingBytesToRead); + int offsetInBlock = offsetInBlock(plainStreamPosition); + System.arraycopy(currentPlainBlock, offsetInBlock, b, resultBufferOffset, bytesToCopy); + totalBytesRead += bytesToCopy; + remainingBytesToRead -= bytesToCopy; + resultBufferOffset += bytesToCopy; + this.plainStreamPosition += bytesToCopy; + if (blockIndex(plainStreamPosition) != currentPlainBlockIndex) { + // invalidate the current block + this.currentPlainBlockIndex = -1; + } + + } else if (available() > 0) { + decryptBlock(blockIndex(plainStreamPosition)); + + } else { break; } - - if (finishTheBlock) { - currentBlockIndex++; - currentOffsetInPlainBlock = 0; - isLastBlockInStream = (currentBlockIndex + 1 == numberOfBlocks); - } } - plainStreamPosition += len - remainingBytesToRead; - return len - remainingBytesToRead; + // return -1 for EOF + return totalBytesRead > 0 ? totalBytesRead : -1; } @Override public void seek(long newPos) throws IOException { if (newPos < 0) { - throw new IOException("Negative new position " + newPos); + throw new IOException("Invalid position: " + newPos); } else if (newPos > plainStreamSize) { throw new EOFException( - "New position " + newPos + " exceeds the max stream size " + plainStreamSize); + "Invalid position: " + newPos + " > stream length, " + plainStreamSize); } - currentBlockIndex = Math.toIntExact(newPos / plainBlockSize); - currentOffsetInPlainBlock = Math.toIntExact(newPos % plainBlockSize); - plainStreamPosition = newPos; + this.plainStreamPosition = newPos; + if (blockIndex(plainStreamPosition) != currentPlainBlockIndex) { + // invalidate the current block + this.currentPlainBlockIndex = -1; + } } @Override @@ -179,27 +164,19 @@ public long skip(long n) { return 0; } - if (plainStreamPosition == plainStreamSize) { - return 0; + long bytesLeftInStream = plainStreamSize - plainStreamPosition; + if (n > bytesLeftInStream) { + // skip the rest of the stream + this.plainStreamPosition = plainStreamSize; + return bytesLeftInStream; } - long newPosition = plainStreamPosition + n; - - if (newPosition > plainStreamSize) { - long skipped = plainStreamSize - plainStreamPosition; - try { - seek(plainStreamSize); - } catch (IOException e) { - throw new UncheckedIOException(e); - } - return skipped; + this.plainStreamPosition += n; + if (blockIndex(plainStreamPosition) != currentPlainBlockIndex) { + // invalidate the current block + this.currentPlainBlockIndex = -1; } - try { - seek(newPosition); - } catch (IOException e) { - throw new UncheckedIOException(e); - } return n; } @@ -216,59 +193,62 @@ public int read() throws IOException { @Override public void close() throws IOException { sourceStream.close(); - currentDecryptedBlock = null; + this.currentPlainBlock = null; } - static long calculatePlaintextLength(long sourceLength, int plainBlockSize) { - long netSourceFileLength = netSourceFileLength(sourceLength); - - if (netSourceFileLength == 0) { - return 0; + private void decryptBlock(long blockIndex) throws IOException { + if (blockIndex == currentPlainBlockIndex) { + return; } - int cipherBlockSize = plainBlockSize + Ciphers.NONCE_LENGTH + Ciphers.GCM_TAG_LENGTH; - int numberOfFullBlocks = Math.toIntExact(netSourceFileLength / cipherBlockSize); - int cipherBytesInLastBlock = - Math.toIntExact(netSourceFileLength - numberOfFullBlocks * cipherBlockSize); - boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); - int plainBytesInLastBlock = - fullBlocksOnly - ? 0 - : (cipherBytesInLastBlock - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH); - - return (long) numberOfFullBlocks * plainBlockSize + plainBytesInLastBlock; - } - - private byte[] decryptNextBlock(boolean isLastBlockInStream) throws IOException { - if (currentBlockIndex == currentDecryptedBlockIndex) { - return currentDecryptedBlock; - } - - long blockPositionInStream = blockOffset(currentBlockIndex); + long blockPositionInStream = blockOffset(blockIndex); if (sourceStream.getPos() != blockPositionInStream) { + if (sourceStream.getPos() == 0) { + validateHeader(); + } + sourceStream.seek(blockPositionInStream); } - int currentCipherBlockSize = isLastBlockInStream ? lastCipherBlockSize : cipherBlockSize; - IOUtil.readFully(sourceStream, cipherBlockBuffer, 0, currentCipherBlockSize); + boolean isLastBlock = blockIndex == numBlocks - 1; + int cipherBlockSize = isLastBlock ? lastCipherBlockSize : Ciphers.CIPHER_BLOCK_SIZE; + IOUtil.readFully(sourceStream, cipherBlockBuffer, 0, cipherBlockSize); - byte[] aad = Ciphers.streamBlockAAD(fileAADPrefix, currentBlockIndex); - byte[] result = gcmDecryptor.decrypt(cipherBlockBuffer, 0, currentCipherBlockSize, aad); - currentDecryptedBlockIndex = currentBlockIndex; - currentDecryptedBlock = result; - return result; + // TODO: the AAD should probably use a long block index. + byte[] blockAAD = Ciphers.streamBlockAAD(fileAADPrefix, Math.toIntExact(blockIndex)); + this.currentPlainBlock = decryptor.decrypt(cipherBlockBuffer, 0, cipherBlockSize, blockAAD); + this.currentPlainBlockSize = cipherBlockSize - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH; + this.currentPlainBlockIndex = blockIndex; } - private long blockOffset(int blockIndex) { - return (long) blockIndex * cipherBlockSize + Ciphers.GCM_STREAM_HEADER_LENGTH; + private static long blockIndex(long plainPosition) { + return plainPosition / Ciphers.PLAIN_BLOCK_SIZE; } - private static long netSourceFileLength(long sourceFileLength) { - long netSourceLength = sourceFileLength - Ciphers.GCM_STREAM_HEADER_LENGTH; - Preconditions.checkArgument( - netSourceLength >= 0, - "Source length " + sourceFileLength + " is shorter than GCM prefix. File is not encrypted"); + private static int offsetInBlock(long plainPosition) { + return Math.toIntExact(plainPosition % Ciphers.PLAIN_BLOCK_SIZE); + } + + private static long blockOffset(long blockIndex) { + return blockIndex * Ciphers.CIPHER_BLOCK_SIZE + Ciphers.GCM_STREAM_HEADER_LENGTH; + } + + static long calculatePlaintextLength(long sourceLength) { + long streamLength = sourceLength - Ciphers.GCM_STREAM_HEADER_LENGTH; + + if (streamLength == 0) { + return 0; + } + + long numberOfFullBlocks = streamLength / Ciphers.CIPHER_BLOCK_SIZE; + long fullBlockSize = numberOfFullBlocks * Ciphers.CIPHER_BLOCK_SIZE; + long cipherBytesInLastBlock = streamLength - fullBlockSize; + boolean fullBlocksOnly = (0 == cipherBytesInLastBlock); + long plainBytesInLastBlock = + fullBlocksOnly + ? 0 + : (cipherBytesInLastBlock - Ciphers.NONCE_LENGTH - Ciphers.GCM_TAG_LENGTH); - return netSourceLength; + return (numberOfFullBlocks * Ciphers.PLAIN_BLOCK_SIZE) + plainBytesInLastBlock; } } diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java index ce165c39c70f..2739302b9f4e 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java @@ -24,7 +24,6 @@ import org.apache.iceberg.io.PositionOutputStream; public class AesGcmOutputStream extends PositionOutputStream { - public static final int plainBlockSize = 1024 * 1024; private final Ciphers.AesGcmEncryptor gcmEncryptor; private final PositionOutputStream targetStream; @@ -39,7 +38,7 @@ public class AesGcmOutputStream extends PositionOutputStream { throws IOException { this.targetStream = targetStream; this.gcmEncryptor = new Ciphers.AesGcmEncryptor(aesKey); - this.plainBlockBuffer = new byte[plainBlockSize]; + this.plainBlockBuffer = new byte[Ciphers.PLAIN_BLOCK_SIZE]; this.positionInBuffer = 0; this.streamPosition = 0; this.currentBlockIndex = 0; @@ -49,7 +48,7 @@ public class AesGcmOutputStream extends PositionOutputStream { ByteBuffer.allocate(Ciphers.GCM_STREAM_HEADER_LENGTH) .order(ByteOrder.LITTLE_ENDIAN) .put(Ciphers.GCM_STREAM_MAGIC_ARRAY) - .putInt(plainBlockSize) + .putInt(Ciphers.PLAIN_BLOCK_SIZE) .array(); targetStream.write(headerBytes); } @@ -69,12 +68,12 @@ public void write(byte[] b, int off, int len) throws IOException { int offset = off; while (remaining > 0) { - int freeBlockBytes = plainBlockSize - positionInBuffer; + int freeBlockBytes = Ciphers.PLAIN_BLOCK_SIZE - positionInBuffer; int toWrite = freeBlockBytes <= remaining ? freeBlockBytes : remaining; System.arraycopy(b, offset, plainBlockBuffer, positionInBuffer, toWrite); positionInBuffer += toWrite; - if (positionInBuffer == plainBlockSize) { + if (positionInBuffer == Ciphers.PLAIN_BLOCK_SIZE) { encryptAndWriteBlock(); positionInBuffer = 0; } diff --git a/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java b/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java index cf43880a7475..d68d0774ff9c 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java +++ b/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java @@ -30,12 +30,16 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; public class Ciphers { + public static final int PLAIN_BLOCK_SIZE = 1024 * 1024; public static final int NONCE_LENGTH = 12; public static final int GCM_TAG_LENGTH = 16; + public static final int CIPHER_BLOCK_SIZE = + PLAIN_BLOCK_SIZE + NONCE_LENGTH + GCM_TAG_LENGTH; public static final String GCM_STREAM_MAGIC_STRING = "AGS1"; static final byte[] GCM_STREAM_MAGIC_ARRAY = GCM_STREAM_MAGIC_STRING.getBytes(StandardCharsets.UTF_8); + static final ByteBuffer GCM_STREAM_MAGIC = ByteBuffer.wrap(GCM_STREAM_MAGIC_ARRAY).asReadOnlyBuffer(); static final int GCM_STREAM_HEADER_LENGTH = GCM_STREAM_MAGIC_ARRAY.length + 4; // magic_len + block_size_len diff --git a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java index d6905f1306f5..ec929efe2d7e 100644 --- a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java +++ b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java @@ -57,9 +57,9 @@ public void testEmptyFile() throws IOException { @Test public void testRandomWriteRead() throws IOException { Random random = new Random(); - int smallerThanBlock = (int) (AesGcmOutputStream.plainBlockSize * 0.5); - int largerThanBlock = (int) (AesGcmOutputStream.plainBlockSize * 1.5); - int alignedWithBlock = AesGcmOutputStream.plainBlockSize; + int smallerThanBlock = (int) (Ciphers.PLAIN_BLOCK_SIZE * 0.5); + int largerThanBlock = (int) (Ciphers.PLAIN_BLOCK_SIZE * 1.5); + int alignedWithBlock = Ciphers.PLAIN_BLOCK_SIZE; int[] testFileSizes = { smallerThanBlock, largerThanBlock, @@ -162,9 +162,9 @@ public void testRandomWriteRead() throws IOException { public void testAlignedWriteRead() throws IOException { Random random = new Random(); int[] testFileSizes = { - AesGcmOutputStream.plainBlockSize, - AesGcmOutputStream.plainBlockSize + 1, - AesGcmOutputStream.plainBlockSize - 1 + Ciphers.PLAIN_BLOCK_SIZE, + Ciphers.PLAIN_BLOCK_SIZE + 1, + Ciphers.PLAIN_BLOCK_SIZE - 1 }; for (int testFileSize : testFileSizes) { @@ -181,7 +181,7 @@ public void testAlignedWriteRead() throws IOException { PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); int offset = 0; - int chunkLen = AesGcmOutputStream.plainBlockSize; + int chunkLen = Ciphers.PLAIN_BLOCK_SIZE; int left = testFileSize; while (left > 0) { @@ -204,7 +204,7 @@ public void testAlignedWriteRead() throws IOException { Assert.assertEquals("File size", testFileSize, decryptedFile.getLength()); offset = 0; - chunkLen = AesGcmOutputStream.plainBlockSize; + chunkLen = Ciphers.PLAIN_BLOCK_SIZE; byte[] chunk = new byte[chunkLen]; left = testFileSize;