diff --git a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java index af91384cb602..f764b400142a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java @@ -22,14 +22,18 @@ import io.trino.spi.block.Block; import io.trino.spi.block.BlockBuilder; import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.AbstractLongType; import io.trino.spi.type.BigintType; import io.trino.spi.type.Type; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.ints.IntList; import org.openjdk.jol.info.ClassLayout; import java.util.Arrays; import java.util.List; +import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -40,6 +44,7 @@ import static io.trino.util.HashCollisionsEstimator.estimateNumberOfHashCollisions; import static it.unimi.dsi.fastutil.HashCommon.arraySize; import static it.unimi.dsi.fastutil.HashCommon.murmurHash3; +import static java.lang.Math.min; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -51,6 +56,14 @@ public class BigintGroupByHash private static final float FILL_RATIO = 0.75f; private static final List TYPES = ImmutableList.of(BIGINT); private static final List TYPES_WITH_RAW_HASH = ImmutableList.of(BIGINT, BIGINT); + private static final int BATCH_SIZE = 256; + /** + * Above that number of groups it is faster to use the batched way of execution. + * The allocation and method inlining overhead makes batched execution faster + * only for big number of groups. + * This number is an approximation based on microbenchmark results. + */ + private static final int ADD_GROUP_BATCH_THRESHOLD = 200_000; private final int hashChannel; private final boolean outputRawHash; @@ -243,7 +256,11 @@ private int putIfAbsent(int position, Block block) long value = BIGINT.getLong(block, position); int hashPosition = getHashPosition(value, mask); - // look for an empty slot or a slot containing this key + return putIfAbsent(value, hashPosition); + } + + private int putIfAbsent(long value, int hashPosition) + { while (true) { int groupId = groupIds[hashPosition]; if (groupId == -1) { @@ -271,10 +288,6 @@ private int addNewGroup(int hashPosition, long value) valuesByGroupId.set(groupId, value); groupIds[hashPosition] = groupId; - // increase capacity, if necessary - if (needRehash()) { - tryRehash(); - } return groupId; } @@ -302,22 +315,23 @@ private boolean tryRehash() int[] newGroupIds = new int[newCapacity]; Arrays.fill(newGroupIds, -1); - for (int groupId = 0; groupId < nextGroupId; groupId++) { - if (groupId == nullGroupId) { - continue; - } - long value = valuesByGroupId.get(groupId); + for (int i = 0; i < values.length; i++) { + long value = values[i]; + int groupId = groupIds[i]; - // find an empty slot for the address - int hashPosition = getHashPosition(value, newMask); - while (newGroupIds[hashPosition] != -1) { - hashPosition = (hashPosition + 1) & newMask; - hashCollisions++; - } + if (groupId != nullGroupId && groupId != -1) { + int hashPosition = getHashPosition(value, newMask); + + // find an empty slot for the address + while (newGroupIds[hashPosition] != -1) { + hashPosition = (hashPosition + 1) & newMask; + hashCollisions++; + } - // record the mapping - newValues[hashPosition] = value; - newGroupIds[hashPosition] = groupId; + // record the mapping + newValues[hashPosition] = value; + newGroupIds[hashPosition] = groupId; + } } mask = newMask; @@ -326,7 +340,7 @@ private boolean tryRehash() values = newValues; groupIds = newGroupIds; - this.valuesByGroupId.ensureCapacity(maxFill); + valuesByGroupId.ensureCapacity(maxFill); return true; } @@ -387,19 +401,27 @@ public boolean process() { int positionCount = block.getPositionCount(); checkState(lastPosition < positionCount, "position count out of bound"); - - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } - - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // get the group for the current row - putIfAbsent(lastPosition, block); - lastPosition++; + int remainingPositions = positionCount - lastPosition; + + long[] dummyGroupIds = new long[BATCH_SIZE]; + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } + + if (nextGroupId > ADD_GROUP_BATCH_THRESHOLD) { + batchedPutIfAbsent(block, lastPosition, batchSize, dummyGroupIds, 0); + } + else { + // The only advantage of batching in this path is not checking table capacity every iteration + for (int i = 0; i < batchSize; i++) { + putIfAbsent(lastPosition + i, block); + } + } + + lastPosition += batchSize; + remainingPositions -= batchSize; } return lastPosition == positionCount; } @@ -423,7 +445,7 @@ class AddDictionaryPageWork public AddDictionaryPageWork(DictionaryBlock block) { this.block = requireNonNull(block, "block is null"); - this.dictionary = block.getDictionary(); + dictionary = block.getDictionary(); updateDictionaryLookBack(dictionary); } @@ -433,18 +455,20 @@ public boolean process() int positionCount = block.getPositionCount(); checkState(lastPosition < positionCount, "position count out of bound"); - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + int remainingPositions = positionCount - lastPosition; + + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - int positionInDictionary = block.getId(lastPosition); - registerGroupId(dictionary, positionInDictionary); - lastPosition++; + for (int i = 0; i < batchSize; i++) { + int positionInDictionary = block.getId(lastPosition + i); + registerGroupId(dictionary, positionInDictionary); + } + lastPosition += batchSize; + remainingPositions -= batchSize; } return lastPosition == positionCount; } @@ -502,7 +526,7 @@ public Void getResult() class GetGroupIdsWork implements Work { - private final BlockBuilder blockBuilder; + private final long[] groupIds; private final Block block; private boolean finished; @@ -512,7 +536,7 @@ public GetGroupIdsWork(Block block) { this.block = requireNonNull(block, "block is null"); // we know the exact size required for the block - this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(block.getPositionCount()); + groupIds = new long[block.getPositionCount()]; } @Override @@ -522,18 +546,18 @@ public boolean process() checkState(lastPosition < positionCount, "position count out of bound"); checkState(!finished); - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + int remainingPositions = positionCount - lastPosition; + + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - // output the group id for this row - BIGINT.writeLong(blockBuilder, putIfAbsent(lastPosition, block)); - lastPosition++; + batchedPutIfAbsent(block, lastPosition, batchSize, groupIds, lastPosition); + + lastPosition += batchSize; + remainingPositions -= batchSize; } return lastPosition == positionCount; } @@ -544,7 +568,102 @@ public GroupByIdBlock getResult() checkState(lastPosition == block.getPositionCount(), "process has not yet finished"); checkState(!finished, "result has produced"); finished = true; - return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); + } + } + + private void batchedPutIfAbsent(Block block, int blockOffset, int batchSize, long[] batchGroupIds, int batchGroupIdOffset) + { + if (block.mayHaveNull()) { + batchedPutIfAbsentNullable(block, blockOffset, batchSize, batchGroupIds, batchGroupIdOffset); + } + else { + batchedPutIfAbsentNoNull(block, blockOffset, batchSize, batchGroupIds, batchGroupIdOffset); + } + } + + private void batchedPutIfAbsentNullable(Block block, int blockOffset, int batchSize, long[] batchGroupIds, int batchGroupIdOffset) + { + if (nullGroupId < 0) { + // This branch will be executed only until there is the first null, likely at most few times. + // This is done for two reasons: + // -Order of new groups must correspond to the order incoming rows + // -After the null group id is determined the code is more streamlined, thus more performant + for (int i = 0; i < batchSize; i++) { + batchGroupIds[batchGroupIdOffset + i] = putIfAbsent(blockOffset + i, block); + } + return; + } + + // Allocate assuming no null values to prevent resizing + IntList nonNulls = new IntArrayList(batchSize); + int[] hashPositions = new int[batchSize]; + + for (int i = 0; i < batchSize; i++) { + if (block.isNull(blockOffset + i)) { + batchGroupIds[batchGroupIdOffset + i] = nullGroupId; + } + else { + hashPositions[nonNulls.size()] = getHashPosition(BIGINT.getLong(block, blockOffset + i), mask); + nonNulls.add(i); + } + } + + for (int i = 0; i < nonNulls.size(); i++) { + int indexInBatch = nonNulls.getInt(i); + batchGroupIds[batchGroupIdOffset + indexInBatch] = groupIds[hashPositions[i]]; + } + + for (int i = 0; i < nonNulls.size(); i++) { + int indexInBatch = nonNulls.getInt(i); + if (batchGroupIds[batchGroupIdOffset + indexInBatch] >= 0) { + long value = BIGINT.getLong(block, blockOffset + indexInBatch); + long storedValue = values[hashPositions[i]]; + // Same as + // if (value != storedValue) + // batchGroupIds[batchGroupIdOffset + i] = -1; + // but without explicit branches + int match = value == storedValue ? 1 : 0; + batchGroupIds[batchGroupIdOffset + indexInBatch] = (batchGroupIds[batchGroupIdOffset + indexInBatch] + 1) * match - 1; + } + } + + for (int i = 0; i < nonNulls.size(); i++) { + int indexInBatch = nonNulls.getInt(i); + if (batchGroupIds[batchGroupIdOffset + indexInBatch] == -1) { + batchGroupIds[batchGroupIdOffset + indexInBatch] = putIfAbsent(BIGINT.getLong(block, blockOffset + indexInBatch), hashPositions[i]); + } + } + } + + private void batchedPutIfAbsentNoNull(Block block, int blockOffset, int batchSize, long[] batchGroupIds, int batchGroupIdOffset) + { + int[] hashPositions = new int[batchSize]; + for (int i = 0; i < batchSize; i++) { + hashPositions[i] = getHashPosition(BIGINT.getLong(block, blockOffset + i), mask); + } + + for (int i = 0; i < batchSize; i++) { + batchGroupIds[batchGroupIdOffset + i] = groupIds[hashPositions[i]]; + } + + for (int i = 0; i < batchSize; i++) { + if (batchGroupIds[batchGroupIdOffset + i] != -1) { + long value = BIGINT.getLong(block, blockOffset + i); + long storedValue = values[hashPositions[i]]; + // Same as + // if (value != storedValue) + // batchGroupIds[batchGroupIdOffset + i] = -1; + // but without explicit branches + int match = value == storedValue ? 1 : 0; + batchGroupIds[batchGroupIdOffset + i] = (batchGroupIds[batchGroupIdOffset + i] + 1) * match - 1; + } + } + + for (int i = 0; i < batchSize; i++) { + if (batchGroupIds[batchGroupIdOffset + i] == -1) { + batchGroupIds[batchGroupIdOffset + i] = putIfAbsent(BIGINT.getLong(block, blockOffset + i), hashPositions[i]); + } } } @@ -552,7 +671,7 @@ public GroupByIdBlock getResult() class GetDictionaryGroupIdsWork implements Work { - private final BlockBuilder blockBuilder; + private final long[] groupIds; private final Block dictionary; private final DictionaryBlock block; @@ -562,11 +681,10 @@ class GetDictionaryGroupIdsWork public GetDictionaryGroupIdsWork(DictionaryBlock block) { this.block = requireNonNull(block, "block is null"); - this.dictionary = block.getDictionary(); + dictionary = block.getDictionary(); updateDictionaryLookBack(dictionary); - // we know the exact size required for the block - this.blockBuilder = BIGINT.createFixedSizeBlockBuilder(block.getPositionCount()); + groupIds = new long[block.getPositionCount()]; } @Override @@ -576,19 +694,20 @@ public boolean process() checkState(lastPosition < positionCount, "position count out of bound"); checkState(!finished); - // needRehash() == false indicates we have reached capacity boundary and a rehash is needed. - // We can only proceed if tryRehash() successfully did a rehash. - if (needRehash() && !tryRehash()) { - return false; - } + int remainingPositions = positionCount - lastPosition; + + while (remainingPositions != 0) { + int batchSize = min(remainingPositions, BATCH_SIZE); + if (!ensureHashTableSize(batchSize)) { + return false; + } - // putIfAbsent will rehash automatically if rehash is needed, unless there isn't enough memory to do so. - // Therefore needRehash will not generally return true even if we have just crossed the capacity boundary. - while (lastPosition < positionCount && !needRehash()) { - int positionInDictionary = block.getId(lastPosition); - int groupId = registerGroupId(dictionary, positionInDictionary); - BIGINT.writeLong(blockBuilder, groupId); - lastPosition++; + for (int i = 0; i < batchSize; i++) { + int positionInDictionary = block.getId(lastPosition + i); + groupIds[lastPosition + i] = registerGroupId(dictionary, positionInDictionary); + } + lastPosition += batchSize; + remainingPositions -= batchSize; } return lastPosition == positionCount; } @@ -599,7 +718,7 @@ public GroupByIdBlock getResult() checkState(lastPosition == block.getPositionCount(), "process has not yet finished"); checkState(!finished, "result has produced"); finished = true; - return new GroupByIdBlock(nextGroupId, blockBuilder.build()); + return new GroupByIdBlock(nextGroupId, new LongArrayBlock(groupIds.length, Optional.empty(), groupIds)); } } @@ -654,6 +773,18 @@ public GroupByIdBlock getResult() } } + private boolean ensureHashTableSize(int batchSize) + { + int positionCountUntilRehash = maxFill - nextGroupId; + while (positionCountUntilRehash < batchSize) { + if (!tryRehash()) { + return false; + } + positionCountUntilRehash = maxFill - nextGroupId; + } + return true; + } + private static final class DictionaryLookBack { private final Block dictionary; @@ -662,7 +793,7 @@ private static final class DictionaryLookBack public DictionaryLookBack(Block dictionary) { this.dictionary = dictionary; - this.processed = new int[dictionary.getPositionCount()]; + processed = new int[dictionary.getPositionCount()]; Arrays.fill(processed, -1); } diff --git a/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHashOnSimulatedData.java b/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHashOnSimulatedData.java new file mode 100644 index 000000000000..067378139a1a --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/operator/BenchmarkGroupByHashOnSimulatedData.java @@ -0,0 +1,626 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slices; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.block.DictionaryBlock; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.CharType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; +import io.trino.spi.type.VarcharType; +import io.trino.sql.gen.JoinCompiler; +import io.trino.type.BlockTypeOperators; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OperationsPerInvocation; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.runner.RunnerException; +import org.testng.annotations.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Optional; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.jmh.Benchmarks.benchmark; +import static io.trino.operator.BenchmarkGroupByHashOnSimulatedData.AggregationDefinition.BIGINT_1K_GROUPS; +import static io.trino.operator.BenchmarkGroupByHashOnSimulatedData.AggregationDefinition.BIGINT_1M_GROUPS; +import static io.trino.operator.BenchmarkGroupByHashOnSimulatedData.AggregationDefinition.BIGINT_2_GROUPS; +import static io.trino.operator.BenchmarkGroupByHashOnSimulatedData.WorkType.GET_GROUPS; +import static io.trino.operator.UpdateMemory.NOOP; +import static java.util.Objects.requireNonNull; + +/** + * This class attempts to emulate aggregations done while running real-life queries. + * Some of the numbers here has been inspired by tpch benchmarks, however, + * there is no guarantee that results correlate with the benchmark itself. + */ +@State(Scope.Thread) +@OutputTimeUnit(TimeUnit.NANOSECONDS) +@Fork(3) +@Warmup(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MILLISECONDS) +@BenchmarkMode(Mode.AverageTime) +public class BenchmarkGroupByHashOnSimulatedData +{ + private static final int DEFAULT_POSITIONS = 10_000_000; + private static final int EXPECTED_GROUP_COUNT = 10_000; + private static final int DEFAULT_PAGE_SIZE = 8192; + private static final TypeOperators TYPE_OPERATORS = new TypeOperators(); + private static final BlockTypeOperators TYPE_OPERATOR_FACTORY = new BlockTypeOperators(TYPE_OPERATORS); + + private final JoinCompiler joinCompiler = new JoinCompiler(TYPE_OPERATORS); + + @Benchmark + @OperationsPerInvocation(DEFAULT_POSITIONS) + public Object groupBy(BenchmarkContext data) + { + GroupByHash groupByHash = GroupByHash.createGroupByHash( + data.getTypes(), + data.getChannels(), + Optional.empty(), + EXPECTED_GROUP_COUNT, + false, + joinCompiler, + TYPE_OPERATOR_FACTORY, + NOOP); + List results = addInputPages(groupByHash, data.getPages(), data.getWorkType()); + + ImmutableList.Builder pages = ImmutableList.builder(); + PageBuilder pageBuilder = new PageBuilder(groupByHash.getTypes()); + for (int groupId = 0; groupId < groupByHash.getGroupCount(); groupId++) { + pageBuilder.declarePosition(); + groupByHash.appendValuesTo(groupId, pageBuilder); + if (pageBuilder.isFull()) { + pages.add(pageBuilder.build()); + pageBuilder.reset(); + } + } + pages.add(pageBuilder.build()); + return ImmutableList.of(pages, results); // all the things that might get erased by the compiler + } + + @Test + public void testGroupBy() + { + BenchmarkGroupByHashOnSimulatedData benchmark = new BenchmarkGroupByHashOnSimulatedData(); + for (double nullChance : new double[] {0, .1, .5, .9}) { + for (AggregationDefinition query : AggregationDefinition.values()) { + BenchmarkContext data = new BenchmarkContext(GET_GROUPS, query, nullChance, 10_000); + data.setup(); + benchmark.groupBy(data); + } + } + } + + private List addInputPages(GroupByHash groupByHash, List pages, WorkType workType) + { + List results = new ArrayList<>(); + for (Page page : pages) { + if (workType == GET_GROUPS) { + Work work = groupByHash.getGroupIds(page); + boolean finished; + do { + finished = work.process(); + results.add(work.getResult()); + } + while (!finished); + } + else { + Work work = groupByHash.addPage(page); + boolean finished; + do { + finished = work.process(); + } + while (!finished); + } + } + + return results; + } + + public interface BlockWriter + { + void write(BlockBuilder blockBuilder, int positionCount, long randomSeed); + } + + public enum ColumnType + { + BIGINT(BigintType.BIGINT, (blockBuilder, positionCount, seed) -> { + Random r = new Random(seed); + for (int i = 0; i < positionCount; i++) { + blockBuilder.writeLong((r.nextLong() >>> 1)); // Only positives + } + }), + INT(IntegerType.INTEGER, (blockBuilder, positionCount, seed) -> { + Random r = new Random(seed); + for (int i = 0; i < positionCount; i++) { + blockBuilder.writeInt(r.nextInt()); + } + }), + DOUBLE(DoubleType.DOUBLE, (blockBuilder, positionCount, seed) -> { + Random r = new Random(seed); + for (int i = 0; i < positionCount; i++) { + blockBuilder.writeLong((r.nextLong() >>> 1)); // Only positives + } + }), + VARCHAR_25(VarcharType.VARCHAR, (blockBuilder, positionCount, seed) -> { + writeVarchar(blockBuilder, positionCount, seed, 25); + }), + VARCHAR_117(VarcharType.VARCHAR, (blockBuilder, positionCount, seed) -> { + writeVarchar(blockBuilder, positionCount, seed, 117); + }), + CHAR_1(CharType.createCharType(1), (blockBuilder, positionCount, seed) -> { + Random r = new Random(seed); + for (int i = 0; i < positionCount; i++) { + byte value = (byte) r.nextInt(); + while (value == ' ') { + value = (byte) r.nextInt(); + } + CharType.createCharType(1).writeSlice(blockBuilder, Slices.wrappedBuffer(value)); + } + }), + /**/; + + private static void writeVarchar(BlockBuilder blockBuilder, int positionCount, long seed, int maxLength) + { + Random r = new Random(seed); + + for (int i = 0; i < positionCount; i++) { + int length = 1 + r.nextInt(maxLength - 1); + byte[] bytes = new byte[length]; + r.nextBytes(bytes); + VarcharType.VARCHAR.writeSlice(blockBuilder, Slices.wrappedBuffer(bytes)); + } + } + + final Type type; + final BlockWriter blockWriter; + + ColumnType(Type type, BlockWriter blockWriter) + { + this.type = requireNonNull(type, "type is null"); + this.blockWriter = requireNonNull(blockWriter, "blockWriter is null"); + } + + public Type getType() + { + return type; + } + + public BlockWriter getBlockWriter() + { + return blockWriter; + } + } + + @State(Scope.Thread) + public static class BenchmarkContext + { + @Param + private WorkType workType; + + @Param + private AggregationDefinition query; + + @Param({"0", ".1", ".5", ".9"}) + private double nullChance; + + private final int positions; + private List pages; + private List types; + private int[] channels; + + public BenchmarkContext() + { + this.positions = DEFAULT_POSITIONS; + } + + public BenchmarkContext(WorkType workType, AggregationDefinition query, double nullChance, int positions) + { + this.workType = requireNonNull(workType, "workType is null"); + this.query = requireNonNull(query, "query is null"); + this.positions = positions; + this.nullChance = nullChance; + } + + @Setup + public void setup() + { + types = query.getChannels().stream() + .map(channel -> channel.columnType.type) + .collect(toImmutableList()); + channels = IntStream.range(0, query.getChannels().size()).toArray(); + pages = createPages(query); + } + + private List createPages(AggregationDefinition definition) + { + List result = new ArrayList<>(); + int channelCount = definition.getChannels().size(); + int pageSize = definition.pageSize; + int pageCount = positions / pageSize; + + Block[][] blocks = new Block[channelCount][]; + for (int i = 0; i < definition.getChannels().size(); i++) { + ChannelDefinition channel = definition.getChannels().get(i); + blocks[i] = channel.createBlocks(pageCount, pageSize, i, nullChance); + } + + for (int i = 0; i < pageCount; i++) { + int pageIndex = i; + Block[] pageBlocks = IntStream.range(0, channelCount) + .mapToObj(channel -> blocks[channel][pageIndex]) + .toArray(Block[]::new); + result.add(new Page(pageBlocks)); + } + + return result; + } + + public List getPages() + { + return pages; + } + + public List getTypes() + { + return types; + } + + public int[] getChannels() + { + return channels; + } + + public WorkType getWorkType() + { + return workType; + } + } + + public enum WorkType + { + ADD, + GET_GROUPS, + } + + public enum AggregationDefinition + { + BIGINT_2_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 2)), + BIGINT_10_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 10)), + BIGINT_1K_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 1000)), + BIGINT_10K_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 10_000)), + BIGINT_100K_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 100_000)), + BIGINT_1M_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 1_000_000)), + BIGINT_10M_GROUPS(new ChannelDefinition(ColumnType.BIGINT, 10_000_000)), + BIGINT_2_GROUPS_1_SMALL_DICTIONARY(new ChannelDefinition(ColumnType.BIGINT, 2, 1, 50)), + BIGINT_2_GROUPS_1_BIG_DICTIONARY(new ChannelDefinition(ColumnType.BIGINT, 2, 1, 10000)), + BIGINT_2_GROUPS_MULTIPLE_SMALL_DICTIONARY(new ChannelDefinition(ColumnType.BIGINT, 2, 10, 50)), + BIGINT_2_GROUPS_MULTIPLE_BIG_DICTIONARY(new ChannelDefinition(ColumnType.BIGINT, 2, 10, 10000)), + BIGINT_10K_GROUPS_1_DICTIONARY(new ChannelDefinition(ColumnType.BIGINT, 10000, 1, 20000)), + BIGINT_10K_GROUPS_MULTIPLE_DICTIONARY(new ChannelDefinition(ColumnType.BIGINT, 10000, 20, 20000)), + DOUBLE_10_GROUPS(new ChannelDefinition(ColumnType.DOUBLE, 10)), + TWO_TINY_VARCHAR_DICTIONARIES( + new ChannelDefinition(ColumnType.CHAR_1, 2, 10), + new ChannelDefinition(ColumnType.CHAR_1, 2, 10)), + FIVE_TINY_VARCHAR_DICTIONARIES( + new ChannelDefinition(ColumnType.CHAR_1, 2, 10), + new ChannelDefinition(ColumnType.CHAR_1, 2, 10), + new ChannelDefinition(ColumnType.CHAR_1, 2, 10), + new ChannelDefinition(ColumnType.CHAR_1, 2, 10), + new ChannelDefinition(ColumnType.CHAR_1, 2, 10)), + TWO_SMALL_VARCHAR_DICTIONARIES( + new ChannelDefinition(ColumnType.CHAR_1, 30, 10), + new ChannelDefinition(ColumnType.CHAR_1, 30, 10)), + TWO_SMALL_VARCHAR_DICTIONARIES_WITH_SMALL_PAGE_SIZE(// low cardinality optimisation will not kick in here + 1000, + new ChannelDefinition(ColumnType.CHAR_1, 30, 10), + new ChannelDefinition(ColumnType.CHAR_1, 30, 10)), + VARCHAR_2_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 2)), + VARCHAR_10_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 10)), + VARCHAR_1K_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 1000)), + VARCHAR_10K_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 10_000)), + VARCHAR_100K_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 100_000)), + VARCHAR_1M_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 1_000_000)), + VARCHAR_10M_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_25, 10_000_000)), + VARCHAR_2_GROUPS_1_SMALL_DICTIONARY(new ChannelDefinition(ColumnType.VARCHAR_25, 2, 1, 50)), + VARCHAR_2_GROUPS_1_BIG_DICTIONARY(new ChannelDefinition(ColumnType.VARCHAR_25, 2, 1, 10000)), + VARCHAR_2_GROUPS_MULTIPLE_SMALL_DICTIONARY(new ChannelDefinition(ColumnType.VARCHAR_25, 2, 10, 50)), + VARCHAR_2_GROUPS_MULTIPLE_BIG_DICTIONARY(new ChannelDefinition(ColumnType.VARCHAR_25, 2, 10, 10000)), + VARCHAR_10K_GROUPS_1_DICTIONARY(new ChannelDefinition(ColumnType.VARCHAR_25, 10000, 1, 20000)), + VARCHAR_10K_GROUPS_MULTIPLE_DICTIONARY(new ChannelDefinition(ColumnType.VARCHAR_25, 10000, 20, 20000)), + TINY_CHAR_10_GROUPS(new ChannelDefinition(ColumnType.CHAR_1, 10)), + BIG_VARCHAR_10_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_117, 10)), + BIG_VARCHAR_1M_GROUPS(new ChannelDefinition(ColumnType.VARCHAR_117, 1_000_000)), + DOUBLE_BIGINT_100_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 10), + new ChannelDefinition(ColumnType.BIGINT, 10)), + BIGINT_AND_TWO_INTS_5K( + new ChannelDefinition(ColumnType.BIGINT, 500), + new ChannelDefinition(ColumnType.INT, 10), + new ChannelDefinition(ColumnType.INT, 10)), + FIVE_MIXED_SHORT_COLUMNS_100_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 5), + new ChannelDefinition(ColumnType.INT, 5), + new ChannelDefinition(ColumnType.VARCHAR_25, 2), + new ChannelDefinition(ColumnType.INT, 1), + new ChannelDefinition(ColumnType.DOUBLE, 2)), + FIVE_MIXED_SHORT_COLUMNS_100K_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 5), + new ChannelDefinition(ColumnType.INT, 5), + new ChannelDefinition(ColumnType.VARCHAR_25, 20), + new ChannelDefinition(ColumnType.INT, 10), + new ChannelDefinition(ColumnType.DOUBLE, 20)), + FIVE_MIXED_LONG_COLUMNS_100_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 5), + new ChannelDefinition(ColumnType.VARCHAR_117, 5), + new ChannelDefinition(ColumnType.VARCHAR_25, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 1), + new ChannelDefinition(ColumnType.VARCHAR_117, 2)), + FIVE_MIXED_LONG_COLUMNS_100K_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 5), + new ChannelDefinition(ColumnType.VARCHAR_117, 5), + new ChannelDefinition(ColumnType.VARCHAR_25, 20), + new ChannelDefinition(ColumnType.VARCHAR_25, 10), + new ChannelDefinition(ColumnType.VARCHAR_117, 20)), + TEN_MIXED_SHORT_COLUMNS_100_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 1), + new ChannelDefinition(ColumnType.INT, 2), + new ChannelDefinition(ColumnType.BIGINT, 1), + new ChannelDefinition(ColumnType.INT, 5), + new ChannelDefinition(ColumnType.DOUBLE, 1), + new ChannelDefinition(ColumnType.BIGINT, 2), + new ChannelDefinition(ColumnType.INT, 1), + new ChannelDefinition(ColumnType.VARCHAR_25, 5), + new ChannelDefinition(ColumnType.INT, 1), + new ChannelDefinition(ColumnType.DOUBLE, 1)), + TEN_MIXED_SHORT_COLUMNS_100K_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 5), + new ChannelDefinition(ColumnType.INT, 2), + new ChannelDefinition(ColumnType.BIGINT, 2), + new ChannelDefinition(ColumnType.INT, 5), + new ChannelDefinition(ColumnType.DOUBLE, 5), + new ChannelDefinition(ColumnType.BIGINT, 2), + new ChannelDefinition(ColumnType.INT, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 5), + new ChannelDefinition(ColumnType.INT, 5), + new ChannelDefinition(ColumnType.DOUBLE, 2)), + TEN_MIXED_LONG_COLUMNS_100_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 1), + new ChannelDefinition(ColumnType.VARCHAR_117, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 1), + new ChannelDefinition(ColumnType.VARCHAR_117, 5), + new ChannelDefinition(ColumnType.DOUBLE, 1), + new ChannelDefinition(ColumnType.VARCHAR_25, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 1), + new ChannelDefinition(ColumnType.VARCHAR_25, 5), + new ChannelDefinition(ColumnType.VARCHAR_117, 1), + new ChannelDefinition(ColumnType.DOUBLE, 1)), + TEN_MIXED_LONG_COLUMNS_100K_GROUPS( + new ChannelDefinition(ColumnType.BIGINT, 5), + new ChannelDefinition(ColumnType.VARCHAR_117, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 2), + new ChannelDefinition(ColumnType.VARCHAR_117, 5), + new ChannelDefinition(ColumnType.DOUBLE, 5), + new ChannelDefinition(ColumnType.VARCHAR_25, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 2), + new ChannelDefinition(ColumnType.VARCHAR_25, 5), + new ChannelDefinition(ColumnType.VARCHAR_117, 5), + new ChannelDefinition(ColumnType.DOUBLE, 2)), + /**/; + + private final int pageSize; + private final List channels; + + AggregationDefinition(ChannelDefinition... channels) + { + this(DEFAULT_PAGE_SIZE, channels); + } + + AggregationDefinition(int pageSize, ChannelDefinition... channels) + { + this.pageSize = pageSize; + this.channels = Arrays.stream(requireNonNull(channels, "channels is null")).collect(toImmutableList()); + } + + public int getPageSize() + { + return pageSize; + } + + public List getChannels() + { + return channels; + } + } + + public static class ChannelDefinition + { + private final ColumnType columnType; + private final int distinctValuesCountInColumn; + private final int dictionaryPositionsCount; + private final int numberOfDistinctDictionaries; + + public ChannelDefinition(ColumnType columnType, int distinctValuesCountInColumn) + { + this(columnType, distinctValuesCountInColumn, -1, -1); + } + + public ChannelDefinition(ColumnType columnType, int distinctValuesCountInColumn, int numberOfDistinctDictionaries) + { + this(columnType, distinctValuesCountInColumn, numberOfDistinctDictionaries, distinctValuesCountInColumn); + } + + public ChannelDefinition(ColumnType columnType, int distinctValuesCountInColumn, int numberOfDistinctDictionaries, int dictionaryPositionsCount) + { + this.columnType = requireNonNull(columnType, "columnType is null"); + this.distinctValuesCountInColumn = distinctValuesCountInColumn; + this.dictionaryPositionsCount = dictionaryPositionsCount; + this.numberOfDistinctDictionaries = numberOfDistinctDictionaries; + checkArgument(dictionaryPositionsCount == -1 || dictionaryPositionsCount >= distinctValuesCountInColumn); + } + + public ColumnType getColumnType() + { + return columnType; + } + + public Block[] createBlocks(int blockCount, int positionsPerBlock, int channel, double nullChance) + { + Block[] blocks = new Block[blockCount]; + if (dictionaryPositionsCount == -1) { // No dictionaries + createNonDictionaryBlock(blockCount, positionsPerBlock, channel, nullChance, blocks); + } + else { + createDictionaryBlock(blockCount, positionsPerBlock, channel, nullChance, blocks); + } + return blocks; + } + + private void createDictionaryBlock(int blockCount, int positionsPerBlock, int channel, double nullChance, Block[] blocks) + { + Random r = new Random(channel); + + // All values that will be stored in dictionaries. Not all of them need to be used in blocks + BlockBuilder allValues = generateValues(channel, dictionaryPositionsCount); + if (nullChance > 0) { + allValues.appendNull(); + } + + Block[] dictionaries = new Block[numberOfDistinctDictionaries]; + // Generate 'numberOfDistinctDictionaries' dictionaries that are equal, but are not the same object. + // This way the optimization that caches dictionary results will not work + for (int i = 0; i < numberOfDistinctDictionaries; i++) { + dictionaries[i] = allValues.build(); + } + + int[] usedValues = nOutOfM(r, distinctValuesCountInColumn, dictionaryPositionsCount).stream() + .mapToInt(x -> x) + .toArray(); + + // Generate output blocks + for (int i = 0; i < blockCount; i++) { + int[] indexes = new int[positionsPerBlock]; + int dictionaryId = r.nextInt(numberOfDistinctDictionaries); + Block dictionary = dictionaries[dictionaryId]; + for (int j = 0; j < positionsPerBlock; j++) { + if (isNull(r, nullChance)) { + indexes[j] = dictionaryPositionsCount; // Last value in dictionary is null + } + else { + indexes[j] = usedValues[r.nextInt(usedValues.length)]; + } + } + + blocks[i] = new DictionaryBlock(dictionary, indexes); + } + } + + private void createNonDictionaryBlock(int blockCount, int positionsPerBlock, int channel, double nullChance, Block[] blocks) + { + BlockBuilder allValues = generateValues(channel, distinctValuesCountInColumn); + Random r = new Random(channel); + for (int i = 0; i < blockCount; i++) { + BlockBuilder block = columnType.getType().createBlockBuilder(null, positionsPerBlock); + for (int j = 0; j < positionsPerBlock; j++) { + if (isNull(r, nullChance)) { + block.appendNull(); + } + else { + int position = r.nextInt(distinctValuesCountInColumn); + columnType.getType().appendTo(allValues, position, block); + } + } + blocks[i] = block.build(); + } + } + + private BlockBuilder generateValues(int channel, int distinctValueCount) + { + BlockBuilder allValues = columnType.getType().createBlockBuilder(null, distinctValueCount); + columnType.getBlockWriter().write(allValues, distinctValueCount, channel); + return allValues; + } + + private static boolean isNull(Random random, double nullChance) + { + double value = 0; + // null chance has to be 0 to 1 exclusive. + while (value == 0) { + value = random.nextDouble(); + } + return value < nullChance; + } + + private Set nOutOfM(Random r, int n, int m) + { + Set usedValues = new HashSet<>(); + // Double loop for performance reasons + while (usedValues.size() < n) { + int left = n - usedValues.size(); + for (int i = 0; i < left; i++) { + usedValues.add(r.nextInt(m)); + } + } + return usedValues; + } + } + + static { + // Pollute JVM profile + BenchmarkGroupByHashOnSimulatedData benchmark = new BenchmarkGroupByHashOnSimulatedData(); + for (WorkType workType : WorkType.values()) { + for (double nullChance : new double[] {0, .1, .5, .9}) { + for (AggregationDefinition query : new AggregationDefinition[] {BIGINT_2_GROUPS, BIGINT_1K_GROUPS, BIGINT_1M_GROUPS}) { + BenchmarkContext context = new BenchmarkContext(workType, query, nullChance, 8000); + context.setup(); + benchmark.groupBy(context); + } + } + } + } + + public static void main(String[] args) + throws RunnerException + { + benchmark(BenchmarkGroupByHashOnSimulatedData.class) + .withOptions(optionsBuilder -> optionsBuilder + .jvmArgs("-Xmx8g")) + .run(); + } +}