diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/CompressingDecryptingPageDeserializer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/CompressingDecryptingPageDeserializer.java new file mode 100644 index 000000000000..20dec4337643 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/CompressingDecryptingPageDeserializer.java @@ -0,0 +1,690 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.buffer; + +import com.google.common.base.VerifyException; +import io.airlift.compress.v3.Decompressor; +import io.airlift.compress.v3.lz4.Lz4Decompressor; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.slice.Slices; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.BlockEncodingSerde; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.UnsupportedEncodingException; +import java.security.GeneralSecurityException; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.airlift.slice.SizeOf.sizeOfByteArray; +import static io.trino.execution.buffer.PagesSerdeUtil.ESTIMATED_AES_CIPHER_RETAINED_SIZE; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_CIPHER_NAME; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_HEADER_SIZE; +import static io.trino.execution.buffer.PagesSerdeUtil.getSerializedPagePositionCount; +import static io.trino.execution.buffer.PagesSerdeUtil.readRawPage; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.util.Ciphers.is256BitSecretKeySpec; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static javax.crypto.Cipher.DECRYPT_MODE; + +public class CompressingDecryptingPageDeserializer + implements PageDeserializer +{ + private static final int INSTANCE_SIZE = instanceSize(CompressingDecryptingPageDeserializer.class); + + private final BlockEncodingSerde blockEncodingSerde; + private final SerializedPageInput input; + + public CompressingDecryptingPageDeserializer( + BlockEncodingSerde blockEncodingSerde, + Optional decompressor, + Optional encryptionKey, + int blockSizeInBytes, + OptionalInt maxCompressedBlockSizeInBytes) + { + this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); + requireNonNull(encryptionKey, "encryptionKey is null"); + encryptionKey.ifPresent(secretKey -> checkArgument(is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key")); + input = new SerializedPageInput( + requireNonNull(decompressor, "decompressor is null"), + encryptionKey, + blockSizeInBytes, + maxCompressedBlockSizeInBytes); + } + + @Override + public Page deserialize(Slice serializedPage) + { + int positionCount = input.startPage(serializedPage); + Page page = readRawPage(positionCount, input, blockEncodingSerde); + input.finishPage(); + return page; + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + input.getRetainedSize(); + } + + private static class SerializedPageInput + extends SliceInput + { + private static final int INSTANCE_SIZE = instanceSize(SerializedPageInput.class); + // TODO: implement getRetainedSizeInBytes in Lz4Decompressor + private static final int DECOMPRESSOR_RETAINED_SIZE = instanceSize(Lz4Decompressor.class); + private static final int ENCRYPTION_KEY_RETAINED_SIZE = toIntExact(instanceSize(SecretKeySpec.class) + sizeOfByteArray(256 / 8)); + + private final Optional decompressor; + private final Optional encryptionKey; + private final Optional cipher; + + private final ReadBuffer[] buffers; + + private SerializedPageInput(Optional decompressor, Optional encryptionKey, int blockSizeInBytes, OptionalInt maxCompressedBlockSizeInBytes) + { + this.decompressor = requireNonNull(decompressor, "decompressor is null"); + this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null"); + + buffers = new ReadBuffer[ + (decompressor.isPresent() ? 1 : 0) // decompression buffer + + (encryptionKey.isPresent() ? 1 : 0) // decryption buffer + + 1 // input buffer + ]; + if (decompressor.isPresent()) { + int bufferSize = blockSizeInBytes + // to guarantee a single long can always be read entirely + + Long.BYTES; + buffers[0] = new ReadBuffer(Slices.allocate(bufferSize)); + buffers[0].setPosition(bufferSize); + } + if (encryptionKey.isPresent()) { + int bufferSize; + if (decompressor.isPresent()) { + // to store compressed block size + bufferSize = maxCompressedBlockSizeInBytes.orElseThrow() + // to store compressed block size + + Integer.BYTES + // to guarantee a single long can always be read entirely + + Long.BYTES; + } + else { + bufferSize = blockSizeInBytes + // to guarantee a single long can always be read entirely + + Long.BYTES; + } + buffers[buffers.length - 2] = new ReadBuffer(Slices.allocate(bufferSize)); + buffers[buffers.length - 2].setPosition(bufferSize); + + try { + cipher = Optional.of(Cipher.getInstance(SERIALIZED_PAGE_CIPHER_NAME)); + } + catch (GeneralSecurityException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), e); + } + } + else { + cipher = Optional.empty(); + } + } + + public int startPage(Slice page) + { + int positionCount = getSerializedPagePositionCount(page); + ReadBuffer buffer = new ReadBuffer(page); + buffer.setPosition(SERIALIZED_PAGE_HEADER_SIZE); + buffers[buffers.length - 1] = buffer; + return positionCount; + } + + @Override + public boolean readBoolean() + { + ensureReadable(1); + return buffers[0].readBoolean(); + } + + @Override + public byte readByte() + { + ensureReadable(Byte.BYTES); + return buffers[0].readByte(); + } + + @Override + public short readShort() + { + ensureReadable(Short.BYTES); + return buffers[0].readShort(); + } + + @Override + public int readInt() + { + ensureReadable(Integer.BYTES); + return buffers[0].readInt(); + } + + @Override + public long readLong() + { + ensureReadable(Long.BYTES); + return buffers[0].readLong(); + } + + @Override + public float readFloat() + { + ensureReadable(Float.BYTES); + return buffers[0].readFloat(); + } + + @Override + public double readDouble() + { + ensureReadable(Double.BYTES); + return buffers[0].readDouble(); + } + + @Override + public int read(byte[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int bytesRemaining = length; + while (bytesRemaining > 0) { + ensureReadable(min(Long.BYTES, bytesRemaining)); + int bytesToRead = min(bytesRemaining, buffer.available()); + int bytesRead = buffer.read(destination, destinationIndex, bytesToRead); + if (bytesRead == -1) { + break; + } + bytesRemaining -= bytesRead; + destinationIndex += bytesRead; + } + return length - bytesRemaining; + } + + @Override + public void readBytes(byte[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int bytesRemaining = length; + while (bytesRemaining > 0) { + ensureReadable(min(Long.BYTES, bytesRemaining)); + int bytesToRead = min(bytesRemaining, buffer.available()); + buffer.readBytes(destination, destinationIndex, bytesToRead); + bytesRemaining -= bytesToRead; + destinationIndex += bytesToRead; + } + } + + @Override + public void readShorts(short[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int shortsRemaining = length; + while (shortsRemaining > 0) { + ensureReadable(min(Long.BYTES, shortsRemaining * Short.BYTES)); + int shortsToRead = min(shortsRemaining, buffer.available() / Short.BYTES); + buffer.readShorts(destination, destinationIndex, shortsToRead); + shortsRemaining -= shortsToRead; + destinationIndex += shortsToRead; + } + } + + @Override + public void readInts(int[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int intsRemaining = length; + while (intsRemaining > 0) { + ensureReadable(min(Long.BYTES, intsRemaining * Integer.BYTES)); + int intsToRead = min(intsRemaining, buffer.available() / Integer.BYTES); + buffer.readInts(destination, destinationIndex, intsToRead); + intsRemaining -= intsToRead; + destinationIndex += intsToRead; + } + } + + @Override + public void readLongs(long[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int longsRemaining = length; + while (longsRemaining > 0) { + ensureReadable(min(Long.BYTES, longsRemaining * Long.BYTES)); + int longsToRead = min(longsRemaining, buffer.available() / Long.BYTES); + buffer.readLongs(destination, destinationIndex, longsToRead); + longsRemaining -= longsToRead; + destinationIndex += longsToRead; + } + } + + @Override + public void readFloats(float[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int floatsRemaining = length; + while (floatsRemaining > 0) { + ensureReadable(min(Long.BYTES, floatsRemaining * Float.BYTES)); + int floatsToRead = min(floatsRemaining, buffer.available() / Float.BYTES); + buffer.readFloats(destination, destinationIndex, floatsToRead); + floatsRemaining -= floatsToRead; + destinationIndex += floatsToRead; + } + } + + @Override + public void readDoubles(double[] destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int doublesRemaining = length; + while (doublesRemaining > 0) { + ensureReadable(min(Long.BYTES, doublesRemaining * Double.BYTES)); + int doublesToRead = min(doublesRemaining, buffer.available() / Double.BYTES); + buffer.readDoubles(destination, destinationIndex, doublesToRead); + doublesRemaining -= doublesToRead; + destinationIndex += doublesToRead; + } + } + + @Override + public void readBytes(Slice destination, int destinationIndex, int length) + { + ReadBuffer buffer = buffers[0]; + int bytesRemaining = length; + while (bytesRemaining > 0) { + ensureReadable(min(Long.BYTES, bytesRemaining)); + int bytesToRead = min(bytesRemaining, buffer.available()); + buffer.readBytes(destination, destinationIndex, bytesToRead); + bytesRemaining -= bytesToRead; + destinationIndex += bytesToRead; + } + } + + private void ensureReadable(int bytes) + { + if (buffers[0].available() >= bytes) { + return; + } + decrypt(); + decompress(); + } + + private void decrypt() + { + if (this.encryptionKey.isEmpty()) { + return; + } + + ReadBuffer source = buffers[buffers.length - 1]; + ReadBuffer sink = buffers[buffers.length - 2]; + int bytesPreserved = sink.rollOver(); + + int encryptedSize = source.readInt(); + int ivSize = cipher.orElseThrow().getBlockSize(); + IvParameterSpec iv = new IvParameterSpec( + source.getSlice().byteArray(), + source.getSlice().byteArrayOffset() + source.getPosition(), + ivSize); + source.setPosition(source.getPosition() + ivSize); + + Cipher cipher = initCipher(encryptionKey.get(), iv); + int decryptedSize; + try { + // Do not refactor into single doFinal call, performance and allocation rate are significantly worse + // See https://github.com/trinodb/trino/pull/5557 + decryptedSize = cipher.update( + source.getSlice().byteArray(), + source.getSlice().byteArrayOffset() + source.getPosition(), + encryptedSize, + sink.getSlice().byteArray(), + sink.getSlice().byteArrayOffset() + bytesPreserved); + decryptedSize += cipher.doFinal( + sink.getSlice().byteArray(), + sink.getSlice().byteArrayOffset() + bytesPreserved + decryptedSize); + } + catch (GeneralSecurityException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Cannot decrypt previously encrypted data: " + e.getMessage(), e); + } + source.setPosition(source.getPosition() + encryptedSize); + sink.setLimit(bytesPreserved + decryptedSize); + } + + private Cipher initCipher(SecretKey key, IvParameterSpec iv) + { + Cipher cipher = this.cipher.orElseThrow(() -> new VerifyException("cipher is expected to be present")); + try { + cipher.init(DECRYPT_MODE, key, iv); + } + catch (GeneralSecurityException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), e); + } + return cipher; + } + + private void decompress() + { + if (this.decompressor.isEmpty()) { + return; + } + + Decompressor decompressor = this.decompressor.get(); + + ReadBuffer source = buffers[1]; + ReadBuffer sink = buffers[0]; + int bytesPreserved = sink.rollOver(); + + int compressedBlockMarker = source.readInt(); + int blockSize = getCompressedBlockSize(compressedBlockMarker); + boolean compressed = isCompressed(compressedBlockMarker); + + int decompressedSize; + if (compressed) { + decompressedSize = decompressor.decompress( + source.getSlice().byteArray(), + source.getSlice().byteArrayOffset() + source.getPosition(), + blockSize, + sink.getSlice().byteArray(), + sink.getSlice().byteArrayOffset() + bytesPreserved, + sink.getSlice().length() - bytesPreserved); + } + else { + System.arraycopy( + source.getSlice().byteArray(), + source.getSlice().byteArrayOffset() + source.getPosition(), + sink.getSlice().byteArray(), + sink.getSlice().byteArrayOffset() + bytesPreserved, + blockSize); + decompressedSize = blockSize; + } + source.setPosition(source.getPosition() + blockSize); + sink.setLimit(bytesPreserved + decompressedSize); + } + + private static int getCompressedBlockSize(int compressedBlockMarker) + { + return compressedBlockMarker & ~SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; + } + + private static boolean isCompressed(int compressedBlockMarker) + { + return (compressedBlockMarker & SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK) == SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; + } + + public void finishPage() + { + buffers[buffers.length - 1] = null; + for (ReadBuffer buffer : buffers) { + if (buffer != null) { + buffer.setPosition(buffer.getSlice().length()); + buffer.setLimit(buffer.getSlice().length()); + } + } + } + + @Override + public int read() + { + return readByte(); + } + + @Override + public int readUnsignedByte() + { + return readByte() & 0xFF; + } + + @Override + public int readUnsignedShort() + { + return readShort() & 0xFFFF; + } + + @Override + public Slice readSlice(int length) + { + Slice slice = Slices.allocate(length); + readBytes(slice, 0, length); + return slice; + } + + @Override + public boolean isReadable() + { + return available() > 0; + } + + @Override + public int available() + { + return buffers[0].available(); + } + + @Override + public long skip(long length) + { + return 0; + } + + @Override + public int skipBytes(int length) + { + return toIntExact(skip(length)); + } + + @Override + public long getRetainedSize() + { + long size = INSTANCE_SIZE; + size += sizeOf(decompressor, compressor -> DECOMPRESSOR_RETAINED_SIZE); + size += sizeOf(encryptionKey, encryptionKey -> ENCRYPTION_KEY_RETAINED_SIZE); + size += sizeOf(cipher, cipher -> ESTIMATED_AES_CIPHER_RETAINED_SIZE); + for (ReadBuffer input : buffers) { + if (input != null) { + size += input.getRetainedSizeInBytes(); + } + } + return size; + } + + @Override + public void readBytes(OutputStream out, int length) + throws IOException + { + throw new UnsupportedEncodingException(); + } + + @Override + public long position() + { + throw new UnsupportedOperationException(); + } + + @Override + public void setPosition(long position) + { + throw new UnsupportedOperationException(); + } + } + + private static class ReadBuffer + { + private static final int INSTANCE_SIZE = instanceSize(ReadBuffer.class); + + private final Slice slice; + private int position; + private int limit; + + public ReadBuffer(Slice slice) + { + requireNonNull(slice, "slice is null"); + this.slice = slice; + limit = slice.length(); + } + + public int available() + { + return limit - position; + } + + public Slice getSlice() + { + return slice; + } + + public int getPosition() + { + return position; + } + + public void setPosition(int position) + { + this.position = position; + } + + public void setLimit(int limit) + { + this.limit = limit; + } + + public int rollOver() + { + int bytesToCopy = available(); + if (bytesToCopy != 0) { + slice.setBytes(0, slice, position, bytesToCopy); + } + position = 0; + return bytesToCopy; + } + + public boolean readBoolean() + { + boolean value = slice.getByte(position) == 1; + position += Byte.BYTES; + return value; + } + + public byte readByte() + { + byte value = slice.getByte(position); + position += Byte.BYTES; + return value; + } + + public short readShort() + { + short value = slice.getShort(position); + position += Short.BYTES; + return value; + } + + public int readInt() + { + int value = slice.getInt(position); + position += Integer.BYTES; + return value; + } + + public long readLong() + { + long value = slice.getLong(position); + position += Long.BYTES; + return value; + } + + public float readFloat() + { + float value = slice.getFloat(position); + position += Float.BYTES; + return value; + } + + public double readDouble() + { + double value = slice.getDouble(position); + position += Double.BYTES; + return value; + } + + public int read(byte[] destination, int destinationIndex, int length) + { + int bytesToRead = min(length, slice.length() - position); + slice.getBytes(position, destination, destinationIndex, bytesToRead); + position += bytesToRead; + return bytesToRead; + } + + public void readBytes(byte[] destination, int destinationIndex, int length) + { + slice.getBytes(position, destination, destinationIndex, length); + position += length; + } + + public void readShorts(short[] destination, int destinationIndex, int length) + { + slice.getShorts(position, destination, destinationIndex, length); + position += length * Short.BYTES; + } + + public void readInts(int[] destination, int destinationIndex, int length) + { + slice.getInts(position, destination, destinationIndex, length); + position += length * Integer.BYTES; + } + + public void readLongs(long[] destination, int destinationIndex, int length) + { + slice.getLongs(position, destination, destinationIndex, length); + position += length * Long.BYTES; + } + + public void readFloats(float[] destination, int destinationIndex, int length) + { + slice.getFloats(position, destination, destinationIndex, length); + position += length * Float.BYTES; + } + + public void readDoubles(double[] destination, int destinationIndex, int length) + { + slice.getDoubles(position, destination, destinationIndex, length); + position += length * Double.BYTES; + } + + public void readBytes(Slice destination, int destinationIndex, int length) + { + slice.getBytes(position, destination, destinationIndex, length); + position += length; + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + slice.getRetainedSize(); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/CompressingEncryptingPageSerializer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/CompressingEncryptingPageSerializer.java new file mode 100644 index 000000000000..54e6d9d2053b --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/CompressingEncryptingPageSerializer.java @@ -0,0 +1,745 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.buffer; + +import com.google.common.base.VerifyException; +import io.airlift.compress.v3.Compressor; +import io.airlift.slice.Slice; +import io.airlift.slice.SliceOutput; +import io.airlift.slice.Slices; +import io.trino.spi.Page; +import io.trino.spi.TrinoException; +import io.trino.spi.block.BlockEncodingSerde; + +import javax.crypto.Cipher; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; +import java.security.GeneralSecurityException; +import java.util.Optional; +import java.util.OptionalInt; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.airlift.slice.SizeOf.sizeOfByteArray; +import static io.trino.execution.buffer.PageCodecMarker.COMPRESSED; +import static io.trino.execution.buffer.PageCodecMarker.ENCRYPTED; +import static io.trino.execution.buffer.PagesSerdeUtil.ESTIMATED_AES_CIPHER_RETAINED_SIZE; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_CIPHER_NAME; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_HEADER_SIZE; +import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET; +import static io.trino.execution.buffer.PagesSerdeUtil.writeRawPage; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.util.Ciphers.is256BitSecretKeySpec; +import static java.lang.Math.min; +import static java.lang.Math.round; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; +import static javax.crypto.Cipher.ENCRYPT_MODE; + +public class CompressingEncryptingPageSerializer + implements PageSerializer +{ + private static final int INSTANCE_SIZE = instanceSize(CompressingEncryptingPageSerializer.class); + + private final BlockEncodingSerde blockEncodingSerde; + private final SerializedPageOutput output; + + public CompressingEncryptingPageSerializer( + BlockEncodingSerde blockEncodingSerde, + Optional compressor, + Optional encryptionKey, + int blockSizeInBytes, + OptionalInt maxCompressedBlockSize) + { + this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); + requireNonNull(encryptionKey, "encryptionKey is null"); + encryptionKey.ifPresent(secretKey -> checkArgument(is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key")); + output = new SerializedPageOutput( + requireNonNull(compressor, "compressor is null"), + encryptionKey, + blockSizeInBytes, + maxCompressedBlockSize); + } + + @Override + public Slice serialize(Page page) + { + output.startPage(page.getPositionCount(), toIntExact(page.getSizeInBytes())); + writeRawPage(page, output, blockEncodingSerde); + return output.closePage(); + } + + @Override + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + output.getRetainedSize(); + } + + private static class SerializedPageOutput + extends SliceOutput + { + private static final int INSTANCE_SIZE = instanceSize(SerializedPageOutput.class); + private static final int ENCRYPTION_KEY_RETAINED_SIZE = toIntExact(instanceSize(SecretKeySpec.class) + sizeOfByteArray(256 / 8)); + + private static final double MINIMUM_COMPRESSION_RATIO = 0.8; + + private final Optional compressor; + private final Optional encryptionKey; + private final int markers; + private final Optional cipher; + + private final WriteBuffer[] buffers; + private int uncompressedSize; + + private SerializedPageOutput( + Optional compressor, + Optional encryptionKey, + int blockSizeInBytes, + OptionalInt maxCompressedBlockSize) + { + this.compressor = requireNonNull(compressor, "compressor is null"); + this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null"); + + buffers = new WriteBuffer[ + (compressor.isPresent() ? 1 : 0) // compression buffer + + (encryptionKey.isPresent() ? 1 : 0) // encryption buffer + + 1 // output buffer + ]; + PageCodecMarker.MarkerSet markerSet = PageCodecMarker.MarkerSet.empty(); + if (compressor.isPresent()) { + buffers[0] = new WriteBuffer(blockSizeInBytes); + markerSet.add(COMPRESSED); + } + if (encryptionKey.isPresent()) { + int bufferSize = blockSizeInBytes; + if (compressor.isPresent()) { + bufferSize = maxCompressedBlockSize.orElseThrow() + // to store compressed block size + + Integer.BYTES; + } + buffers[buffers.length - 2] = new WriteBuffer(bufferSize); + markerSet.add(ENCRYPTED); + + try { + cipher = Optional.of(Cipher.getInstance(SERIALIZED_PAGE_CIPHER_NAME)); + } + catch (GeneralSecurityException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), e); + } + } + else { + cipher = Optional.empty(); + } + markers = markerSet.byteValue(); + } + + public void startPage(int positionCount, int sizeInBytes) + { + WriteBuffer buffer = new WriteBuffer(round(sizeInBytes * 1.2F) + SERIALIZED_PAGE_HEADER_SIZE); + buffer.writeInt(positionCount); + buffer.writeByte(markers); + // leave space for uncompressed and compressed sizes + buffer.skip(Integer.BYTES * 2); + + buffers[buffers.length - 1] = buffer; + uncompressedSize = 0; + } + + @Override + public void writeByte(int value) + { + ensureCapacityFor(Byte.BYTES); + buffers[0].writeByte(value); + uncompressedSize += Byte.BYTES; + } + + @Override + public void writeShort(int value) + { + ensureCapacityFor(Short.BYTES); + buffers[0].writeShort(value); + uncompressedSize += Short.BYTES; + } + + @Override + public void writeInt(int value) + { + ensureCapacityFor(Integer.BYTES); + buffers[0].writeInt(value); + uncompressedSize += Integer.BYTES; + } + + @Override + public void writeLong(long value) + { + ensureCapacityFor(Long.BYTES); + buffers[0].writeLong(value); + uncompressedSize += Long.BYTES; + } + + @Override + public void writeFloat(float value) + { + ensureCapacityFor(Float.BYTES); + buffers[0].writeFloat(value); + uncompressedSize += Float.BYTES; + } + + @Override + public void writeDouble(double value) + { + ensureCapacityFor(Double.BYTES); + buffers[0].writeDouble(value); + uncompressedSize += Double.BYTES; + } + + @Override + public void writeBytes(Slice source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int bytesRemaining = length; + while (bytesRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, bytesRemaining)); + int bufferCapacity = buffer.remainingCapacity(); + int bytesToCopy = min(bytesRemaining, bufferCapacity); + buffer.writeBytes(source, currentIndex, bytesToCopy); + currentIndex += bytesToCopy; + bytesRemaining -= bytesToCopy; + } + uncompressedSize += length; + } + + @Override + public void writeBytes(byte[] source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int bytesRemaining = length; + while (bytesRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, bytesRemaining)); + int bufferCapacity = buffer.remainingCapacity(); + int bytesToCopy = min(bytesRemaining, bufferCapacity); + buffer.writeBytes(source, currentIndex, bytesToCopy); + currentIndex += bytesToCopy; + bytesRemaining -= bytesToCopy; + } + uncompressedSize += length; + } + + @Override + public void writeShorts(short[] source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int shortsRemaining = length; + while (shortsRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, shortsRemaining * Short.BYTES)); + int bufferCapacity = buffer.remainingCapacity(); + int shortsToCopy = min(shortsRemaining, bufferCapacity / Short.BYTES); + buffer.writeShorts(source, currentIndex, shortsToCopy); + currentIndex += shortsToCopy; + shortsRemaining -= shortsToCopy; + } + uncompressedSize += length * Short.BYTES; + } + + @Override + public void writeInts(int[] source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int intsRemaining = length; + while (intsRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, intsRemaining * Integer.BYTES)); + int bufferCapacity = buffer.remainingCapacity(); + int intsToCopy = min(intsRemaining, bufferCapacity / Integer.BYTES); + buffer.writeInts(source, currentIndex, intsToCopy); + currentIndex += intsToCopy; + intsRemaining -= intsToCopy; + } + uncompressedSize += length * Integer.BYTES; + } + + @Override + public void writeLongs(long[] source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int longsRemaining = length; + while (longsRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, longsRemaining * Long.BYTES)); + int bufferCapacity = buffer.remainingCapacity(); + int longsToCopy = min(longsRemaining, bufferCapacity / Long.BYTES); + buffer.writeLongs(source, currentIndex, longsToCopy); + currentIndex += longsToCopy; + longsRemaining -= longsToCopy; + } + uncompressedSize += length * Long.BYTES; + } + + @Override + public void writeFloats(float[] source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int floatsRemaining = length; + while (floatsRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, floatsRemaining * Float.BYTES)); + int bufferCapacity = buffer.remainingCapacity(); + int floatsToCopy = min(floatsRemaining, bufferCapacity / Float.BYTES); + buffer.writeFloats(source, currentIndex, floatsToCopy); + currentIndex += floatsToCopy; + floatsRemaining -= floatsToCopy; + } + uncompressedSize += length * Float.BYTES; + } + + @Override + public void writeDoubles(double[] source, int sourceIndex, int length) + { + WriteBuffer buffer = buffers[0]; + int currentIndex = sourceIndex; + int doublesRemaining = length; + while (doublesRemaining > 0) { + ensureCapacityFor(min(Long.BYTES, doublesRemaining * Double.BYTES)); + int bufferCapacity = buffer.remainingCapacity(); + int doublesToCopy = min(doublesRemaining, bufferCapacity / Double.BYTES); + buffer.writeDoubles(source, currentIndex, doublesToCopy); + currentIndex += doublesToCopy; + doublesRemaining -= doublesToCopy; + } + uncompressedSize += length * Double.BYTES; + } + + public Slice closePage() + { + compress(); + encrypt(); + + WriteBuffer pageBuffer = buffers[buffers.length - 1]; + int serializedPageSize = pageBuffer.getPosition(); + int compressedSize = serializedPageSize - SERIALIZED_PAGE_HEADER_SIZE; + Slice slice = pageBuffer.getSlice(); + slice.setInt(SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET, uncompressedSize); + slice.setInt(SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET, compressedSize); + + Slice page; + if (serializedPageSize < slice.length() / 2) { + page = slice.copy(0, serializedPageSize); + } + else { + page = slice.slice(0, serializedPageSize); + } + for (WriteBuffer buffer : buffers) { + buffer.reset(); + } + buffers[buffers.length - 1] = null; + uncompressedSize = 0; + return page; + } + + private void ensureCapacityFor(int bytes) + { + if (buffers[0].remainingCapacity() >= bytes) { + return; + } + // expand page output buffer + buffers[buffers.length - 1].ensureCapacityFor(bytes); + + compress(); + encrypt(); + } + + private void compress() + { + if (this.compressor.isEmpty()) { + return; + } + Compressor compressor = this.compressor.get(); + + WriteBuffer sourceBuffer = buffers[0]; + WriteBuffer sinkBuffer = buffers[1]; + + int maxCompressedLength = compressor.maxCompressedLength(sourceBuffer.getPosition()); + sinkBuffer.ensureCapacityFor(maxCompressedLength + Integer.BYTES); + + int uncompressedSize = sourceBuffer.getPosition(); + int compressedSize = compressor.compress( + sourceBuffer.getSlice().byteArray(), + sourceBuffer.getSlice().byteArrayOffset(), + uncompressedSize, + sinkBuffer.getSlice().byteArray(), + sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + Integer.BYTES, + maxCompressedLength); + + boolean compressed = uncompressedSize * MINIMUM_COMPRESSION_RATIO > compressedSize; + int blockSize; + if (!compressed) { + System.arraycopy( + sourceBuffer.getSlice().byteArray(), + sourceBuffer.getSlice().byteArrayOffset(), + sinkBuffer.getSlice().byteArray(), + sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + Integer.BYTES, + uncompressedSize); + blockSize = uncompressedSize; + } + else { + blockSize = compressedSize; + } + + sinkBuffer.writeInt(createBlockMarker(compressed, blockSize)); + sinkBuffer.skip(blockSize); + + sourceBuffer.reset(); + } + + private static int createBlockMarker(boolean compressed, int size) + { + if (compressed) { + return size | SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; + } + return size; + } + + private void encrypt() + { + if (encryptionKey.isEmpty()) { + return; + } + Cipher cipher = initCipher(encryptionKey.get()); + byte[] iv = cipher.getIV(); + + WriteBuffer sourceBuffer = buffers[buffers.length - 2]; + WriteBuffer sinkBuffer = buffers[buffers.length - 1]; + + int maxEncryptedSize = cipher.getOutputSize(sourceBuffer.getPosition()) + iv.length; + sinkBuffer.ensureCapacityFor(maxEncryptedSize + // to store encrypted block length + + Integer.BYTES + // to store initialization vector + + iv.length); + // reserve space for encrypted block length + sinkBuffer.skip(Integer.BYTES); + // write initialization vector + sinkBuffer.writeBytes(iv, 0, iv.length); + + int encryptedSize; + try { + // Do not refactor into single doFinal call, performance and allocation rate are significantly worse + // See https://github.com/trinodb/trino/pull/5557 + encryptedSize = cipher.update( + sourceBuffer.getSlice().byteArray(), + sourceBuffer.getSlice().byteArrayOffset(), + sourceBuffer.getPosition(), + sinkBuffer.getSlice().byteArray(), + sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition()); + encryptedSize += cipher.doFinal( + sinkBuffer.getSlice().byteArray(), + sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + encryptedSize); + } + catch (GeneralSecurityException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to encrypt data: " + e.getMessage(), e); + } + + sinkBuffer.getSlice().setInt(sinkBuffer.getPosition() - Integer.BYTES - iv.length, encryptedSize); + sinkBuffer.skip(encryptedSize); + + sourceBuffer.reset(); + } + + private Cipher initCipher(SecretKey key) + { + Cipher cipher = this.cipher.orElseThrow(() -> new VerifyException("cipher is expected to be present")); + try { + cipher.init(ENCRYPT_MODE, key); + } + catch (GeneralSecurityException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), e); + } + return cipher; + } + + @Override + public long getRetainedSize() + { + long size = INSTANCE_SIZE; + size += sizeOf(compressor, compressor -> instanceSize(compressor.getClass()) + + compressor.getRetainedSizeInBytes(uncompressedSize)); + size += sizeOf(encryptionKey, encryptionKey -> ENCRYPTION_KEY_RETAINED_SIZE); + size += sizeOf(cipher, cipher -> ESTIMATED_AES_CIPHER_RETAINED_SIZE); + for (WriteBuffer buffer : buffers) { + if (buffer != null) { + size += buffer.getRetainedSizeInBytes(); + } + } + return size; + } + + @Override + public int writableBytes() + { + return Integer.MAX_VALUE; + } + + @Override + public boolean isWritable() + { + return true; + } + + @Override + public void writeBytes(byte[] source) + { + writeBytes(source, 0, source.length); + } + + @Override + public void writeBytes(Slice source) + { + writeBytes(source, 0, source.length()); + } + + @Override + public void writeBytes(InputStream in, int length) + throws IOException + { + throw new UnsupportedOperationException(); + } + + @Override + public Slice slice() + { + throw new UnsupportedOperationException(); + } + + @Override + public Slice getUnderlyingSlice() + { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() + { + throw new UnsupportedOperationException(); + } + + @Override + public void reset(int position) + { + throw new UnsupportedOperationException(); + } + + @Override + public int size() + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString(Charset charset) + { + throw new UnsupportedOperationException(); + } + + @Override + public SliceOutput appendLong(long value) + { + writeLong(value); + return this; + } + + @Override + public SliceOutput appendDouble(double value) + { + writeDouble(value); + return this; + } + + @Override + public SliceOutput appendInt(int value) + { + writeInt(value); + return this; + } + + @Override + public SliceOutput appendShort(int value) + { + writeShort(value); + return this; + } + + @Override + public SliceOutput appendByte(int value) + { + writeByte(value); + return this; + } + + @Override + public SliceOutput appendBytes(byte[] source, int sourceIndex, int length) + { + writeBytes(source, sourceIndex, length); + return this; + } + + @Override + public SliceOutput appendBytes(byte[] source) + { + return appendBytes(source, 0, source.length); + } + + @Override + public SliceOutput appendBytes(Slice slice) + { + writeBytes(slice); + return this; + } + } + + private static class WriteBuffer + { + private static final int INSTANCE_SIZE = instanceSize(WriteBuffer.class); + + private Slice slice; + private int position; + + public WriteBuffer(int initialCapacity) + { + this.slice = Slices.allocate(initialCapacity); + } + + public void writeByte(int value) + { + slice.setByte(position, value); + position += Byte.BYTES; + } + + public void writeShort(int value) + { + slice.setShort(position, value); + position += Short.BYTES; + } + + public void writeInt(int value) + { + slice.setInt(position, value); + position += Integer.BYTES; + } + + public void writeLong(long value) + { + slice.setLong(position, value); + position += Long.BYTES; + } + + public void writeFloat(float value) + { + slice.setFloat(position, value); + position += Float.BYTES; + } + + public void writeDouble(double value) + { + slice.setDouble(position, value); + position += Double.BYTES; + } + + public void writeBytes(Slice source, int sourceIndex, int length) + { + slice.setBytes(position, source, sourceIndex, length); + position += length; + } + + public void writeBytes(byte[] source, int sourceIndex, int length) + { + slice.setBytes(position, source, sourceIndex, length); + position += length; + } + + public void writeShorts(short[] source, int sourceIndex, int length) + { + slice.setShorts(position, source, sourceIndex, length); + position += length * Short.BYTES; + } + + public void writeInts(int[] source, int sourceIndex, int length) + { + slice.setInts(position, source, sourceIndex, length); + position += length * Integer.BYTES; + } + + public void writeLongs(long[] source, int sourceIndex, int length) + { + slice.setLongs(position, source, sourceIndex, length); + position += length * Long.BYTES; + } + + public void writeFloats(float[] source, int sourceIndex, int length) + { + slice.setFloats(position, source, sourceIndex, length); + position += length * Float.BYTES; + } + + public void writeDoubles(double[] source, int sourceIndex, int length) + { + slice.setDoubles(position, source, sourceIndex, length); + position += length * Double.BYTES; + } + + public void skip(int length) + { + position += length; + } + + public int remainingCapacity() + { + return slice.length() - position; + } + + public int getPosition() + { + return position; + } + + public Slice getSlice() + { + return slice; + } + + public void reset() + { + position = 0; + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + slice.getRetainedSize(); + } + + public void ensureCapacityFor(int bytes) + { + slice = Slices.ensureSize(slice, position + bytes); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java index 48748c702bda..57bf42bc1737 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java @@ -13,675 +13,12 @@ */ package io.trino.execution.buffer; -import com.google.common.base.VerifyException; -import io.airlift.compress.v3.Decompressor; -import io.airlift.compress.v3.lz4.Lz4Decompressor; import io.airlift.slice.Slice; -import io.airlift.slice.SliceInput; -import io.airlift.slice.Slices; import io.trino.spi.Page; -import io.trino.spi.TrinoException; -import io.trino.spi.block.BlockEncodingSerde; -import javax.crypto.Cipher; -import javax.crypto.SecretKey; -import javax.crypto.spec.IvParameterSpec; -import javax.crypto.spec.SecretKeySpec; - -import java.io.IOException; -import java.io.OutputStream; -import java.io.UnsupportedEncodingException; -import java.security.GeneralSecurityException; -import java.util.Optional; -import java.util.OptionalInt; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.slice.SizeOf.instanceSize; -import static io.airlift.slice.SizeOf.sizeOf; -import static io.airlift.slice.SizeOf.sizeOfByteArray; -import static io.trino.execution.buffer.PagesSerdeUtil.ESTIMATED_AES_CIPHER_RETAINED_SIZE; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_CIPHER_NAME; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_HEADER_SIZE; -import static io.trino.execution.buffer.PagesSerdeUtil.getSerializedPagePositionCount; -import static io.trino.execution.buffer.PagesSerdeUtil.readRawPage; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.util.Ciphers.is256BitSecretKeySpec; -import static java.lang.Math.min; -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; -import static javax.crypto.Cipher.DECRYPT_MODE; - -public class PageDeserializer +public interface PageDeserializer { - private static final int INSTANCE_SIZE = instanceSize(PageDeserializer.class); - - private final BlockEncodingSerde blockEncodingSerde; - private final SerializedPageInput input; - - public PageDeserializer( - BlockEncodingSerde blockEncodingSerde, - Optional decompressor, - Optional encryptionKey, - int blockSizeInBytes, - OptionalInt maxCompressedBlockSizeInBytes) - { - this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); - requireNonNull(encryptionKey, "encryptionKey is null"); - encryptionKey.ifPresent(secretKey -> checkArgument(is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key")); - input = new SerializedPageInput( - requireNonNull(decompressor, "decompressor is null"), - encryptionKey, - blockSizeInBytes, - maxCompressedBlockSizeInBytes); - } - - public Page deserialize(Slice serializedPage) - { - int positionCount = input.startPage(serializedPage); - Page page = readRawPage(positionCount, input, blockEncodingSerde); - input.finishPage(); - return page; - } - - public long getRetainedSizeInBytes() - { - return INSTANCE_SIZE + input.getRetainedSize(); - } - - private static class SerializedPageInput - extends SliceInput - { - private static final int INSTANCE_SIZE = instanceSize(SerializedPageInput.class); - // TODO: implement getRetainedSizeInBytes in Lz4Decompressor - private static final int DECOMPRESSOR_RETAINED_SIZE = instanceSize(Lz4Decompressor.class); - private static final int ENCRYPTION_KEY_RETAINED_SIZE = toIntExact(instanceSize(SecretKeySpec.class) + sizeOfByteArray(256 / 8)); - - private final Optional decompressor; - private final Optional encryptionKey; - private final Optional cipher; - - private final ReadBuffer[] buffers; - - private SerializedPageInput(Optional decompressor, Optional encryptionKey, int blockSizeInBytes, OptionalInt maxCompressedBlockSizeInBytes) - { - this.decompressor = requireNonNull(decompressor, "decompressor is null"); - this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null"); - - buffers = new ReadBuffer[ - (decompressor.isPresent() ? 1 : 0) // decompression buffer - + (encryptionKey.isPresent() ? 1 : 0) // decryption buffer - + 1 // input buffer - ]; - if (decompressor.isPresent()) { - int bufferSize = blockSizeInBytes - // to guarantee a single long can always be read entirely - + Long.BYTES; - buffers[0] = new ReadBuffer(Slices.allocate(bufferSize)); - buffers[0].setPosition(bufferSize); - } - if (encryptionKey.isPresent()) { - int bufferSize; - if (decompressor.isPresent()) { - // to store compressed block size - bufferSize = maxCompressedBlockSizeInBytes.orElseThrow() - // to store compressed block size - + Integer.BYTES - // to guarantee a single long can always be read entirely - + Long.BYTES; - } - else { - bufferSize = blockSizeInBytes - // to guarantee a single long can always be read entirely - + Long.BYTES; - } - buffers[buffers.length - 2] = new ReadBuffer(Slices.allocate(bufferSize)); - buffers[buffers.length - 2].setPosition(bufferSize); - - try { - cipher = Optional.of(Cipher.getInstance(SERIALIZED_PAGE_CIPHER_NAME)); - } - catch (GeneralSecurityException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), e); - } - } - else { - cipher = Optional.empty(); - } - } - - public int startPage(Slice page) - { - int positionCount = getSerializedPagePositionCount(page); - ReadBuffer buffer = new ReadBuffer(page); - buffer.setPosition(SERIALIZED_PAGE_HEADER_SIZE); - buffers[buffers.length - 1] = buffer; - return positionCount; - } - - @Override - public boolean readBoolean() - { - ensureReadable(1); - return buffers[0].readBoolean(); - } - - @Override - public byte readByte() - { - ensureReadable(Byte.BYTES); - return buffers[0].readByte(); - } - - @Override - public short readShort() - { - ensureReadable(Short.BYTES); - return buffers[0].readShort(); - } - - @Override - public int readInt() - { - ensureReadable(Integer.BYTES); - return buffers[0].readInt(); - } - - @Override - public long readLong() - { - ensureReadable(Long.BYTES); - return buffers[0].readLong(); - } - - @Override - public float readFloat() - { - ensureReadable(Float.BYTES); - return buffers[0].readFloat(); - } - - @Override - public double readDouble() - { - ensureReadable(Double.BYTES); - return buffers[0].readDouble(); - } - - @Override - public int read(byte[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int bytesRemaining = length; - while (bytesRemaining > 0) { - ensureReadable(min(Long.BYTES, bytesRemaining)); - int bytesToRead = min(bytesRemaining, buffer.available()); - int bytesRead = buffer.read(destination, destinationIndex, bytesToRead); - if (bytesRead == -1) { - break; - } - bytesRemaining -= bytesRead; - destinationIndex += bytesRead; - } - return length - bytesRemaining; - } - - @Override - public void readBytes(byte[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int bytesRemaining = length; - while (bytesRemaining > 0) { - ensureReadable(min(Long.BYTES, bytesRemaining)); - int bytesToRead = min(bytesRemaining, buffer.available()); - buffer.readBytes(destination, destinationIndex, bytesToRead); - bytesRemaining -= bytesToRead; - destinationIndex += bytesToRead; - } - } - - @Override - public void readShorts(short[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int shortsRemaining = length; - while (shortsRemaining > 0) { - ensureReadable(min(Long.BYTES, shortsRemaining * Short.BYTES)); - int shortsToRead = min(shortsRemaining, buffer.available() / Short.BYTES); - buffer.readShorts(destination, destinationIndex, shortsToRead); - shortsRemaining -= shortsToRead; - destinationIndex += shortsToRead; - } - } - - @Override - public void readInts(int[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int intsRemaining = length; - while (intsRemaining > 0) { - ensureReadable(min(Long.BYTES, intsRemaining * Integer.BYTES)); - int intsToRead = min(intsRemaining, buffer.available() / Integer.BYTES); - buffer.readInts(destination, destinationIndex, intsToRead); - intsRemaining -= intsToRead; - destinationIndex += intsToRead; - } - } - - @Override - public void readLongs(long[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int longsRemaining = length; - while (longsRemaining > 0) { - ensureReadable(min(Long.BYTES, longsRemaining * Long.BYTES)); - int longsToRead = min(longsRemaining, buffer.available() / Long.BYTES); - buffer.readLongs(destination, destinationIndex, longsToRead); - longsRemaining -= longsToRead; - destinationIndex += longsToRead; - } - } - - @Override - public void readFloats(float[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int floatsRemaining = length; - while (floatsRemaining > 0) { - ensureReadable(min(Long.BYTES, floatsRemaining * Float.BYTES)); - int floatsToRead = min(floatsRemaining, buffer.available() / Float.BYTES); - buffer.readFloats(destination, destinationIndex, floatsToRead); - floatsRemaining -= floatsToRead; - destinationIndex += floatsToRead; - } - } - - @Override - public void readDoubles(double[] destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int doublesRemaining = length; - while (doublesRemaining > 0) { - ensureReadable(min(Long.BYTES, doublesRemaining * Double.BYTES)); - int doublesToRead = min(doublesRemaining, buffer.available() / Double.BYTES); - buffer.readDoubles(destination, destinationIndex, doublesToRead); - doublesRemaining -= doublesToRead; - destinationIndex += doublesToRead; - } - } - - @Override - public void readBytes(Slice destination, int destinationIndex, int length) - { - ReadBuffer buffer = buffers[0]; - int bytesRemaining = length; - while (bytesRemaining > 0) { - ensureReadable(min(Long.BYTES, bytesRemaining)); - int bytesToRead = min(bytesRemaining, buffer.available()); - buffer.readBytes(destination, destinationIndex, bytesToRead); - bytesRemaining -= bytesToRead; - destinationIndex += bytesToRead; - } - } - - private void ensureReadable(int bytes) - { - if (buffers[0].available() >= bytes) { - return; - } - decrypt(); - decompress(); - } - - private void decrypt() - { - if (this.encryptionKey.isEmpty()) { - return; - } - - ReadBuffer source = buffers[buffers.length - 1]; - ReadBuffer sink = buffers[buffers.length - 2]; - int bytesPreserved = sink.rollOver(); - - int encryptedSize = source.readInt(); - int ivSize = cipher.orElseThrow().getBlockSize(); - IvParameterSpec iv = new IvParameterSpec( - source.getSlice().byteArray(), - source.getSlice().byteArrayOffset() + source.getPosition(), - ivSize); - source.setPosition(source.getPosition() + ivSize); - - Cipher cipher = initCipher(encryptionKey.get(), iv); - int decryptedSize; - try { - // Do not refactor into single doFinal call, performance and allocation rate are significantly worse - // See https://github.com/trinodb/trino/pull/5557 - decryptedSize = cipher.update( - source.getSlice().byteArray(), - source.getSlice().byteArrayOffset() + source.getPosition(), - encryptedSize, - sink.getSlice().byteArray(), - sink.getSlice().byteArrayOffset() + bytesPreserved); - decryptedSize += cipher.doFinal( - sink.getSlice().byteArray(), - sink.getSlice().byteArrayOffset() + bytesPreserved + decryptedSize); - } - catch (GeneralSecurityException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Cannot decrypt previously encrypted data: " + e.getMessage(), e); - } - source.setPosition(source.getPosition() + encryptedSize); - sink.setLimit(bytesPreserved + decryptedSize); - } - - private Cipher initCipher(SecretKey key, IvParameterSpec iv) - { - Cipher cipher = this.cipher.orElseThrow(() -> new VerifyException("cipher is expected to be present")); - try { - cipher.init(DECRYPT_MODE, key, iv); - } - catch (GeneralSecurityException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), e); - } - return cipher; - } - - private void decompress() - { - if (this.decompressor.isEmpty()) { - return; - } - - Decompressor decompressor = this.decompressor.get(); - - ReadBuffer source = buffers[1]; - ReadBuffer sink = buffers[0]; - int bytesPreserved = sink.rollOver(); - - int compressedBlockMarker = source.readInt(); - int blockSize = getCompressedBlockSize(compressedBlockMarker); - boolean compressed = isCompressed(compressedBlockMarker); - - int decompressedSize; - if (compressed) { - decompressedSize = decompressor.decompress( - source.getSlice().byteArray(), - source.getSlice().byteArrayOffset() + source.getPosition(), - blockSize, - sink.getSlice().byteArray(), - sink.getSlice().byteArrayOffset() + bytesPreserved, - sink.getSlice().length() - bytesPreserved); - } - else { - System.arraycopy( - source.getSlice().byteArray(), - source.getSlice().byteArrayOffset() + source.getPosition(), - sink.getSlice().byteArray(), - sink.getSlice().byteArrayOffset() + bytesPreserved, - blockSize); - decompressedSize = blockSize; - } - source.setPosition(source.getPosition() + blockSize); - sink.setLimit(bytesPreserved + decompressedSize); - } - - private static int getCompressedBlockSize(int compressedBlockMarker) - { - return compressedBlockMarker & ~SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; - } - - private static boolean isCompressed(int compressedBlockMarker) - { - return (compressedBlockMarker & SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK) == SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; - } - - public void finishPage() - { - buffers[buffers.length - 1] = null; - for (ReadBuffer buffer : buffers) { - if (buffer != null) { - buffer.setPosition(buffer.getSlice().length()); - buffer.setLimit(buffer.getSlice().length()); - } - } - } - - @Override - public int read() - { - return readByte(); - } - - @Override - public int readUnsignedByte() - { - return readByte() & 0xFF; - } - - @Override - public int readUnsignedShort() - { - return readShort() & 0xFFFF; - } - - @Override - public Slice readSlice(int length) - { - Slice slice = Slices.allocate(length); - readBytes(slice, 0, length); - return slice; - } - - @Override - public boolean isReadable() - { - return available() > 0; - } - - @Override - public int available() - { - return buffers[0].available(); - } - - @Override - public long skip(long length) - { - return 0; - } - - @Override - public int skipBytes(int length) - { - return toIntExact(skip(length)); - } - - @Override - public long getRetainedSize() - { - long size = INSTANCE_SIZE; - size += sizeOf(decompressor, compressor -> DECOMPRESSOR_RETAINED_SIZE); - size += sizeOf(encryptionKey, encryptionKey -> ENCRYPTION_KEY_RETAINED_SIZE); - size += sizeOf(cipher, cipher -> ESTIMATED_AES_CIPHER_RETAINED_SIZE); - for (ReadBuffer input : buffers) { - if (input != null) { - size += input.getRetainedSizeInBytes(); - } - } - return size; - } - - @Override - public void readBytes(OutputStream out, int length) - throws IOException - { - throw new UnsupportedEncodingException(); - } - - @Override - public long position() - { - throw new UnsupportedOperationException(); - } - - @Override - public void setPosition(long position) - { - throw new UnsupportedOperationException(); - } - } - - private static class ReadBuffer - { - private static final int INSTANCE_SIZE = instanceSize(ReadBuffer.class); - - private final Slice slice; - private int position; - private int limit; - - public ReadBuffer(Slice slice) - { - requireNonNull(slice, "slice is null"); - this.slice = slice; - limit = slice.length(); - } - - public int available() - { - return limit - position; - } - - public Slice getSlice() - { - return slice; - } - - public int getPosition() - { - return position; - } - - public void setPosition(int position) - { - this.position = position; - } - - public void setLimit(int limit) - { - this.limit = limit; - } - - public int rollOver() - { - int bytesToCopy = available(); - if (bytesToCopy != 0) { - slice.setBytes(0, slice, position, bytesToCopy); - } - position = 0; - return bytesToCopy; - } - - public boolean readBoolean() - { - boolean value = slice.getByte(position) == 1; - position += Byte.BYTES; - return value; - } - - public byte readByte() - { - byte value = slice.getByte(position); - position += Byte.BYTES; - return value; - } - - public short readShort() - { - short value = slice.getShort(position); - position += Short.BYTES; - return value; - } - - public int readInt() - { - int value = slice.getInt(position); - position += Integer.BYTES; - return value; - } - - public long readLong() - { - long value = slice.getLong(position); - position += Long.BYTES; - return value; - } - - public float readFloat() - { - float value = slice.getFloat(position); - position += Float.BYTES; - return value; - } - - public double readDouble() - { - double value = slice.getDouble(position); - position += Double.BYTES; - return value; - } - - public int read(byte[] destination, int destinationIndex, int length) - { - int bytesToRead = min(length, slice.length() - position); - slice.getBytes(position, destination, destinationIndex, bytesToRead); - position += bytesToRead; - return bytesToRead; - } - - public void readBytes(byte[] destination, int destinationIndex, int length) - { - slice.getBytes(position, destination, destinationIndex, length); - position += length; - } - - public void readShorts(short[] destination, int destinationIndex, int length) - { - slice.getShorts(position, destination, destinationIndex, length); - position += length * Short.BYTES; - } - - public void readInts(int[] destination, int destinationIndex, int length) - { - slice.getInts(position, destination, destinationIndex, length); - position += length * Integer.BYTES; - } - - public void readLongs(long[] destination, int destinationIndex, int length) - { - slice.getLongs(position, destination, destinationIndex, length); - position += length * Long.BYTES; - } - - public void readFloats(float[] destination, int destinationIndex, int length) - { - slice.getFloats(position, destination, destinationIndex, length); - position += length * Float.BYTES; - } - - public void readDoubles(double[] destination, int destinationIndex, int length) - { - slice.getDoubles(position, destination, destinationIndex, length); - position += length * Double.BYTES; - } - - public void readBytes(Slice destination, int destinationIndex, int length) - { - slice.getBytes(position, destination, destinationIndex, length); - position += length; - } + Page deserialize(Slice slice); - public long getRetainedSizeInBytes() - { - return INSTANCE_SIZE + slice.getRetainedSize(); - } - } + long getRetainedSizeInBytes(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PageSerializer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PageSerializer.java index 88a96b52da47..027e648683d4 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PageSerializer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PageSerializer.java @@ -13,730 +13,12 @@ */ package io.trino.execution.buffer; -import com.google.common.base.VerifyException; -import io.airlift.compress.v3.Compressor; import io.airlift.slice.Slice; -import io.airlift.slice.SliceOutput; -import io.airlift.slice.Slices; import io.trino.spi.Page; -import io.trino.spi.TrinoException; -import io.trino.spi.block.BlockEncodingSerde; -import javax.crypto.Cipher; -import javax.crypto.SecretKey; -import javax.crypto.spec.SecretKeySpec; - -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.Charset; -import java.security.GeneralSecurityException; -import java.util.Optional; -import java.util.OptionalInt; - -import static com.google.common.base.Preconditions.checkArgument; -import static io.airlift.slice.SizeOf.instanceSize; -import static io.airlift.slice.SizeOf.sizeOf; -import static io.airlift.slice.SizeOf.sizeOfByteArray; -import static io.trino.execution.buffer.PageCodecMarker.COMPRESSED; -import static io.trino.execution.buffer.PageCodecMarker.ENCRYPTED; -import static io.trino.execution.buffer.PagesSerdeUtil.ESTIMATED_AES_CIPHER_RETAINED_SIZE; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_CIPHER_NAME; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_HEADER_SIZE; -import static io.trino.execution.buffer.PagesSerdeUtil.SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET; -import static io.trino.execution.buffer.PagesSerdeUtil.writeRawPage; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.util.Ciphers.is256BitSecretKeySpec; -import static java.lang.Math.min; -import static java.lang.Math.round; -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; -import static javax.crypto.Cipher.ENCRYPT_MODE; - -public class PageSerializer +public interface PageSerializer { - private static final int INSTANCE_SIZE = instanceSize(PageSerializer.class); - - private final BlockEncodingSerde blockEncodingSerde; - private final SerializedPageOutput output; - - public PageSerializer( - BlockEncodingSerde blockEncodingSerde, - Optional compressor, - Optional encryptionKey, - int blockSizeInBytes, - OptionalInt maxCompressedBlockSize) - { - this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); - requireNonNull(encryptionKey, "encryptionKey is null"); - encryptionKey.ifPresent(secretKey -> checkArgument(is256BitSecretKeySpec(secretKey), "encryptionKey is expected to be an instance of SecretKeySpec containing a 256bit key")); - output = new SerializedPageOutput( - requireNonNull(compressor, "compressor is null"), - encryptionKey, - blockSizeInBytes, - maxCompressedBlockSize); - } - - public Slice serialize(Page page) - { - output.startPage(page.getPositionCount(), toIntExact(page.getSizeInBytes())); - writeRawPage(page, output, blockEncodingSerde); - return output.closePage(); - } - - public long getRetainedSizeInBytes() - { - return INSTANCE_SIZE + output.getRetainedSize(); - } - - private static class SerializedPageOutput - extends SliceOutput - { - private static final int INSTANCE_SIZE = instanceSize(SerializedPageOutput.class); - private static final int ENCRYPTION_KEY_RETAINED_SIZE = toIntExact(instanceSize(SecretKeySpec.class) + sizeOfByteArray(256 / 8)); - - private static final double MINIMUM_COMPRESSION_RATIO = 0.8; - - private final Optional compressor; - private final Optional encryptionKey; - private final int markers; - private final Optional cipher; - - private final WriteBuffer[] buffers; - private int uncompressedSize; - - private SerializedPageOutput( - Optional compressor, - Optional encryptionKey, - int blockSizeInBytes, - OptionalInt maxCompressedBlockSize) - { - this.compressor = requireNonNull(compressor, "compressor is null"); - this.encryptionKey = requireNonNull(encryptionKey, "encryptionKey is null"); - - buffers = new WriteBuffer[ - (compressor.isPresent() ? 1 : 0) // compression buffer - + (encryptionKey.isPresent() ? 1 : 0) // encryption buffer - + 1 // output buffer - ]; - PageCodecMarker.MarkerSet markerSet = PageCodecMarker.MarkerSet.empty(); - if (compressor.isPresent()) { - buffers[0] = new WriteBuffer(blockSizeInBytes); - markerSet.add(COMPRESSED); - } - if (encryptionKey.isPresent()) { - int bufferSize = blockSizeInBytes; - if (compressor.isPresent()) { - bufferSize = maxCompressedBlockSize.orElseThrow() - // to store compressed block size - + Integer.BYTES; - } - buffers[buffers.length - 2] = new WriteBuffer(bufferSize); - markerSet.add(ENCRYPTED); - - try { - cipher = Optional.of(Cipher.getInstance(SERIALIZED_PAGE_CIPHER_NAME)); - } - catch (GeneralSecurityException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to create cipher: " + e.getMessage(), e); - } - } - else { - cipher = Optional.empty(); - } - markers = markerSet.byteValue(); - } - - public void startPage(int positionCount, int sizeInBytes) - { - WriteBuffer buffer = new WriteBuffer(round(sizeInBytes * 1.2F) + SERIALIZED_PAGE_HEADER_SIZE); - buffer.writeInt(positionCount); - buffer.writeByte(markers); - // leave space for uncompressed and compressed sizes - buffer.skip(Integer.BYTES * 2); - - buffers[buffers.length - 1] = buffer; - uncompressedSize = 0; - } - - @Override - public void writeByte(int value) - { - ensureCapacityFor(Byte.BYTES); - buffers[0].writeByte(value); - uncompressedSize += Byte.BYTES; - } - - @Override - public void writeShort(int value) - { - ensureCapacityFor(Short.BYTES); - buffers[0].writeShort(value); - uncompressedSize += Short.BYTES; - } - - @Override - public void writeInt(int value) - { - ensureCapacityFor(Integer.BYTES); - buffers[0].writeInt(value); - uncompressedSize += Integer.BYTES; - } - - @Override - public void writeLong(long value) - { - ensureCapacityFor(Long.BYTES); - buffers[0].writeLong(value); - uncompressedSize += Long.BYTES; - } - - @Override - public void writeFloat(float value) - { - ensureCapacityFor(Float.BYTES); - buffers[0].writeFloat(value); - uncompressedSize += Float.BYTES; - } - - @Override - public void writeDouble(double value) - { - ensureCapacityFor(Double.BYTES); - buffers[0].writeDouble(value); - uncompressedSize += Double.BYTES; - } - - @Override - public void writeBytes(Slice source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int bytesRemaining = length; - while (bytesRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, bytesRemaining)); - int bufferCapacity = buffer.remainingCapacity(); - int bytesToCopy = min(bytesRemaining, bufferCapacity); - buffer.writeBytes(source, currentIndex, bytesToCopy); - currentIndex += bytesToCopy; - bytesRemaining -= bytesToCopy; - } - uncompressedSize += length; - } - - @Override - public void writeBytes(byte[] source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int bytesRemaining = length; - while (bytesRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, bytesRemaining)); - int bufferCapacity = buffer.remainingCapacity(); - int bytesToCopy = min(bytesRemaining, bufferCapacity); - buffer.writeBytes(source, currentIndex, bytesToCopy); - currentIndex += bytesToCopy; - bytesRemaining -= bytesToCopy; - } - uncompressedSize += length; - } - - @Override - public void writeShorts(short[] source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int shortsRemaining = length; - while (shortsRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, shortsRemaining * Short.BYTES)); - int bufferCapacity = buffer.remainingCapacity(); - int shortsToCopy = min(shortsRemaining, bufferCapacity / Short.BYTES); - buffer.writeShorts(source, currentIndex, shortsToCopy); - currentIndex += shortsToCopy; - shortsRemaining -= shortsToCopy; - } - uncompressedSize += length * Short.BYTES; - } - - @Override - public void writeInts(int[] source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int intsRemaining = length; - while (intsRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, intsRemaining * Integer.BYTES)); - int bufferCapacity = buffer.remainingCapacity(); - int intsToCopy = min(intsRemaining, bufferCapacity / Integer.BYTES); - buffer.writeInts(source, currentIndex, intsToCopy); - currentIndex += intsToCopy; - intsRemaining -= intsToCopy; - } - uncompressedSize += length * Integer.BYTES; - } - - @Override - public void writeLongs(long[] source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int longsRemaining = length; - while (longsRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, longsRemaining * Long.BYTES)); - int bufferCapacity = buffer.remainingCapacity(); - int longsToCopy = min(longsRemaining, bufferCapacity / Long.BYTES); - buffer.writeLongs(source, currentIndex, longsToCopy); - currentIndex += longsToCopy; - longsRemaining -= longsToCopy; - } - uncompressedSize += length * Long.BYTES; - } - - @Override - public void writeFloats(float[] source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int floatsRemaining = length; - while (floatsRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, floatsRemaining * Float.BYTES)); - int bufferCapacity = buffer.remainingCapacity(); - int floatsToCopy = min(floatsRemaining, bufferCapacity / Float.BYTES); - buffer.writeFloats(source, currentIndex, floatsToCopy); - currentIndex += floatsToCopy; - floatsRemaining -= floatsToCopy; - } - uncompressedSize += length * Float.BYTES; - } - - @Override - public void writeDoubles(double[] source, int sourceIndex, int length) - { - WriteBuffer buffer = buffers[0]; - int currentIndex = sourceIndex; - int doublesRemaining = length; - while (doublesRemaining > 0) { - ensureCapacityFor(min(Long.BYTES, doublesRemaining * Double.BYTES)); - int bufferCapacity = buffer.remainingCapacity(); - int doublesToCopy = min(doublesRemaining, bufferCapacity / Double.BYTES); - buffer.writeDoubles(source, currentIndex, doublesToCopy); - currentIndex += doublesToCopy; - doublesRemaining -= doublesToCopy; - } - uncompressedSize += length * Double.BYTES; - } - - public Slice closePage() - { - compress(); - encrypt(); - - WriteBuffer pageBuffer = buffers[buffers.length - 1]; - int serializedPageSize = pageBuffer.getPosition(); - int compressedSize = serializedPageSize - SERIALIZED_PAGE_HEADER_SIZE; - Slice slice = pageBuffer.getSlice(); - slice.setInt(SERIALIZED_PAGE_UNCOMPRESSED_SIZE_OFFSET, uncompressedSize); - slice.setInt(SERIALIZED_PAGE_COMPRESSED_SIZE_OFFSET, compressedSize); - - Slice page; - if (serializedPageSize < slice.length() / 2) { - page = slice.copy(0, serializedPageSize); - } - else { - page = slice.slice(0, serializedPageSize); - } - for (WriteBuffer buffer : buffers) { - buffer.reset(); - } - buffers[buffers.length - 1] = null; - uncompressedSize = 0; - return page; - } - - private void ensureCapacityFor(int bytes) - { - if (buffers[0].remainingCapacity() >= bytes) { - return; - } - // expand page output buffer - buffers[buffers.length - 1].ensureCapacityFor(bytes); - - compress(); - encrypt(); - } - - private void compress() - { - if (this.compressor.isEmpty()) { - return; - } - Compressor compressor = this.compressor.get(); - - WriteBuffer sourceBuffer = buffers[0]; - WriteBuffer sinkBuffer = buffers[1]; - - int maxCompressedLength = compressor.maxCompressedLength(sourceBuffer.getPosition()); - sinkBuffer.ensureCapacityFor(maxCompressedLength + Integer.BYTES); - - int uncompressedSize = sourceBuffer.getPosition(); - int compressedSize = compressor.compress( - sourceBuffer.getSlice().byteArray(), - sourceBuffer.getSlice().byteArrayOffset(), - uncompressedSize, - sinkBuffer.getSlice().byteArray(), - sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + Integer.BYTES, - maxCompressedLength); - - boolean compressed = uncompressedSize * MINIMUM_COMPRESSION_RATIO > compressedSize; - int blockSize; - if (!compressed) { - System.arraycopy( - sourceBuffer.getSlice().byteArray(), - sourceBuffer.getSlice().byteArrayOffset(), - sinkBuffer.getSlice().byteArray(), - sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + Integer.BYTES, - uncompressedSize); - blockSize = uncompressedSize; - } - else { - blockSize = compressedSize; - } - - sinkBuffer.writeInt(createBlockMarker(compressed, blockSize)); - sinkBuffer.skip(blockSize); - - sourceBuffer.reset(); - } - - private static int createBlockMarker(boolean compressed, int size) - { - if (compressed) { - return size | SERIALIZED_PAGE_COMPRESSED_BLOCK_MASK; - } - return size; - } - - private void encrypt() - { - if (encryptionKey.isEmpty()) { - return; - } - Cipher cipher = initCipher(encryptionKey.get()); - byte[] iv = cipher.getIV(); - - WriteBuffer sourceBuffer = buffers[buffers.length - 2]; - WriteBuffer sinkBuffer = buffers[buffers.length - 1]; - - int maxEncryptedSize = cipher.getOutputSize(sourceBuffer.getPosition()) + iv.length; - sinkBuffer.ensureCapacityFor(maxEncryptedSize - // to store encrypted block length - + Integer.BYTES - // to store initialization vector - + iv.length); - // reserve space for encrypted block length - sinkBuffer.skip(Integer.BYTES); - // write initialization vector - sinkBuffer.writeBytes(iv, 0, iv.length); - - int encryptedSize; - try { - // Do not refactor into single doFinal call, performance and allocation rate are significantly worse - // See https://github.com/trinodb/trino/pull/5557 - encryptedSize = cipher.update( - sourceBuffer.getSlice().byteArray(), - sourceBuffer.getSlice().byteArrayOffset(), - sourceBuffer.getPosition(), - sinkBuffer.getSlice().byteArray(), - sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition()); - encryptedSize += cipher.doFinal( - sinkBuffer.getSlice().byteArray(), - sinkBuffer.getSlice().byteArrayOffset() + sinkBuffer.getPosition() + encryptedSize); - } - catch (GeneralSecurityException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to encrypt data: " + e.getMessage(), e); - } - - sinkBuffer.getSlice().setInt(sinkBuffer.getPosition() - Integer.BYTES - iv.length, encryptedSize); - sinkBuffer.skip(encryptedSize); - - sourceBuffer.reset(); - } - - private Cipher initCipher(SecretKey key) - { - Cipher cipher = this.cipher.orElseThrow(() -> new VerifyException("cipher is expected to be present")); - try { - cipher.init(ENCRYPT_MODE, key); - } - catch (GeneralSecurityException e) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, "Failed to init cipher: " + e.getMessage(), e); - } - return cipher; - } - - @Override - public long getRetainedSize() - { - long size = INSTANCE_SIZE; - size += sizeOf(compressor, compressor -> instanceSize(compressor.getClass()) - + compressor.getRetainedSizeInBytes(uncompressedSize)); - size += sizeOf(encryptionKey, encryptionKey -> ENCRYPTION_KEY_RETAINED_SIZE); - size += sizeOf(cipher, cipher -> ESTIMATED_AES_CIPHER_RETAINED_SIZE); - for (WriteBuffer buffer : buffers) { - if (buffer != null) { - size += buffer.getRetainedSizeInBytes(); - } - } - return size; - } - - @Override - public int writableBytes() - { - return Integer.MAX_VALUE; - } - - @Override - public boolean isWritable() - { - return true; - } - - @Override - public void writeBytes(byte[] source) - { - writeBytes(source, 0, source.length); - } - - @Override - public void writeBytes(Slice source) - { - writeBytes(source, 0, source.length()); - } - - @Override - public void writeBytes(InputStream in, int length) - throws IOException - { - throw new UnsupportedOperationException(); - } - - @Override - public Slice slice() - { - throw new UnsupportedOperationException(); - } - - @Override - public Slice getUnderlyingSlice() - { - throw new UnsupportedOperationException(); - } - - @Override - public void reset() - { - throw new UnsupportedOperationException(); - } - - @Override - public void reset(int position) - { - throw new UnsupportedOperationException(); - } - - @Override - public int size() - { - throw new UnsupportedOperationException(); - } - - @Override - public String toString(Charset charset) - { - throw new UnsupportedOperationException(); - } - - @Override - public SliceOutput appendLong(long value) - { - writeLong(value); - return this; - } - - @Override - public SliceOutput appendDouble(double value) - { - writeDouble(value); - return this; - } - - @Override - public SliceOutput appendInt(int value) - { - writeInt(value); - return this; - } - - @Override - public SliceOutput appendShort(int value) - { - writeShort(value); - return this; - } - - @Override - public SliceOutput appendByte(int value) - { - writeByte(value); - return this; - } - - @Override - public SliceOutput appendBytes(byte[] source, int sourceIndex, int length) - { - writeBytes(source, sourceIndex, length); - return this; - } - - @Override - public SliceOutput appendBytes(byte[] source) - { - return appendBytes(source, 0, source.length); - } - - @Override - public SliceOutput appendBytes(Slice slice) - { - writeBytes(slice); - return this; - } - } - - private static class WriteBuffer - { - private static final int INSTANCE_SIZE = instanceSize(WriteBuffer.class); - - private Slice slice; - private int position; - - public WriteBuffer(int initialCapacity) - { - this.slice = Slices.allocate(initialCapacity); - } - - public void writeByte(int value) - { - slice.setByte(position, value); - position += Byte.BYTES; - } - - public void writeShort(int value) - { - slice.setShort(position, value); - position += Short.BYTES; - } - - public void writeInt(int value) - { - slice.setInt(position, value); - position += Integer.BYTES; - } - - public void writeLong(long value) - { - slice.setLong(position, value); - position += Long.BYTES; - } - - public void writeFloat(float value) - { - slice.setFloat(position, value); - position += Float.BYTES; - } - - public void writeDouble(double value) - { - slice.setDouble(position, value); - position += Double.BYTES; - } - - public void writeBytes(Slice source, int sourceIndex, int length) - { - slice.setBytes(position, source, sourceIndex, length); - position += length; - } - - public void writeBytes(byte[] source, int sourceIndex, int length) - { - slice.setBytes(position, source, sourceIndex, length); - position += length; - } - - public void writeShorts(short[] source, int sourceIndex, int length) - { - slice.setShorts(position, source, sourceIndex, length); - position += length * Short.BYTES; - } - - public void writeInts(int[] source, int sourceIndex, int length) - { - slice.setInts(position, source, sourceIndex, length); - position += length * Integer.BYTES; - } - - public void writeLongs(long[] source, int sourceIndex, int length) - { - slice.setLongs(position, source, sourceIndex, length); - position += length * Long.BYTES; - } - - public void writeFloats(float[] source, int sourceIndex, int length) - { - slice.setFloats(position, source, sourceIndex, length); - position += length * Float.BYTES; - } - - public void writeDoubles(double[] source, int sourceIndex, int length) - { - slice.setDoubles(position, source, sourceIndex, length); - position += length * Double.BYTES; - } - - public void skip(int length) - { - position += length; - } - - public int remainingCapacity() - { - return slice.length() - position; - } - - public int getPosition() - { - return position; - } - - public Slice getSlice() - { - return slice; - } - - public void reset() - { - position = 0; - } - - public long getRetainedSizeInBytes() - { - return INSTANCE_SIZE + slice.getRetainedSize(); - } + Slice serialize(Page page); - public void ensureCapacityFor(int bytes) - { - slice = Slices.ensureSize(slice, position + bytes); - } - } + long getRetainedSizeInBytes(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeFactory.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeFactory.java index 2fcf3b41e79a..29c0ec873db1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PagesSerdeFactory.java @@ -13,6 +13,7 @@ */ package io.trino.execution.buffer; +import com.google.common.annotations.VisibleForTesting; import io.airlift.compress.v3.Compressor; import io.airlift.compress.v3.Decompressor; import io.airlift.compress.v3.lz4.Lz4Compressor; @@ -23,51 +24,51 @@ import javax.crypto.SecretKey; -import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import static io.trino.execution.buffer.CompressionCodec.LZ4; -import static io.trino.execution.buffer.CompressionCodec.NONE; import static io.trino.execution.buffer.CompressionCodec.ZSTD; import static java.util.Objects.requireNonNull; public class PagesSerdeFactory { private static final int SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES = 64 * 1024; - - private static final Map MAX_COMPRESSED_LENGTH = Map.of( - NONE, NONE.maxCompressedLength(SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES), - LZ4, LZ4.maxCompressedLength(SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES), - ZSTD, ZSTD.maxCompressedLength(SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES)); - private final BlockEncodingSerde blockEncodingSerde; private final CompressionCodec compressionCodec; + private final int blockSizeInBytes; public PagesSerdeFactory(BlockEncodingSerde blockEncodingSerde, CompressionCodec compressionCodec) + { + this(blockEncodingSerde, compressionCodec, SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES); + } + + @VisibleForTesting + PagesSerdeFactory(BlockEncodingSerde blockEncodingSerde, CompressionCodec compressionCodec, int blockSizeInBytes) { this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerde is null"); this.compressionCodec = requireNonNull(compressionCodec, "compressionCodec is null"); + this.blockSizeInBytes = blockSizeInBytes; } public PageSerializer createSerializer(Optional encryptionKey) { - return new PageSerializer( + return new CompressingEncryptingPageSerializer( blockEncodingSerde, createCompressor(compressionCodec), encryptionKey, - SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES, - MAX_COMPRESSED_LENGTH.get(compressionCodec)); + blockSizeInBytes, + maxCompressedSize(blockSizeInBytes, compressionCodec)); } public PageDeserializer createDeserializer(Optional encryptionKey) { - return new PageDeserializer( + return new CompressingDecryptingPageDeserializer( blockEncodingSerde, createDecompressor(compressionCodec), encryptionKey, - SERIALIZED_PAGE_DEFAULT_BLOCK_SIZE_IN_BYTES, - MAX_COMPRESSED_LENGTH.get(compressionCodec)); + blockSizeInBytes, + maxCompressedSize(blockSizeInBytes, compressionCodec)); } public static Optional createCompressor(CompressionCodec compressionCodec) @@ -87,4 +88,13 @@ public static Optional createDecompressor(CompressionCodec compres case ZSTD -> Optional.of(ZstdDecompressor.create()); }; } + + private static OptionalInt maxCompressedSize(int uncompressedSize, CompressionCodec compressionCodec) + { + return switch (compressionCodec) { + case NONE -> OptionalInt.of(uncompressedSize); + case LZ4 -> LZ4.maxCompressedLength(uncompressedSize); + case ZSTD -> ZSTD.maxCompressedLength(uncompressedSize); + }; + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java index 8f33bc52108c..2d3dae59b8e0 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java @@ -45,8 +45,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.execution.buffer.CompressionCodec.NONE; -import static io.trino.execution.buffer.PagesSerdeFactory.createCompressor; -import static io.trino.execution.buffer.PagesSerdeFactory.createDecompressor; import static io.trino.execution.buffer.PagesSerdeUtil.readPages; import static io.trino.execution.buffer.PagesSerdeUtil.writePages; import static io.trino.operator.PageAssertions.assertPageEquals; @@ -146,8 +144,10 @@ private void testRoundTrip(List types, List pages, boolean encryptio { Optional encryptionKey = encryptionEnabled ? Optional.of(createRandomAesEncryptionKey()) : Optional.empty(); for (CompressionCodec compressionCodec : CompressionCodec.values()) { - PageSerializer serializer = new PageSerializer(blockEncodingSerde, createCompressor(compressionCodec), encryptionKey, blockSizeInBytes, compressionCodec.maxCompressedLength(blockSizeInBytes)); - PageDeserializer deserializer = new PageDeserializer(blockEncodingSerde, createDecompressor(compressionCodec), encryptionKey, blockSizeInBytes, compressionCodec.maxCompressedLength(blockSizeInBytes)); + PagesSerdeFactory pagesSerdeFactory = new PagesSerdeFactory(blockEncodingSerde, compressionCodec, blockSizeInBytes); + PageSerializer serializer = pagesSerdeFactory.createSerializer(encryptionKey); + PageDeserializer deserializer = pagesSerdeFactory.createDeserializer(encryptionKey); + for (Page page : pages) { Slice serialized = serializer.serialize(page); Page deserialized = deserializer.deserialize(serialized); @@ -270,8 +270,9 @@ private void testDeserializationWithRollover(boolean encryptionEnabled, int numb RolloverBlockSerde blockSerde = new RolloverBlockSerde(); Optional encryptionKey = encryptionEnabled ? Optional.of(createRandomAesEncryptionKey()) : Optional.empty(); for (CompressionCodec compressionCodec : CompressionCodec.values()) { - PageSerializer serializer = new PageSerializer(blockSerde, createCompressor(compressionCodec), encryptionKey, blockSize, compressionCodec.maxCompressedLength(blockSize)); - PageDeserializer deserializer = new PageDeserializer(blockSerde, createDecompressor(compressionCodec), encryptionKey, blockSize, compressionCodec.maxCompressedLength(blockSize)); + PagesSerdeFactory pagesSerdeFactory = new PagesSerdeFactory(blockSerde, compressionCodec, blockSize); + PageSerializer serializer = pagesSerdeFactory.createSerializer(encryptionKey); + PageDeserializer deserializer = pagesSerdeFactory.createDeserializer(encryptionKey); Page page = createTestPage(numberOfEntries); Slice serialized = serializer.serialize(page);