diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java index ab0e91bcbd61..a43643fcc779 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputFile.java @@ -49,8 +49,8 @@ public long getLength() { public SeekableInputStream newStream() { long ciphertextLength = sourceFile.getLength(); Preconditions.checkState( - ciphertextLength >= Ciphers.GCM_STREAM_HEADER_LENGTH, - "Invalid encrypted stream: %d is shorter than the GCM stream header length", + ciphertextLength >= Ciphers.MIN_STREAM_LENGTH, + "Invalid encrypted stream: %d is shorter than the minimum possible stream length", ciphertextLength); return new AesGcmInputStream(sourceFile.newStream(), ciphertextLength, dataKey, fileAADPrefix); } diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java index e033dd62ee14..57fed69c0172 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmInputStream.java @@ -30,6 +30,8 @@ public class AesGcmInputStream extends SeekableInputStream { private final SeekableInputStream sourceStream; private final byte[] fileAADPrefix; private final Ciphers.AesGcmDecryptor decryptor; + private final byte[] cipherBlockBuffer; + private final byte[] currentPlainBlock; private final long numBlocks; private final int lastCipherBlockSize; private final long plainStreamSize; @@ -37,8 +39,6 @@ public class AesGcmInputStream extends SeekableInputStream { private long plainStreamPosition; private long currentPlainBlockIndex; - private byte[] cipherBlockBuffer; - private byte[] currentPlainBlock; private int currentPlainBlockSize; AesGcmInputStream( @@ -107,6 +107,10 @@ private int availableInCurrentBlock() { public int read(byte[] b, int off, int len) throws IOException { Preconditions.checkArgument(len >= 0, "Invalid read length: " + len); + if (currentPlainBlockIndex < 0) { + decryptBlock(0); + } + if (available() <= 0 && len > 0) { throw new EOFException(); } @@ -183,16 +187,12 @@ public int read() throws IOException { return -1; } - int unsignedByte = singleByte[0] >= 0 ? singleByte[0] : 256 + singleByte[0]; - - return unsignedByte; + return singleByte[0] >= 0 ? singleByte[0] : 256 + singleByte[0]; } @Override public void close() throws IOException { sourceStream.close(); - this.currentPlainBlock = null; - this.cipherBlockBuffer = null; } private void decryptBlock(long blockIndex) throws IOException { diff --git a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java index 3a3fbd6b3f3e..4db0802ea1b3 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java +++ b/core/src/main/java/org/apache/iceberg/encryption/AesGcmOutputStream.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import org.apache.iceberg.io.PositionOutputStream; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; public class AesGcmOutputStream extends PositionOutputStream { @@ -31,29 +32,30 @@ public class AesGcmOutputStream extends PositionOutputStream { .put(Ciphers.GCM_STREAM_MAGIC_ARRAY) .putInt(Ciphers.PLAIN_BLOCK_SIZE) .array(); + private final Ciphers.AesGcmEncryptor gcmEncryptor; private final PositionOutputStream targetStream; private final byte[] fileAadPrefix; private final byte[] singleByte; + private final byte[] plainBlock; + private final byte[] cipherBlock; - private byte[] plainBlock; - private byte[] cipherBlock; private int positionInPlainBlock; - private long streamPosition; private int currentBlockIndex; private boolean isHeaderWritten; + private boolean lastBlockWritten; AesGcmOutputStream(PositionOutputStream targetStream, byte[] aesKey, byte[] fileAadPrefix) { this.targetStream = targetStream; this.gcmEncryptor = new Ciphers.AesGcmEncryptor(aesKey); + this.fileAadPrefix = fileAadPrefix; + this.singleByte = new byte[1]; this.plainBlock = new byte[Ciphers.PLAIN_BLOCK_SIZE]; this.cipherBlock = new byte[Ciphers.CIPHER_BLOCK_SIZE]; this.positionInPlainBlock = 0; - this.streamPosition = 0; this.currentBlockIndex = 0; - this.fileAadPrefix = fileAadPrefix; this.isHeaderWritten = false; - this.singleByte = new byte[1]; + this.lastBlockWritten = false; } @Override @@ -85,17 +87,15 @@ public void write(byte[] b, int off, int len) throws IOException { offset += toWrite; remaining -= toWrite; - if (positionInPlainBlock == Ciphers.PLAIN_BLOCK_SIZE) { + if (positionInPlainBlock == plainBlock.length) { encryptAndWriteBlock(); } } - - streamPosition += len; } @Override public long getPos() throws IOException { - return streamPosition; + return (long) currentBlockIndex * Ciphers.PLAIN_BLOCK_SIZE + positionInPlainBlock; } @Override @@ -109,28 +109,31 @@ public void close() throws IOException { writeHeader(); } - if (positionInPlainBlock > 0) { - encryptAndWriteBlock(); - } + encryptAndWriteBlock(); targetStream.close(); - plainBlock = null; - cipherBlock = null; } private void writeHeader() throws IOException { - targetStream.write(HEADER_BYTES); isHeaderWritten = true; } private void encryptAndWriteBlock() throws IOException { + Preconditions.checkState( + !lastBlockWritten, "Cannot encrypt block: a partial block has already been written"); + if (currentBlockIndex == Integer.MAX_VALUE) { throw new IOException("Cannot write block: exceeded Integer.MAX_VALUE blocks"); } - if (positionInPlainBlock == 0) { - throw new IOException("Empty plain block"); + if (positionInPlainBlock == 0 && currentBlockIndex != 0) { + return; + } + + if (positionInPlainBlock != plainBlock.length) { + // signal that a partial block has been written and must be the last + this.lastBlockWritten = true; } byte[] aad = Ciphers.streamBlockAAD(fileAadPrefix, currentBlockIndex); diff --git a/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java b/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java index 381e1372952b..4aeb1ecad919 100644 --- a/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java +++ b/core/src/main/java/org/apache/iceberg/encryption/Ciphers.java @@ -45,6 +45,8 @@ public class Ciphers { private static final int GCM_TAG_LENGTH_BITS = 8 * GCM_TAG_LENGTH; + static final int MIN_STREAM_LENGTH = GCM_STREAM_HEADER_LENGTH + NONCE_LENGTH + GCM_TAG_LENGTH; + private Ciphers() {} public static class AesGcmEncryptor { @@ -54,20 +56,8 @@ public static class AesGcmEncryptor { private final byte[] nonce; public AesGcmEncryptor(byte[] keyBytes) { - Preconditions.checkArgument(keyBytes != null, "Key can't be null"); - int keyLength = keyBytes.length; - Preconditions.checkArgument( - (keyLength == 16 || keyLength == 24 || keyLength == 32), - "Cannot use a key of length " - + keyLength - + " because AES only allows 16, 24 or 32 bytes"); - this.aesKey = new SecretKeySpec(keyBytes, "AES"); - - try { - this.cipher = Cipher.getInstance("AES/GCM/NoPadding"); - } catch (GeneralSecurityException e) { - throw new RuntimeException("Failed to create GCM cipher", e); - } + this.aesKey = newKey(keyBytes); + this.cipher = newCipher(); this.randomGenerator = new SecureRandom(); this.nonce = new byte[NONCE_LENGTH]; @@ -92,7 +82,7 @@ public int encrypt( int ciphertextOffset, byte[] aad) { Preconditions.checkArgument( - plaintextLength > 0, "Invalid plain text length: %s", plaintextLength); + plaintextLength >= 0, "Invalid plain text length: %s", plaintextLength); randomGenerator.nextBytes(nonce); int enciphered; @@ -136,20 +126,8 @@ public static class AesGcmDecryptor { private final Cipher cipher; public AesGcmDecryptor(byte[] keyBytes) { - Preconditions.checkArgument(keyBytes != null, "Key can't be null"); - int keyLength = keyBytes.length; - Preconditions.checkArgument( - (keyLength == 16 || keyLength == 24 || keyLength == 32), - "Cannot use a key of length " - + keyLength - + " because AES only allows 16, 24 or 32 bytes"); - this.aesKey = new SecretKeySpec(keyBytes, "AES"); - - try { - this.cipher = Cipher.getInstance("AES/GCM/NoPadding"); - } catch (GeneralSecurityException e) { - throw new RuntimeException("Failed to create GCM cipher", e); - } + this.aesKey = newKey(keyBytes); + this.cipher = newCipher(); } public byte[] decrypt(byte[] ciphertext, byte[] aad) { @@ -172,7 +150,7 @@ public int decrypt( int plaintextOffset, byte[] aad) { Preconditions.checkState( - ciphertextLength - GCM_TAG_LENGTH - NONCE_LENGTH >= 1, + ciphertextLength - GCM_TAG_LENGTH - NONCE_LENGTH >= 0, "Cannot decrypt cipher text of length " + ciphertext.length + " because text must longer than GCM_TAG_LENGTH + NONCE_LENGTH bytes. Text may not be encrypted" @@ -207,6 +185,24 @@ public int decrypt( } } + private static SecretKeySpec newKey(byte[] keyBytes) { + Preconditions.checkArgument(keyBytes != null, "Invalid key: null"); + int keyLength = keyBytes.length; + Preconditions.checkArgument( + (keyLength == 16 || keyLength == 24 || keyLength == 32), + "Invalid key length: %s (must be 16, 24, or 32 bytes)", + keyLength); + return new SecretKeySpec(keyBytes, "AES"); + } + + private static Cipher newCipher() { + try { + return Cipher.getInstance("AES/GCM/NoPadding"); + } catch (GeneralSecurityException e) { + throw new RuntimeException("Failed to create GCM cipher", e); + } + } + static byte[] streamBlockAAD(byte[] fileAadPrefix, int currentBlockIndex) { byte[] blockAAD = ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putInt(currentBlockIndex).array(); diff --git a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java index 0b3f085bc9d7..773d5f41af94 100644 --- a/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java +++ b/core/src/test/java/org/apache/iceberg/encryption/TestGcmStreams.java @@ -18,13 +18,19 @@ */ package org.apache.iceberg.encryption; +import java.io.EOFException; import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; import java.util.Random; +import javax.crypto.AEADBadTagException; import org.apache.iceberg.Files; import org.apache.iceberg.io.PositionOutputStream; import org.apache.iceberg.io.SeekableInputStream; +import org.assertj.core.api.Assertions; import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -41,17 +47,196 @@ public void testEmptyFile() throws IOException { random.nextBytes(key); byte[] aadPrefix = new byte[16]; random.nextBytes(aadPrefix); + byte[] readBytes = new byte[1]; + File testFile = temp.newFile(); + AesGcmOutputFile encryptedFile = new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite(); encryptedStream.close(); AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); - SeekableInputStream decryptedStream = decryptedFile.newStream(); - Assert.assertEquals("File size", 0, decryptedFile.getLength()); - decryptedStream.close(); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + Assertions.assertThatThrownBy(() -> decryptedStream.read(readBytes)) + .isInstanceOf(EOFException.class); + } + + // check that the AAD is still verified, even for an empty file + byte[] badAAD = Arrays.copyOf(aadPrefix, aadPrefix.length); + badAAD[1] -= 1; // modify the AAD slightly + AesGcmInputFile badAADFile = new AesGcmInputFile(Files.localInput(testFile), key, badAAD); + Assert.assertEquals("File size", 0, badAADFile.getLength()); + + try (SeekableInputStream decryptedStream = badAADFile.newStream()) { + Assertions.assertThatThrownBy(() -> decryptedStream.read(readBytes)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testAADValidation() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block + random.nextBytes(content); + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { + encryptedStream.write(content); + } + + // verify the data can be read correctly with the right AAD + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", content.length, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + int bytesRead = decryptedStream.read(readContent); + Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); + Assert.assertEquals( + "Content should match", + ByteBuffer.wrap(content), + ByteBuffer.wrap(readContent, 0, bytesRead)); + } + + // test with the wrong AAD + byte[] badAAD = Arrays.copyOf(aadPrefix, aadPrefix.length); + badAAD[1] -= 1; // modify the AAD slightly + AesGcmInputFile badAADFile = new AesGcmInputFile(Files.localInput(testFile), key, badAAD); + Assert.assertEquals("File size", content.length, badAADFile.getLength()); + + try (SeekableInputStream decryptedStream = badAADFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + + // modify the file contents + try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { + long lastTagPosition = testFile.length() - Ciphers.GCM_TAG_LENGTH; + out.position(lastTagPosition); + out.write(ByteBuffer.wrap(key)); // overwrite the tag with other random bytes (the key) + } + + // read with the correct AAD and verify the tag check fails + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testCorruptNonce() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block + random.nextBytes(content); + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { + encryptedStream.write(content); + } + + // verify the data can be read correctly with the right AAD + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", content.length, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + int bytesRead = decryptedStream.read(readContent); + Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); + Assert.assertEquals( + "Content should match", + ByteBuffer.wrap(content), + ByteBuffer.wrap(readContent, 0, bytesRead)); + } + + // replace the first block's nonce + try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { + out.position(Ciphers.GCM_STREAM_HEADER_LENGTH); + // overwrite the nonce with other random bytes (the key) + out.write(ByteBuffer.wrap(key, 0, Ciphers.NONCE_LENGTH)); + } + + // read with the correct AAD and verify the read fails + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } + } + + @Test + public void testCorruptCiphertext() throws IOException { + Random random = new Random(); + byte[] key = new byte[16]; + random.nextBytes(key); + byte[] aadPrefix = new byte[16]; + random.nextBytes(aadPrefix); + byte[] content = new byte[Ciphers.PLAIN_BLOCK_SIZE / 2]; // half a block + random.nextBytes(content); + + File testFile = temp.newFile(); + + AesGcmOutputFile encryptedFile = + new AesGcmOutputFile(Files.localOutput(testFile), key, aadPrefix); + try (PositionOutputStream encryptedStream = encryptedFile.createOrOverwrite()) { + encryptedStream.write(content); + } + + // verify the data can be read correctly with the right AAD + AesGcmInputFile decryptedFile = new AesGcmInputFile(Files.localInput(testFile), key, aadPrefix); + Assert.assertEquals("File size", content.length, decryptedFile.getLength()); + + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + int bytesRead = decryptedStream.read(readContent); + Assert.assertEquals("Bytes read should match bytes written", content.length, bytesRead); + Assert.assertEquals( + "Content should match", + ByteBuffer.wrap(content), + ByteBuffer.wrap(readContent, 0, bytesRead)); + } + + // replace part of the first block's content + try (FileChannel out = FileChannel.open(testFile.toPath(), StandardOpenOption.WRITE)) { + out.position(Ciphers.GCM_STREAM_HEADER_LENGTH + Ciphers.NONCE_LENGTH + 34); + // overwrite the nonce with other random bytes (the key) + out.write(ByteBuffer.wrap(key)); + } + + // read with the correct AAD and verify the read fails + try (SeekableInputStream decryptedStream = decryptedFile.newStream()) { + byte[] readContent = new byte[Ciphers.PLAIN_BLOCK_SIZE]; + Assertions.assertThatThrownBy(() -> decryptedStream.read(readContent)) + .isInstanceOf(RuntimeException.class) + .hasCauseInstanceOf(AEADBadTagException.class) + .hasMessageContaining("GCM tag check failed"); + } } @Test