diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java index 49f5dbd2dd3d..60c96879d8fe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PageDeserializer.java @@ -337,7 +337,7 @@ private void decompress() blockSize, sink.getSlice().byteArray(), sink.getSlice().byteArrayOffset() + bytesPreserved, - sink.getSlice().length()); + sink.getSlice().length() - bytesPreserved); } else { System.arraycopy( diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java index 59de15c34c63..744ccee92c78 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestPagesSerde.java @@ -16,15 +16,21 @@ import com.google.common.collect.ImmutableList; import io.airlift.slice.DynamicSliceOutput; import io.airlift.slice.Slice; +import io.airlift.slice.SliceInput; +import io.airlift.slice.SliceOutput; import io.trino.metadata.BlockEncodingManager; import io.trino.metadata.InternalBlockEncodingSerde; import io.trino.spi.Page; import io.trino.spi.PageBuilder; +import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.BlockEncodingSerde; +import io.trino.spi.block.VariableWidthBlock; +import io.trino.spi.block.VariableWidthBlockBuilder; import io.trino.spi.type.Type; import io.trino.tpch.LineItem; import io.trino.tpch.LineItemGenerator; +import org.assertj.core.api.Assertions; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -236,4 +242,112 @@ private int serializedSize(List types, Page expectedPage) return slice.length(); } + + @Test + public void testDeserializationWithRollover() + { + // test non-zero rollover when refilling buffer on deserialization + for (int blockSize = 100; blockSize < 500; blockSize += 101) { + for (int numberOfEntries = 500; numberOfEntries < 1000; numberOfEntries += 99) { + testDeserializationWithRollover(blockSize, numberOfEntries); + } + } + } + + private void testDeserializationWithRollover(int blockSize, int numberOfEntries) + { + testDeserializationWithRollover(false, false, numberOfEntries, blockSize); + testDeserializationWithRollover(false, true, numberOfEntries, blockSize); + testDeserializationWithRollover(true, false, numberOfEntries, blockSize); + testDeserializationWithRollover(true, true, numberOfEntries, blockSize); + } + + private void testDeserializationWithRollover(boolean encryptionEnabled, boolean compressionEnabled, int numberOfEntries, int blockSize) + { + RolloverBlockSerde blockSerde = new RolloverBlockSerde(); + Optional encryptionKey = encryptionEnabled ? Optional.of(createRandomAesEncryptionKey()) : Optional.empty(); + PageSerializer serializer = new PageSerializer(blockSerde, compressionEnabled, encryptionKey, blockSize); + PageDeserializer deserializer = new PageDeserializer(blockSerde, compressionEnabled, encryptionKey, blockSize); + + Page page = createTestPage(numberOfEntries); + Slice serialized = serializer.serialize(page); + Page deserialized = deserializer.deserialize(serialized); + assertEquals(deserialized.getChannelCount(), 1); + + VariableWidthBlock expected = (VariableWidthBlock) page.getBlock(0); + VariableWidthBlock actual = (VariableWidthBlock) deserialized.getBlock(0); + + Assertions.assertThat(actual.getRawSlice().getBytes()).isEqualTo(expected.getRawSlice().getBytes()); + } + + private static Page createTestPage(int numberOfEntries) + { + VariableWidthBlockBuilder blockBuilder = new VariableWidthBlockBuilder(null, 1, 1000); + blockBuilder.writeInt(numberOfEntries); + for (int i = 0; i < numberOfEntries; i++) { + blockBuilder.writeLong(i); + } + blockBuilder.closeEntry(); + return new Page(blockBuilder.build()); + } + + private static class RolloverBlockSerde + implements BlockEncodingSerde + { + @Override + public Block readBlock(SliceInput input) + { + int numberOfEntries = input.readInt(); + VariableWidthBlockBuilder blockBuilder = new VariableWidthBlockBuilder(null, 1, 1000); + blockBuilder.writeInt(numberOfEntries); + for (int i = 0; i < numberOfEntries; ++i) { + // read 8 bytes at a time + blockBuilder.writeLong(input.readLong()); + } + blockBuilder.closeEntry(); + return blockBuilder.build(); + } + + @Override + public void writeBlock(SliceOutput output, Block block) + { + int offset = 0; + int numberOfEntries = block.getInt(0, offset); + output.writeInt(numberOfEntries); + offset += 4; + for (int i = 0; i < numberOfEntries; ++i) { + long value = block.getLong(0, offset); + offset += 8; + long b7 = value >> 56 & 0xffL; + long b6 = value >> 48 & 0xffL; + long b5 = value >> 40 & 0xffL; + long b4 = value >> 32 & 0xffL; + long b3 = value >> 24 & 0xffL; + long b2 = value >> 16 & 0xffL; + long b1 = value >> 8 & 0xffL; + long b0 = value & 0xffL; + // write one byte at a time + output.writeByte((int) b0); + output.writeByte((int) b1); + output.writeByte((int) b2); + output.writeByte((int) b3); + output.writeByte((int) b4); + output.writeByte((int) b5); + output.writeByte((int) b6); + output.writeByte((int) b7); + } + } + + @Override + public Type readType(SliceInput sliceInput) + { + throw new RuntimeException("not implemented"); + } + + @Override + public void writeType(SliceOutput sliceOutput, Type type) + { + throw new RuntimeException("not implemented"); + } + } }