diff --git a/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java b/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java index c53061500e68e..36051189bb989 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java +++ b/presto-common/src/main/java/com/facebook/presto/common/block/DictionaryBlock.java @@ -264,14 +264,25 @@ public long getRegionLogicalSizeInBytes(int positionOffset, int length) } long sizeInBytes = 0; - long[] seenSizes = new long[dictionary.getPositionCount()]; - Arrays.fill(seenSizes, -1L); - for (int i = positionOffset; i < positionOffset + length; i++) { - int position = getId(i); - if (seenSizes[position] < 0) { - seenSizes[position] = dictionary.getRegionLogicalSizeInBytes(position, 1); + // Dictionary Block may contain large number of keys and small region length may be requested. + // If the length is less than keys the cache is likely to be not used. + if (length > dictionary.getPositionCount()) { + // Cache code path. + long[] seenSizes = new long[dictionary.getPositionCount()]; + Arrays.fill(seenSizes, -1L); + for (int i = positionOffset; i < positionOffset + length; i++) { + int position = getId(i); + if (seenSizes[position] < 0) { + seenSizes[position] = dictionary.getRegionLogicalSizeInBytes(position, 1); + } + sizeInBytes += seenSizes[position]; + } + } + else { + // In-place code path. + for (int i = positionOffset; i < positionOffset + length; i++) { + sizeInBytes += dictionary.getRegionLogicalSizeInBytes(getId(i), 1); } - sizeInBytes += seenSizes[position]; } if (positionOffset == 0 && length == getPositionCount()) { diff --git a/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java b/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java index 6638986e9524d..8af1ce5d0c87c 100644 --- a/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java +++ b/presto-main/src/test/java/com/facebook/presto/block/TestDictionaryBlock.java @@ -22,11 +22,15 @@ import io.airlift.slice.Slice; import org.testng.annotations.Test; +import java.util.Arrays; + import static com.facebook.presto.block.BlockAssertions.createRLEBlock; import static com.facebook.presto.block.BlockAssertions.createRandomDictionaryBlock; import static com.facebook.presto.block.BlockAssertions.createRandomLongsBlock; import static com.facebook.presto.block.BlockAssertions.createSlicesBlock; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static io.airlift.slice.SizeOf.SIZE_OF_INT; +import static io.airlift.slice.Slices.utf8Slice; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNotEquals; @@ -44,6 +48,44 @@ public void testSizeInBytes() assertEquals(dictionaryBlock.getSizeInBytes(), dictionaryBlock.getDictionary().getSizeInBytes() + (100 * SIZE_OF_INT)); } + @Test + public void testNonCachedLogicalBytes() + { + int numEntries = 10; + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, numEntries); + + // Over allocate dictionary indexes but only use the required limit. + int[] dictionaryIndexes = new int[numEntries + 10]; + Arrays.fill(dictionaryIndexes, 1); + blockBuilder.appendNull(); + dictionaryIndexes[0] = 0; + + String string = ""; + for (int i = 1; i < numEntries; i++) { + string += "a"; + VARCHAR.writeSlice(blockBuilder, utf8Slice(string)); + dictionaryIndexes[i] = numEntries - i; + } + + // A dictionary block of size 10, 1st element -> null, 2nd element size -> 9....9th element size -> 1 + // Pass different maxChunkSize and different offset and verify if it computes the chunk lengths correctly. + Block elementBlock = blockBuilder.build(); + DictionaryBlock block = new DictionaryBlock(numEntries, elementBlock, dictionaryIndexes); + int elementSize = Integer.BYTES + Byte.BYTES; + + long size = block.getRegionLogicalSizeInBytes(0, 1); + assertEquals(size, 0 + 1 * elementSize); + + size = block.getRegionLogicalSizeInBytes(0, numEntries); + assertEquals(size, 45 + numEntries * elementSize); + + size = block.getRegionLogicalSizeInBytes(1, 2); + assertEquals(size, 9 + 8 + 2 * elementSize); + + size = block.getRegionLogicalSizeInBytes(9, 1); + assertEquals(size, 1 + 1 * elementSize); + } + @Test public void testLogicalSizeInBytes() { diff --git a/presto-orc/src/main/java/com/facebook/presto/orc/writer/SliceDictionaryColumnWriter.java b/presto-orc/src/main/java/com/facebook/presto/orc/writer/SliceDictionaryColumnWriter.java index eb753a28b6425..37b6c215cdeb8 100644 --- a/presto-orc/src/main/java/com/facebook/presto/orc/writer/SliceDictionaryColumnWriter.java +++ b/presto-orc/src/main/java/com/facebook/presto/orc/writer/SliceDictionaryColumnWriter.java @@ -15,6 +15,7 @@ import com.facebook.presto.common.block.Block; import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.block.DictionaryId; import com.facebook.presto.common.type.Type; import com.facebook.presto.orc.DwrfDataEncryptor; import com.facebook.presto.orc.OrcEncoding; @@ -29,6 +30,7 @@ import com.facebook.presto.orc.stream.LongOutputStream; import com.facebook.presto.orc.stream.PresentOutputStream; import com.facebook.presto.orc.stream.StreamDataOutput; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import io.airlift.slice.Slice; import io.airlift.units.DataSize; @@ -92,6 +94,23 @@ public int getDictionaryEntries() return dictionary.getEntryCount(); } + @VisibleForTesting + static int getChunkLength(int offset, int[] dictionaryIndexes, int positionCount, Block elementBlock, int maxChunkSize) + { + int endOffset = offset; + long size = elementBlock.getSliceLength(dictionaryIndexes[endOffset++]); + while (endOffset < positionCount) { + // getSliceLength does not include the nulls and length array. But this + // is a heuristic to avoid too much memory allocation, so this is fine. + size += elementBlock.getSliceLength(dictionaryIndexes[endOffset]); + if (size > maxChunkSize) { + break; + } + endOffset++; + } + return endOffset - offset; + } + @Override protected boolean tryConvertToDirect(int dictionaryIndexCount, IntBigArray dictionaryIndexes, int maxDirectBytes) { @@ -99,30 +118,20 @@ protected boolean tryConvertToDirect(int dictionaryIndexCount, IntBigArray dicti for (int i = 0; dictionaryIndexCount > 0 && i < segments.length; i++) { int[] segment = segments[i]; int positionCount = Math.min(dictionaryIndexCount, segment.length); - Block block = new DictionaryBlock(positionCount, dictionary.getElementBlock(), segment); - - while (block != null) { - int chunkPositionCount = block.getPositionCount(); - Block chunk = block.getRegion(0, chunkPositionCount); - - // avoid chunk with huge logical size - while (chunkPositionCount > 1 && chunk.getLogicalSizeInBytes() > DIRECT_CONVERSION_CHUNK_MAX_LOGICAL_BYTES) { - chunkPositionCount /= 2; - chunk = chunk.getRegion(0, chunkPositionCount); - } - + Block elementBlock = dictionary.getElementBlock(); + DictionaryId dictionaryId = DictionaryId.randomDictionaryId(); + + int offset = 0; + while (offset < positionCount) { + // Dictionary can contain large values that are repeated. In such a case, the conversion will be abandoned + // due to maxDirectBytes. To avoid allocating too much memory on those cases, process the dictionary in chunks. + int length = getChunkLength(offset, segment, positionCount, elementBlock, DIRECT_CONVERSION_CHUNK_MAX_LOGICAL_BYTES); + Block chunk = new DictionaryBlock(offset, length, elementBlock, segment, false, dictionaryId); + offset += length; directColumnWriter.writeBlock(chunk); if (directColumnWriter.getBufferedBytes() > maxDirectBytes) { return false; } - - // slice block to only unconverted rows - if (chunkPositionCount < block.getPositionCount()) { - block = block.getRegion(chunkPositionCount, block.getPositionCount() - chunkPositionCount); - } - else { - block = null; - } } dictionaryIndexCount -= positionCount; diff --git a/presto-orc/src/test/java/com/facebook/presto/orc/writer/TestSliceDictionaryColumnWriter.java b/presto-orc/src/test/java/com/facebook/presto/orc/writer/TestSliceDictionaryColumnWriter.java new file mode 100644 index 0000000000000..f942d706f7eae --- /dev/null +++ b/presto-orc/src/test/java/com/facebook/presto/orc/writer/TestSliceDictionaryColumnWriter.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.orc.writer; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import org.testng.annotations.Test; + +import java.util.Arrays; + +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static io.airlift.slice.Slices.utf8Slice; +import static org.testng.Assert.assertEquals; + +public class TestSliceDictionaryColumnWriter +{ + @Test + public void testChunkLength() + { + int numEntries = 10; + BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(null, numEntries); + + // Over allocate dictionary indexes but only use the required limit. + int[] dictionaryIndexes = new int[numEntries + 10]; + Arrays.fill(dictionaryIndexes, 1); + blockBuilder.appendNull(); + dictionaryIndexes[0] = 0; + + String string = ""; + for (int i = 1; i < numEntries; i++) { + string += "a"; + VARCHAR.writeSlice(blockBuilder, utf8Slice(string)); + dictionaryIndexes[i] = numEntries - i; + } + + // A dictionary block of size 10, 1st element -> null, 2nd element size -> 9....9th element size -> 1 + // Pass different maxChunkSize and different offset and verify if it computes the chunk lengths correctly. + Block elementBlock = blockBuilder.build(); + int length = SliceDictionaryColumnWriter.getChunkLength(0, dictionaryIndexes, numEntries, elementBlock, 10); + assertEquals(length, 2); + + length = SliceDictionaryColumnWriter.getChunkLength(0, dictionaryIndexes, numEntries, elementBlock, 1_000_000); + assertEquals(length, numEntries); + + length = SliceDictionaryColumnWriter.getChunkLength(0, dictionaryIndexes, numEntries, elementBlock, 20); + assertEquals(length, 3); + + length = SliceDictionaryColumnWriter.getChunkLength(1, dictionaryIndexes, numEntries, elementBlock, 9 + 8 + 7); + assertEquals(length, 3); + + length = SliceDictionaryColumnWriter.getChunkLength(2, dictionaryIndexes, numEntries, elementBlock, 0); + assertEquals(length, 1); + + length = SliceDictionaryColumnWriter.getChunkLength(9, dictionaryIndexes, numEntries, elementBlock, 0); + assertEquals(length, 1); + } +}