diff --git a/src/main/java/com/facebook/presto/DictionarySerde.java b/src/main/java/com/facebook/presto/DictionarySerde.java new file mode 100644 index 0000000000000..6d19a167a6681 --- /dev/null +++ b/src/main/java/com/facebook/presto/DictionarySerde.java @@ -0,0 +1,160 @@ +package com.facebook.presto; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableBiMap; + +import java.util.Iterator; +import java.util.Map; + +public class DictionarySerde +{ + private final long maxCardinality; + // TODO: we may be able to determine and adjust this value dynamically with a smarter implementation + private final int reqBitSpace; + + public DictionarySerde(long maxCardinality) { + this.maxCardinality = maxCardinality; + reqBitSpace = Long.SIZE - Long.numberOfLeadingZeros(maxCardinality - 1); + } + + public DictionarySerde() { + this(Long.MAX_VALUE); + } + + public void serialize(final Iterable slices, SliceOutput sliceOutput) + { + final BiMap idMap = HashBiMap.create(); + + PackedLongSerde packedLongSerde = new PackedLongSerde(reqBitSpace); + packedLongSerde.serialize( + new Iterable() { + @Override + public Iterator iterator() { + return new AbstractIterator() { + Iterator sliceIterator = slices.iterator(); + // Start ID at the smallest possible value to fully utilize available bit space + long nextId = -1L << (reqBitSpace - 1); + + @Override + protected Long computeNext() { + if (!sliceIterator.hasNext()) { + return endOfData(); + } + + Slice slice = sliceIterator.next(); + + Long id = idMap.get(slice); + if (id == null) { + id = nextId; + nextId++; + idMap.put(slice, id); + } + return id; + } + }; + } + }, + sliceOutput + ); + + // Serialize Footer + int footerBytes = new Footer(idMap.inverse()).serialize(sliceOutput); + + // Write length of Footer + sliceOutput.writeInt(footerBytes); + } + + public static Iterable deserialize(final SliceInput sliceInput) { + // Get map serialized byte length from tail and reset to beginning + int totalBytes = sliceInput.available(); + sliceInput.skipBytes(totalBytes - SizeOf.SIZE_OF_INT); + int idMapByteLength = sliceInput.readInt(); + + // Slice out Footer data and extract it + sliceInput.setPosition(totalBytes - idMapByteLength - SizeOf.SIZE_OF_INT); + Footer footer = Footer.deserialize(sliceInput.readSlice(idMapByteLength).input()); + + final Map idMap = footer.getIdMap(); + + sliceInput.setPosition(0); + final SliceInput paylodSliceInput = + sliceInput.readSlice(totalBytes - idMapByteLength - SizeOf.SIZE_OF_INT) + .input(); + return new Iterable() { + @Override + public Iterator iterator() { + return new AbstractIterator() { + Iterator iterator = PackedLongSerde.deserialize(paylodSliceInput).iterator(); + + @Override + protected Slice computeNext() { + if (!iterator.hasNext()) { + return endOfData(); + } + Slice slice = idMap.get(iterator.next()); + Preconditions.checkNotNull(slice, "Missing entry in dictionary"); + return slice; + } + }; + } + }; + } + + // TODO: this encoding can be made more compact if we leverage sorted order of the map + private static class Footer + { + Map idMap; + + private Footer(Map idMap) + { + this.idMap = idMap; + } + + /** + * Serialize this Footer to the specified SliceOutput + * + * @param sliceOutput + * @return bytes written to sliceOutput + */ + private int serialize(SliceOutput sliceOutput) + { + int startBytesWriteable = sliceOutput.writableBytes(); + for (Map.Entry entry : idMap.entrySet()) { + // Write ID number + sliceOutput.writeLong(entry.getKey()); + // Write Slice length + sliceOutput.writeInt(entry.getValue().length()); + // Write Slice + sliceOutput.writeBytes(entry.getValue()); + } + int endBytesWriteable = sliceOutput.writableBytes(); + return startBytesWriteable - endBytesWriteable; + } + + private static Footer deserialize(SliceInput sliceInput) + { + ImmutableBiMap.Builder builder = ImmutableBiMap.builder(); + + while (sliceInput.isReadable()) { + // Read Slice ID number + long id = sliceInput.readLong(); + // Read Slice Length + int sliceLength = sliceInput.readInt(); + // Read Slice + Slice slice = sliceInput.readSlice(sliceLength); + + builder.put(id, slice); + } + + return new Footer(builder.build()); + } + + public Map getIdMap() + { + return idMap; + } + } +} diff --git a/src/main/java/com/facebook/presto/PackedLongSerde.java b/src/main/java/com/facebook/presto/PackedLongSerde.java new file mode 100644 index 0000000000000..d451cbdc36f7c --- /dev/null +++ b/src/main/java/com/facebook/presto/PackedLongSerde.java @@ -0,0 +1,154 @@ +package com.facebook.presto; + +import com.google.common.base.Preconditions; +import com.google.common.collect.AbstractIterator; +import com.google.common.collect.Range; +import com.google.common.collect.Ranges; + +import java.util.Iterator; + +public class PackedLongSerde +{ + private final byte bitWidth; + private final Range allowedRange; + + public PackedLongSerde(int bitWidth) + { + Preconditions.checkArgument(bitWidth > 0 && bitWidth <= Long.SIZE); + this.bitWidth = (byte) bitWidth; + this.allowedRange = Ranges.closed(-1L << (bitWidth - 1), ~(-1L << (bitWidth - 1))); + } + + public void serialize(Iterable items, SliceOutput sliceOutput) + { + int packCapacity = Long.SIZE / bitWidth; + long mask = (~0L) >>> (Long.SIZE - bitWidth); + + // Write the packed longs + int itemCount = 0; + Iterator iter = items.iterator(); + while (iter.hasNext()) { + long pack = 0; + boolean packUsed = false; + for (int idx = 0; idx < packCapacity; idx++) { + if (!iter.hasNext()) { + break; + } + + long rawValue = iter.next(); + Preconditions.checkArgument( + allowedRange.contains(rawValue), + "Provided value does not fit into bitspace" + ); + long maskedValue = rawValue & mask; + itemCount++; + + pack |= maskedValue << (bitWidth * idx); + packUsed = true; + } + if (packUsed) { + sliceOutput.writeLong(pack); + } + } + + // Write the Footer + new Footer(itemCount, bitWidth).serialize(sliceOutput); + } + + public static Iterable deserialize(final SliceInput sliceInput) + { + Preconditions.checkArgument( + sliceInput.available() >= Footer.BYTE_SIZE, + "sliceInput not large enough to read a footer" + ); + Preconditions.checkArgument( + (sliceInput.available() - Footer.BYTE_SIZE) % (SizeOf.SIZE_OF_LONG) == 0, + "sliceInput byte alignment incorrect" + ); + + // Extract Footer and then reset slice cursor + int totalBytes = sliceInput.available(); + sliceInput.skipBytes(totalBytes - Footer.BYTE_SIZE); + final Footer footer = Footer.deserialize(sliceInput.readSlice(Footer.BYTE_SIZE).input()); + sliceInput.setPosition(0); + + final int packCapacity = Long.SIZE / footer.getBitWidth(); + + return new Iterable() + { + @Override + public Iterator iterator() + { + return new AbstractIterator() + { + int itemIdx = 0; + int packInternalIdx = 0; + long packValue = 0; + + @Override + protected Long computeNext() + { + if (itemIdx >= footer.getItemCount()) { + return endOfData(); + } + if (packInternalIdx == 0) { + packValue = sliceInput.readLong(); + } + // TODO: replace with something more efficient (but needs sign extend) + long value = (packValue << (Long.SIZE - ((packInternalIdx + 1) * footer.getBitWidth()))) >> (Long.SIZE - footer.getBitWidth()); + + itemIdx++; + packInternalIdx = (packInternalIdx + 1) % packCapacity; + return value; + } + }; + } + }; + } + + private static class Footer + { + private static final int BYTE_SIZE = SizeOf.SIZE_OF_INT + SizeOf.SIZE_OF_BYTE; + + private final int itemCount; + private final byte bitWidth; + + private Footer(int itemCount, byte bitWidth) + { + Preconditions.checkArgument(itemCount >= 0, "itemCount must be non-negative"); + Preconditions.checkArgument(bitWidth > 0, "bitWidth must be greater than zero"); + this.itemCount = itemCount; + this.bitWidth = bitWidth; + } + + /** + * Serialize this Header into the specified SliceOutput + * + * @param sliceOutput + * @return bytes written to sliceOutput + */ + public int serialize(SliceOutput sliceOutput) + { + sliceOutput.writeInt(itemCount); + sliceOutput.writeByte(bitWidth); + return BYTE_SIZE; + } + + public static Footer deserialize(SliceInput sliceInput) + { + int itemCount = sliceInput.readInt(); + byte bitWidth = sliceInput.readByte(); + return new Footer(itemCount, bitWidth); + } + + public int getItemCount() + { + return itemCount; + } + + public byte getBitWidth() + { + return bitWidth; + } + } +} diff --git a/src/test/java/com/facebook/presto/TestDictionarySerde.java b/src/test/java/com/facebook/presto/TestDictionarySerde.java new file mode 100644 index 0000000000000..d25a141a1d3ad --- /dev/null +++ b/src/test/java/com/facebook/presto/TestDictionarySerde.java @@ -0,0 +1,94 @@ +package com.facebook.presto; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import org.testng.Assert; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Collections; +import java.util.List; + +public class TestDictionarySerde { + private SliceOutput sliceOutput; + private DictionarySerde dictionarySerde; + + @BeforeMethod(alwaysRun = true) + public void setUp() throws Exception { + sliceOutput = new DynamicSliceOutput(1024); + dictionarySerde = new DictionarySerde(); + } + + @Test + public void testSanity() throws Exception { + List slices = slicesFromStrings("a", "b", "cde", "fuu", "a", "fuu"); + dictionarySerde.serialize(slices, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + slices, + DictionarySerde.deserialize(sliceOutput.slice().input()) + ) + ); + } + + @Test + public void testEmpty() throws Exception { + List slices = Collections.EMPTY_LIST; + dictionarySerde.serialize(slices, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + slices, + DictionarySerde.deserialize(sliceOutput.slice().input()) + ) + ); + } + + @Test + public void testAllSame() throws Exception { + List slices = slicesFromStrings("a", "a", "a", "a", "a", "a", "a"); + dictionarySerde.serialize(slices, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + slices, + DictionarySerde.deserialize(sliceOutput.slice().input()) + ) + ); + } + + @Test + public void testAllUnique() throws Exception { + List slices = slicesFromStrings("a", "b", "c", "d", "e", "f", "g"); + dictionarySerde.serialize(slices, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + slices, + DictionarySerde.deserialize(sliceOutput.slice().input()) + ) + ); + } + + @Test + public void testSmallCardinality() throws Exception { + List slices = slicesFromStrings("a", "b", "c", "a", "c", "d", "c", "b"); + dictionarySerde = new DictionarySerde(4); + dictionarySerde.serialize(slices, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + slices, + DictionarySerde.deserialize(sliceOutput.slice().input()) + ) + ); + } + + private List slicesFromStrings(String... strs) { + ImmutableList.Builder builder = ImmutableList.builder(); + for (String str : strs) { + builder.add(sliceFromString(str)); + } + return builder.build(); + } + + private ByteArraySlice sliceFromString(String str) { + return new ByteArraySlice(str.getBytes()); + } +} diff --git a/src/test/java/com/facebook/presto/TestPackedLongSerde.java b/src/test/java/com/facebook/presto/TestPackedLongSerde.java new file mode 100644 index 0000000000000..270c1980fba3d --- /dev/null +++ b/src/test/java/com/facebook/presto/TestPackedLongSerde.java @@ -0,0 +1,79 @@ +package com.facebook.presto; + +import com.google.common.collect.Iterables; +import org.testng.Assert; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +public class TestPackedLongSerde { + private SliceOutput sliceOutput; + + @BeforeMethod(alwaysRun = true) + public void setUp() throws Exception { + sliceOutput = new DynamicSliceOutput(128); + } + + @Test + public void testFullLong() throws Exception { + List list = Arrays.asList(0L, -1L, 2L, Long.MAX_VALUE, Long.MIN_VALUE); + new PackedLongSerde(Long.SIZE).serialize(list, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + PackedLongSerde.deserialize(sliceOutput.slice().input()), + list + ) + ); + } + + @Test + public void testLowDensity() throws Exception { + List list = Arrays.asList(0L, -1L, 2L, Long.MAX_VALUE/2, Long.MIN_VALUE/2); + new PackedLongSerde(Long.SIZE - 1).serialize(list, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + PackedLongSerde.deserialize(sliceOutput.slice().input()), + list + ) + ); + } + + @Test + public void testAligned() throws Exception { + List list = Arrays.asList(0L, -1L, 2L,(long) Integer.MAX_VALUE, (long) Integer.MIN_VALUE); + new PackedLongSerde(Integer.SIZE).serialize(list, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + PackedLongSerde.deserialize(sliceOutput.slice().input()), + list + ) + ); + } + + @Test + public void testUnaligned() throws Exception { + List list = Arrays.asList(0L, -1L, 2L, 65535L, -65536L, 64L, -3L); + new PackedLongSerde(17).serialize(list, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + PackedLongSerde.deserialize(sliceOutput.slice().input()), + list + ) + ); + } + + @Test + public void testEmpty() throws Exception { + List list = Collections.EMPTY_LIST; + new PackedLongSerde(1).serialize(list, sliceOutput); + Assert.assertTrue( + Iterables.elementsEqual( + PackedLongSerde.deserialize(sliceOutput.slice().input()), + list + ) + ); + } +}