diff --git a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java index 192b3a060ebd..ae783452f07e 100644 --- a/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/BigintGroupByHash.java @@ -162,32 +162,6 @@ public Work getGroupIds(Page page) return new GetGroupIdsWork(block); } - @Override - public boolean contains(int position, Page page) - { - Block block = page.getBlock(0); - if (block.isNull(position)) { - return nullGroupId >= 0; - } - - long value = BIGINT.getLong(block, position); - int hashPosition = getHashPosition(value, mask); - - // look for an empty slot or a slot containing this key - while (true) { - int groupId = groupIds[hashPosition]; - if (groupId == -1) { - return false; - } - if (value == values[hashPosition]) { - return true; - } - - // increment position and mask to handle wrap around - hashPosition = (hashPosition + 1) & mask; - } - } - @Override public long getRawHash(int groupId) { diff --git a/core/trino-main/src/main/java/io/trino/operator/ChannelSet.java b/core/trino-main/src/main/java/io/trino/operator/ChannelSet.java index 581a88ff60ae..dde56ee4ef20 100644 --- a/core/trino-main/src/main/java/io/trino/operator/ChannelSet.java +++ b/core/trino-main/src/main/java/io/trino/operator/ChannelSet.java @@ -13,37 +13,37 @@ */ package io.trino.operator; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; import io.trino.memory.context.LocalMemoryContext; -import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; -import io.trino.sql.gen.JoinCompiler; -import static io.trino.operator.GroupByHash.createGroupByHash; -import static io.trino.type.UnknownType.UNKNOWN; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION_NOT_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.FLAT; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL; +import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN; +import static io.trino.spi.function.InvocationConvention.simpleConvention; +import static io.trino.spi.type.BigintType.BIGINT; import static java.util.Objects.requireNonNull; public class ChannelSet { - private final GroupByHash hash; - private final boolean containsNull; + private final FlatSet set; - private ChannelSet(GroupByHash hash, boolean containsNull) + private ChannelSet(FlatSet set) { - this.hash = hash; - this.containsNull = containsNull; + this.set = set; } public long getEstimatedSizeInBytes() { - return hash.getEstimatedSize(); + return set.getEstimatedSize(); } public int size() { - return hash.getGroupCount(); + return set.size(); } public boolean isEmpty() @@ -53,67 +53,67 @@ public boolean isEmpty() public boolean containsNull() { - return containsNull; + return set.containsNull(); } - public boolean contains(int position, Page page) + public boolean contains(Block valueBlock, int position) { - return hash.contains(position, page); + return set.contains(valueBlock, position); } - public boolean contains(int position, Page page, long rawHash) + public boolean contains(Block valueBlock, int position, long rawHash) { - return hash.contains(position, page, rawHash); + return set.contains(valueBlock, position, rawHash); } public static class ChannelSetBuilder { - private final Type type; - private final OperatorContext operatorContext; - private final LocalMemoryContext localMemoryContext; - private final GroupByHash hash; + private final LocalMemoryContext memoryContext; + private final FlatSet set; - public ChannelSetBuilder(Type type, boolean hasPrecomputedHash, int expectedPositions, OperatorContext operatorContext, JoinCompiler joinCompiler, TypeOperators typeOperators) + public ChannelSetBuilder(Type type, TypeOperators typeOperators, LocalMemoryContext memoryContext) { - this.type = requireNonNull(type, "type is null"); - this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); - this.localMemoryContext = operatorContext.localUserMemoryContext(); - this.hash = createGroupByHash( - operatorContext.getSession(), - ImmutableList.of(type), - hasPrecomputedHash, - expectedPositions, - joinCompiler, - typeOperators, - this::updateMemoryReservation); + set = new FlatSet( + type, + typeOperators.getReadValueOperator(type, simpleConvention(FLAT_RETURN, BLOCK_POSITION_NOT_NULL)), + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, FLAT)), + typeOperators.getDistinctFromOperator(type, simpleConvention(FAIL_ON_NULL, FLAT, BLOCK_POSITION_NOT_NULL)), + typeOperators.getHashCodeOperator(type, simpleConvention(FAIL_ON_NULL, BLOCK_POSITION_NOT_NULL))); + this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); + this.memoryContext.setBytes(set.getEstimatedSize()); } public ChannelSet build() { - Page nullBlockPage = new Page(type.createBlockBuilder(null, 1, UNKNOWN.getFixedSize()).appendNull().build()); - boolean containsNull = hash.contains(0, nullBlockPage); - return new ChannelSet(hash, containsNull); + return new ChannelSet(set); } - public Work addPage(Page page) + public void addAll(Block valueBlock, Block hashBlock) { - // Just add the page to the pending work, which will be processed later. - return hash.addPage(page); - } - - public boolean updateMemoryReservation() - { - // If memory is not available, once we return, this operator will be blocked until memory is available. - localMemoryContext.setBytes(hash.getEstimatedSize()); - - // If memory is not available, inform the caller that we cannot proceed for allocation. - return operatorContext.isWaitingForMemory().isDone(); - } - - @VisibleForTesting - public int getCapacity() - { - return hash.getCapacity(); + if (valueBlock.getPositionCount() == 0) { + return; + } + + if (valueBlock instanceof RunLengthEncodedBlock rleBlock) { + if (hashBlock != null) { + set.add(rleBlock.getValue(), 0, BIGINT.getLong(hashBlock, 0)); + } + else { + set.add(rleBlock.getValue(), 0); + } + } + else if (hashBlock != null) { + for (int position = 0; position < valueBlock.getPositionCount(); position++) { + set.add(valueBlock, position, BIGINT.getLong(hashBlock, position)); + } + } + else { + for (int position = 0; position < valueBlock.getPositionCount(); position++) { + set.add(valueBlock, position); + } + } + + memoryContext.setBytes(set.getEstimatedSize()); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java index 637fab573d87..3ecd84470f2c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/FlatGroupByHash.java @@ -161,32 +161,6 @@ public Work getGroupIds(Page page) return new GetNonDictionaryGroupIdsWork(blocks); } - @Override - public boolean contains(int position, Page page) - { - return flatHash.contains(getBlocksForContainsPage(page), position); - } - - @Override - public boolean contains(int position, Page page, long hash) - { - return flatHash.contains(getBlocksForContainsPage(page), position, hash); - } - - private Block[] getBlocksForContainsPage(Page page) - { - // contains page only has the group by channels as the optional hash is passed directly - checkArgument(page.getChannelCount() == groupByChannelCount); - Block[] blocks = currentBlocks; - for (int i = 0; i < page.getChannelCount(); i++) { - blocks[i] = page.getBlock(i); - } - if (hasPrecomputedHash) { - blocks[blocks.length - 1] = null; - } - return blocks; - } - @VisibleForTesting @Override public int getCapacity() @@ -442,7 +416,7 @@ public AddRunLengthEncodedPageWork(Block[] blocks) { for (int i = 0; i < blocks.length; i++) { // GroupBy blocks are guaranteed to be RLE, but hash block might not be an RLE due to bugs - // use getSingleValueBlock here which for RLE is a no-op, but will still work if hash block is not RLE + // use getSingleValueBlock here, which for RLE is a no-op, but will still work if hash block is not RLE blocks[i] = blocks[i].getSingleValueBlock(0); } this.blocks = blocks; @@ -639,7 +613,7 @@ public GetRunLengthEncodedGroupIdsWork(Block[] blocks) positionCount = blocks[0].getPositionCount(); for (int i = 0; i < blocks.length; i++) { // GroupBy blocks are guaranteed to be RLE, but hash block might not be an RLE due to bugs - // use getSingleValueBlock here which for RLE is a no-op, but will still work if hash block is not RLE + // use getSingleValueBlock here, which for RLE is a no-op, but will still work if hash block is not RLE blocks[i] = blocks[i].getSingleValueBlock(0); } this.blocks = blocks; diff --git a/core/trino-main/src/main/java/io/trino/operator/FlatSet.java b/core/trino-main/src/main/java/io/trino/operator/FlatSet.java new file mode 100644 index 000000000000..5b5c298fdd28 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/FlatSet.java @@ -0,0 +1,399 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator; + +import com.google.common.base.Throwables; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.type.Type; + +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.VarHandle; + +import static io.airlift.slice.SizeOf.instanceSize; +import static io.airlift.slice.SizeOf.sizeOf; +import static io.trino.operator.VariableWidthData.EMPTY_CHUNK; +import static io.trino.operator.VariableWidthData.POINTER_SIZE; +import static io.trino.spi.StandardErrorCode.GENERIC_INSUFFICIENT_RESOURCES; +import static java.lang.Math.multiplyExact; +import static java.nio.ByteOrder.LITTLE_ENDIAN; +import static java.util.Objects.requireNonNull; + +final class FlatSet +{ + private static final int INSTANCE_SIZE = instanceSize(FlatSet.class); + + // See jdk.internal.util.ArraysSupport#SOFT_MAX_ARRAY_LENGTH for an explanation + private static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; + + // Hash table capacity must be a power of two and at least VECTOR_LENGTH + private static final int INITIAL_CAPACITY = 16; + + private static final int RECORDS_PER_GROUP_SHIFT = 10; + private static final int RECORDS_PER_GROUP = 1 << RECORDS_PER_GROUP_SHIFT; + private static final int RECORDS_PER_GROUP_MASK = RECORDS_PER_GROUP - 1; + + private static final int VECTOR_LENGTH = Long.BYTES; + private static final VarHandle LONG_HANDLE = MethodHandles.byteArrayViewVarHandle(long[].class, LITTLE_ENDIAN); + + private final Type type; + private final MethodHandle writeFlat; + private final MethodHandle hashFlat; + private final MethodHandle distinctFlatBlock; + private final MethodHandle hashBlock; + + private final int recordSize; + private final int recordValueOffset; + + private boolean hasNull; + + private int capacity; + private int mask; + + private byte[] control; + private byte[][] recordGroups; + private final VariableWidthData variableWidthData; + + private int size; + private int maxFill; + + public FlatSet( + Type type, + MethodHandle writeFlat, + MethodHandle hashFlat, + MethodHandle distinctFlatBlock, + MethodHandle hashBlock) + { + this.type = requireNonNull(type, "type is null"); + + this.writeFlat = requireNonNull(writeFlat, "writeFlat is null"); + this.hashFlat = requireNonNull(hashFlat, "hashFlat is null"); + this.distinctFlatBlock = requireNonNull(distinctFlatBlock, "distinctFlatBlock is null"); + this.hashBlock = requireNonNull(hashBlock, "hashBlock is null"); + + capacity = INITIAL_CAPACITY; + maxFill = calculateMaxFill(capacity); + mask = capacity - 1; + control = new byte[capacity + VECTOR_LENGTH]; + + boolean variableWidth = type.isFlatVariableWidth(); + variableWidthData = variableWidth ? new VariableWidthData() : null; + + recordValueOffset = (variableWidth ? POINTER_SIZE : 0); + recordSize = recordValueOffset + type.getFlatFixedSize(); + recordGroups = createRecordGroups(capacity, recordSize); + } + + private static byte[][] createRecordGroups(int capacity, int recordSize) + { + if (capacity < RECORDS_PER_GROUP) { + return new byte[][]{new byte[multiplyExact(capacity, recordSize)]}; + } + + byte[][] groups = new byte[(capacity + 1) >> RECORDS_PER_GROUP_SHIFT][]; + for (int i = 0; i < groups.length; i++) { + groups[i] = new byte[multiplyExact(RECORDS_PER_GROUP, recordSize)]; + } + return groups; + } + + public long getEstimatedSize() + { + return INSTANCE_SIZE + + sizeOf(control) + + (sizeOf(recordGroups[0]) * recordGroups.length) + + (variableWidthData == null ? 0 : variableWidthData.getRetainedSizeBytes()); + } + + public int size() + { + return size + (hasNull ? 1 : 0); + } + + public boolean containsNull() + { + return hasNull; + } + + public boolean contains(Block block, int position) + { + if (block.isNull(position)) { + return hasNull; + } + return getIndex(block, position, valueHashCode(block, position)) >= 0; + } + + public boolean contains(Block block, int position, long hash) + { + if (block.isNull(position)) { + return hasNull; + } + return getIndex(block, position, hash) >= 0; + } + + public void add(Block block, int position) + { + if (block.isNull(position)) { + hasNull = true; + return; + } + addNonNull(block, position, valueHashCode(block, position)); + } + + public void add(Block block, int position, long hash) + { + if (block.isNull(position)) { + hasNull = true; + return; + } + addNonNull(block, position, hash); + } + + private void addNonNull(Block block, int position, long hash) + { + int index = getIndex(block, position, hash); + if (index >= 0) { + return; + } + + index = -index - 1; + insert(index, block, position, hash); + size++; + if (size >= maxFill) { + rehash(); + } + } + + private int getIndex(Block block, int position, long hash) + { + byte hashPrefix = (byte) (hash & 0x7F | 0x80); + int bucket = bucket((int) (hash >> 7)); + + int step = 1; + long repeated = repeat(hashPrefix); + + while (true) { + final long controlVector = (long) LONG_HANDLE.get(control, bucket); + + int matchIndex = matchInVector(block, position, bucket, repeated, controlVector); + if (matchIndex >= 0) { + return matchIndex; + } + + int emptyIndex = findEmptyInVector(controlVector, bucket); + if (emptyIndex >= 0) { + return -emptyIndex - 1; + } + + bucket = bucket(bucket + step); + step += VECTOR_LENGTH; + } + } + + private int matchInVector(Block block, int position, int vectorStartBucket, long repeated, long controlVector) + { + long controlMatches = match(controlVector, repeated); + while (controlMatches != 0) { + int bucket = bucket(vectorStartBucket + (Long.numberOfTrailingZeros(controlMatches) >>> 3)); + if (valueNotDistinctFrom(bucket, block, position)) { + return bucket; + } + + controlMatches = controlMatches & (controlMatches - 1); + } + return -1; + } + + private int findEmptyInVector(long vector, int vectorStartBucket) + { + long controlMatches = match(vector, 0x00_00_00_00_00_00_00_00L); + if (controlMatches == 0) { + return -1; + } + int slot = Long.numberOfTrailingZeros(controlMatches) >>> 3; + return bucket(vectorStartBucket + slot); + } + + private void insert(int index, Block block, int position, long hash) + { + setControl(index, (byte) (hash & 0x7F | 0x80)); + + byte[] records = getRecords(index); + int recordOffset = getRecordOffset(index); + + // write value + byte[] variableWidthChunk = EMPTY_CHUNK; + int variableWidthChunkOffset = 0; + if (variableWidthData != null) { + int variableWidthLength = type.getFlatVariableWidthSize(block, position); + variableWidthChunk = variableWidthData.allocate(records, recordOffset, variableWidthLength); + variableWidthChunkOffset = VariableWidthData.getChunkOffset(records, recordOffset); + } + + try { + writeFlat.invokeExact(block, position, records, recordOffset + recordValueOffset, variableWidthChunk, variableWidthChunkOffset); + } + catch (Throwable throwable) { + Throwables.throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } + } + + private void setControl(int index, byte hashPrefix) + { + control[index] = hashPrefix; + if (index < VECTOR_LENGTH) { + control[index + capacity] = hashPrefix; + } + } + + private void rehash() + { + int oldCapacity = capacity; + byte[] oldControl = control; + byte[][] oldRecordGroups = recordGroups; + + long newCapacityLong = capacity * 2L; + if (newCapacityLong > MAX_ARRAY_SIZE) { + throw new TrinoException(GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries"); + } + + capacity = (int) newCapacityLong; + maxFill = calculateMaxFill(capacity); + mask = capacity - 1; + + control = new byte[capacity + VECTOR_LENGTH]; + recordGroups = createRecordGroups(capacity, recordSize); + + for (int oldIndex = 0; oldIndex < oldCapacity; oldIndex++) { + if (oldControl[oldIndex] != 0) { + byte[] oldRecords = oldRecordGroups[oldIndex >> RECORDS_PER_GROUP_SHIFT]; + int oldRecordOffset = getRecordOffset(oldIndex); + + long hash = valueHashCode(oldRecords, oldIndex); + byte hashPrefix = (byte) (hash & 0x7F | 0x80); + int bucket = bucket((int) (hash >> 7)); + + int step = 1; + while (true) { + final long controlVector = (long) LONG_HANDLE.get(control, bucket); + // values are already distinct, so just find the first empty slot + int emptyIndex = findEmptyInVector(controlVector, bucket); + if (emptyIndex >= 0) { + setControl(emptyIndex, hashPrefix); + + // copy full record including groupId and count + byte[] records = getRecords(emptyIndex); + int recordOffset = getRecordOffset(emptyIndex); + System.arraycopy(oldRecords, oldRecordOffset, records, recordOffset, recordSize); + break; + } + + bucket = bucket(bucket + step); + step += VECTOR_LENGTH; + } + } + } + } + + private int bucket(int hash) + { + return hash & mask; + } + + private byte[] getRecords(int index) + { + return recordGroups[index >> RECORDS_PER_GROUP_SHIFT]; + } + + private int getRecordOffset(int index) + { + return (index & RECORDS_PER_GROUP_MASK) * recordSize; + } + + private long valueHashCode(byte[] records, int index) + { + int recordOffset = getRecordOffset(index); + + try { + byte[] variableWidthChunk = EMPTY_CHUNK; + if (variableWidthData != null) { + variableWidthChunk = variableWidthData.getChunk(records, recordOffset); + } + + return (long) hashFlat.invokeExact( + records, + recordOffset + recordValueOffset, + variableWidthChunk); + } + catch (Throwable throwable) { + Throwables.throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } + } + + private long valueHashCode(Block right, int rightPosition) + { + try { + return (long) hashBlock.invokeExact(right, rightPosition); + } + catch (Throwable throwable) { + Throwables.throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } + } + + private boolean valueNotDistinctFrom(int leftPosition, Block right, int rightPosition) + { + byte[] leftRecords = getRecords(leftPosition); + int leftRecordOffset = getRecordOffset(leftPosition); + + byte[] leftVariableWidthChunk = EMPTY_CHUNK; + if (variableWidthData != null) { + leftVariableWidthChunk = variableWidthData.getChunk(leftRecords, leftRecordOffset); + } + + try { + return !(boolean) distinctFlatBlock.invokeExact( + leftRecords, + leftRecordOffset + recordValueOffset, + leftVariableWidthChunk, + right, + rightPosition); + } + catch (Throwable throwable) { + Throwables.throwIfUnchecked(throwable); + throw new RuntimeException(throwable); + } + } + + private static long repeat(byte value) + { + return ((value & 0xFF) * 0x01_01_01_01_01_01_01_01L); + } + + private static long match(long vector, long repeatedValue) + { + // HD 6-1 + long comparison = vector ^ repeatedValue; + return (comparison - 0x01_01_01_01_01_01_01_01L) & ~comparison & 0x80_80_80_80_80_80_80_80L; + } + + private static int calculateMaxFill(int capacity) + { + // The hash table uses a load factory of 15/16 + return (capacity / 16) * 15; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java index 02401a9a450e..6211901b165d 100644 --- a/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/GroupByHash.java @@ -15,6 +15,7 @@ import com.google.common.annotations.VisibleForTesting; import io.trino.Session; +import io.trino.annotation.NotThreadSafe; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.type.Type; @@ -27,6 +28,7 @@ import static io.trino.SystemSessionProperties.isFlatGroupByHash; import static io.trino.spi.type.BigintType.BIGINT; +@NotThreadSafe public interface GroupByHash { static GroupByHash createGroupByHash( @@ -79,13 +81,6 @@ static GroupByHash createGroupByHash( */ Work getGroupIds(Page page); - boolean contains(int position, Page page); - - default boolean contains(int position, Page page, long rawHash) - { - return contains(position, page); - } - long getRawHash(int groupId); @VisibleForTesting diff --git a/core/trino-main/src/main/java/io/trino/operator/HashSemiJoinOperator.java b/core/trino-main/src/main/java/io/trino/operator/HashSemiJoinOperator.java index f8b597edbf56..6bf7d9fe1b44 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HashSemiJoinOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/HashSemiJoinOperator.java @@ -174,7 +174,7 @@ public TransformationState process(Page inputPage) if (channelSet == null) { if (!channelSetFuture.isDone()) { - // This will materialize page but it shouldn't matter for the first page + // This will materialize page, but it shouldn't matter for the first page localMemoryContext.setBytes(inputPage.getSizeInBytes()); return blocked(asVoid(channelSetFuture)); } @@ -182,20 +182,20 @@ public TransformationState process(Page inputPage) channelSet = getFutureValue(channelSetFuture); localMemoryContext.setBytes(0); } - // use an effectively-final local variable instead of the non-final instance field inside of the loop + // use an effectively-final local variable instead of the non-final instance field inside the loop ChannelSet channelSet = requireNonNull(this.channelSet, "channelSet is null"); // create the block builder for the new boolean column // we know the exact size required for the block BlockBuilder blockBuilder = BOOLEAN.createFixedSizeBlockBuilder(inputPage.getPositionCount()); - Page probeJoinPage = inputPage.getLoadedPage(probeJoinChannel); - Block probeJoinNulls = probeJoinPage.getBlock(0).mayHaveNull() ? probeJoinPage.getBlock(0) : null; - Block hashBlock = probeHashChannel >= 0 ? inputPage.getBlock(probeHashChannel) : null; + Block probeBlock = inputPage.getBlock(probeJoinChannel).copyRegion(0, inputPage.getPositionCount()); + boolean probeMayHaveNull = probeBlock.mayHaveNull(); + Block hashBlock = probeHashChannel >= 0 ? inputPage.getBlock(probeHashChannel).copyRegion(0, inputPage.getPositionCount()) : null; // update hashing strategy to use probe cursor for (int position = 0; position < inputPage.getPositionCount(); position++) { - if (probeJoinNulls != null && probeJoinNulls.isNull(position)) { + if (probeMayHaveNull && probeBlock.isNull(position)) { if (channelSet.isEmpty()) { BOOLEAN.writeBoolean(blockBuilder, false); } @@ -207,10 +207,10 @@ public TransformationState process(Page inputPage) boolean contains; if (hashBlock != null) { long rawHash = BIGINT.getLong(hashBlock, position); - contains = channelSet.contains(position, probeJoinPage, rawHash); + contains = channelSet.contains(probeBlock, position, rawHash); } else { - contains = channelSet.contains(position, probeJoinPage); + contains = channelSet.contains(probeBlock, position); } if (!contains && channelSet.containsNull()) { blockBuilder.appendNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java index 5e27adc96fd0..5e1648045403 100644 --- a/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/MultiChannelGroupByHash.java @@ -219,31 +219,6 @@ public Work getGroupIds(Page page) return new GetNonDictionaryGroupIdsWork(page); } - @Override - public boolean contains(int position, Page page) - { - long rawHash = hashStrategy.hashRow(position, page); - return contains(position, page, rawHash); - } - - @Override - public boolean contains(int position, Page page, long rawHash) - { - int hashPosition = getHashPosition(rawHash, mask); - - // look for a slot containing this key - while (groupIdsByHash[hashPosition] != -1) { - if (positionNotDistinctFromCurrentRow(groupIdsByHash[hashPosition], hashPosition, position, page, (byte) rawHash, channels)) { - // found an existing slot for this key - return true; - } - // increment position and mask to handle wrap around - hashPosition = (hashPosition + 1) & mask; - } - - return false; - } - @VisibleForTesting @Override public int getCapacity() diff --git a/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java b/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java index c507c4253297..37370b4a58e6 100644 --- a/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/NoChannelGroupByHash.java @@ -58,12 +58,6 @@ public Work getGroupIds(Page page) return new CompletedWork<>(new int[page.getPositionCount()]); } - @Override - public boolean contains(int position, Page page) - { - throw new UnsupportedOperationException("NoChannelGroupByHash does not support getHashCollisions"); - } - @Override public long getRawHash(int groupId) { diff --git a/core/trino-main/src/main/java/io/trino/operator/SetBuilderOperator.java b/core/trino-main/src/main/java/io/trino/operator/SetBuilderOperator.java index cc462b3811fa..6345a68ccd95 100644 --- a/core/trino-main/src/main/java/io/trino/operator/SetBuilderOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/SetBuilderOperator.java @@ -13,7 +13,6 @@ */ package io.trino.operator; -import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.errorprone.annotations.ThreadSafe; @@ -23,7 +22,6 @@ import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinCompiler; import io.trino.sql.planner.plan.PlanNodeId; -import jakarta.annotation.Nullable; import java.util.Optional; @@ -124,15 +122,13 @@ public OperatorFactory duplicate() private final OperatorContext operatorContext; private final SetSupplier setSupplier; - private final int[] sourceChannels; + private final int setChannel; + private final int hashChannel; private final ChannelSetBuilder channelSetBuilder; private boolean finished; - @Nullable - private Work unfinishedWork; // The pending work for current page. - public SetBuilderOperator( OperatorContext operatorContext, SetSupplier setSupplier, @@ -145,20 +141,13 @@ public SetBuilderOperator( this.operatorContext = requireNonNull(operatorContext, "operatorContext is null"); this.setSupplier = requireNonNull(setSupplier, "setSupplier is null"); - if (hashChannel.isPresent()) { - this.sourceChannels = new int[] {setChannel, hashChannel.get()}; - } - else { - this.sourceChannels = new int[] {setChannel}; - } + this.setChannel = setChannel; + this.hashChannel = hashChannel.orElse(-1); + // Set builder has a single channel which goes in channel 0, if hash is present, add a hashBlock to channel 1 this.channelSetBuilder = new ChannelSetBuilder( setSupplier.getType(), - hashChannel.isPresent(), - expectedPositions, - requireNonNull(operatorContext, "operatorContext is null"), - requireNonNull(joinCompiler, "joinCompiler is null"), - requireNonNull(typeOperators, "typeOperators is null")); + requireNonNull(typeOperators, "typeOperators is null"), operatorContext.localUserMemoryContext()); } @Override @@ -190,9 +179,8 @@ public boolean isFinished() public boolean needsInput() { // Since SetBuilderOperator doesn't produce any output, the getOutput() - // method may never be called. We need to handle any unfinished work - // before addInput() can be called again. - return !finished && (unfinishedWork == null || processUnfinishedWork()); + // method may never be called. + return !finished; } @Override @@ -201,8 +189,7 @@ public void addInput(Page page) requireNonNull(page, "page is null"); checkState(!isFinished(), "Operator is already finished"); - unfinishedWork = channelSetBuilder.addPage(page.getColumns(sourceChannels)); - processUnfinishedWork(); + channelSetBuilder.addAll(page.getBlock(setChannel), hashChannel == -1 ? null : page.getBlock(hashChannel)); } @Override @@ -210,24 +197,4 @@ public Page getOutput() { return null; } - - private boolean processUnfinishedWork() - { - // Processes the unfinishedWork for this page by adding the data to the hash table. If this page - // can't be fully consumed (e.g. rehashing fails), the unfinishedWork will be left with non-empty value. - checkState(unfinishedWork != null, "unfinishedWork is empty"); - boolean done = unfinishedWork.process(); - if (done) { - unfinishedWork = null; - } - // We need to update the memory reservation again since the page builder memory may also be increasing. - channelSetBuilder.updateMemoryReservation(); - return done; - } - - @VisibleForTesting - public int getCapacity() - { - return channelSetBuilder.getCapacity(); - } } diff --git a/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java index b46e94e9d757..77ce46b3a4e7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/CyclingGroupByHash.java @@ -71,12 +71,6 @@ public Work getGroupIds(Page page) return new CompletedWork<>(groupIds); } - @Override - public boolean contains(int position, Page page) - { - throw new UnsupportedOperationException("Not yet supported"); - } - @Override public long getRawHash(int groupId) { diff --git a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java index 21af675ef11a..6e71503a2d15 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestGroupByHash.java @@ -44,13 +44,11 @@ import static io.trino.operator.GroupByHash.createGroupByHash; import static io.trino.operator.UpdateMemory.NOOP; import static io.trino.spi.type.BigintType.BIGINT; -import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.type.TypeTestUtils.getHashBlock; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; public class TestGroupByHash { @@ -121,9 +119,7 @@ public void testAddPage(GroupByHashType groupByHashType) assertEquals(groupByHash.getGroupCount(), tries == 0 ? value + 1 : MAX_GROUP_ID); // add the page again using get group ids and make sure the group count didn't change - Work work = groupByHash.getGroupIds(page); - work.process(); - int[] groupIds = work.getResult(); + int[] groupIds = getGroupIds(groupByHash, page); assertEquals(groupByHash.getGroupCount(), tries == 0 ? value + 1 : MAX_GROUP_ID); // verify the first position @@ -179,10 +175,7 @@ public void testDictionaryInputPage(GroupByHashType groupByHashType) assertEquals(groupByHash.getGroupCount(), 2); - Work work = groupByHash.getGroupIds(page); - work.process(); - int[] groupIds = work.getResult(); - + int[] groupIds = getGroupIds(groupByHash, page); assertEquals(groupByHash.getGroupCount(), 2); assertEquals(groupIds.length, 4); assertEquals(groupIds[0], 0); @@ -196,10 +189,12 @@ public void testNullGroup(GroupByHashType groupByHashType) { GroupByHash groupByHash = groupByHashType.createGroupByHash(); - Block block = createLongsBlock((Long) null); + Block block = createLongsBlock(0L, null); Block hashBlock = getHashBlock(ImmutableList.of(BIGINT), block); Page page = new Page(block, hashBlock); - groupByHash.addPage(page).process(); + // assign null a groupId (which is one since is it the second value added) + assertThat(getGroupIds(groupByHash, page)) + .containsExactly(0, 1); // Add enough values to force a rehash block = createLongSequenceBlock(1, 132748); @@ -209,11 +204,9 @@ public void testNullGroup(GroupByHashType groupByHashType) block = createLongsBlock((Long) null); hashBlock = getHashBlock(ImmutableList.of(BIGINT), block); - assertTrue(groupByHash.contains(0, new Page(block), BIGINT.getLong(hashBlock, 0))); - - block = createLongsBlock(0); - hashBlock = getHashBlock(ImmutableList.of(BIGINT), block); - assertFalse(groupByHash.contains(0, new Page(block), BIGINT.getLong(hashBlock, 0))); + // null groupId will be 0 (as set above) + assertThat(getGroupIds(groupByHash, new Page(block, hashBlock))) + .containsExactly(1); } @Test(dataProvider = "groupByHashType") @@ -226,9 +219,7 @@ public void testGetGroupIds(GroupByHashType groupByHashType) Block hashBlock = getHashBlock(ImmutableList.of(BIGINT), block); Page page = new Page(block, hashBlock); for (int addValuesTries = 0; addValuesTries < 10; addValuesTries++) { - Work work = groupByHash.getGroupIds(page); - work.process(); - int[] groupIds = work.getResult(); + int[] groupIds = getGroupIds(groupByHash, page); assertEquals(groupByHash.getGroupCount(), tries == 0 ? value + 1 : MAX_GROUP_ID); assertEquals(groupIds.length, 1); long groupId = groupIds[0]; @@ -245,9 +236,7 @@ public void testAppendTo(GroupByHashType groupByHashType) Block hashBlock = getHashBlock(ImmutableList.of(BIGINT), valuesBlock); GroupByHash groupByHash = groupByHashType.createGroupByHash(); - Work work = groupByHash.getGroupIds(new Page(valuesBlock, hashBlock)); - work.process(); - int[] groupIds = work.getResult(); + int[] groupIds = getGroupIds(groupByHash, new Page(valuesBlock, hashBlock)); for (int i = 0; i < valuesBlock.getPositionCount(); i++) { assertEquals(groupIds[i], i); } @@ -292,62 +281,6 @@ public void testAppendToMultipleTuplesPerGroup(GroupByHashType groupByHashType) BlockAssertions.assertBlockEquals(BIGINT, outputPage.getBlock(0), createLongSequenceBlock(0, 50)); } - @Test(dataProvider = "groupByHashType") - public void testContains(GroupByHashType groupByHashType) - { - Block valuesBlock = createLongSequenceBlock(0, 10); - Block hashBlock = getHashBlock(ImmutableList.of(BIGINT), valuesBlock); - GroupByHash groupByHash = groupByHashType.createGroupByHash(); - groupByHash.getGroupIds(new Page(valuesBlock, hashBlock)).process(); - - Block testBlock = createLongsBlock(3); - Block testHashBlock = getHashBlock(ImmutableList.of(BIGINT), testBlock); - assertTrue(groupByHash.contains(0, new Page(testBlock), BIGINT.getLong(testHashBlock, 0))); - - testBlock = createLongsBlock(VARCHAR_EXPECTED_REHASH); - testHashBlock = getHashBlock(ImmutableList.of(BIGINT), testBlock); - assertFalse(groupByHash.contains(0, new Page(testBlock), BIGINT.getLong(testHashBlock, 0))); - } - - @Test - public void testContainsMultipleColumns() - { - Block valuesBlock = BlockAssertions.createDoubleSequenceBlock(0, 10); - Block stringValuesBlock = createStringSequenceBlock(0, 10); - Block hashBlock = getHashBlock(ImmutableList.of(DOUBLE, VARCHAR), valuesBlock, stringValuesBlock); - GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(DOUBLE, VARCHAR), true, 100, JOIN_COMPILER, TYPE_OPERATORS, NOOP); - groupByHash.getGroupIds(new Page(valuesBlock, stringValuesBlock, hashBlock)).process(); - - Block testValuesBlock = BlockAssertions.createDoublesBlock((double) 3); - Block testStringValuesBlock = BlockAssertions.createStringsBlock("3"); - Block testHashBlock = getHashBlock(ImmutableList.of(DOUBLE, VARCHAR), testValuesBlock, testStringValuesBlock); - assertTrue(groupByHash.contains(0, new Page(testValuesBlock, testStringValuesBlock), BIGINT.getLong(testHashBlock, 0))); - assertTrue(groupByHash.contains(0, new Page(testValuesBlock, testStringValuesBlock))); - } - - @Test - public void testContainsMultipleVariableColumns() - { - Block valuesBlockStart = createLongSequenceBlock(0, 10); - Block stringValuesBlockA = createStringSequenceBlock(0, 10); - Block stringValuesBlockB = createStringSequenceBlock(10, 20); - Block stringValuesBlockC = createStringSequenceBlock(20, 30); - Block valuesBlockEnd = createLongSequenceBlock(90, 100); - Block hashBlock = getHashBlock(ImmutableList.of(BIGINT, VARCHAR, VARCHAR, VARCHAR, BIGINT), valuesBlockStart, stringValuesBlockA, stringValuesBlockB, stringValuesBlockC, valuesBlockEnd); - GroupByHash groupByHash = createGroupByHash(TEST_SESSION, ImmutableList.of(BIGINT, VARCHAR, VARCHAR, VARCHAR, BIGINT), true, 100, JOIN_COMPILER, TYPE_OPERATORS, NOOP); - Work groupIds = groupByHash.getGroupIds(new Page(valuesBlockStart, stringValuesBlockA, stringValuesBlockB, stringValuesBlockC, valuesBlockEnd, hashBlock)); - assertTrue(groupIds.process()); - - Block testValuesBlock = createLongsBlock((long) 3); - Block testStringValuesBlockA = BlockAssertions.createStringsBlock("3"); - Block testStringValuesBlockB = BlockAssertions.createStringsBlock("13"); - Block testStringValuesBlockC = BlockAssertions.createStringsBlock("23"); - Block testBlockEnd = createLongsBlock((long) 93); - Block testHashBlock = getHashBlock(ImmutableList.of(BIGINT, VARCHAR, VARCHAR, VARCHAR, BIGINT), testValuesBlock, testStringValuesBlockA, testStringValuesBlockB, testStringValuesBlockC, testBlockEnd); - assertTrue(groupByHash.contains(0, new Page(testValuesBlock, testStringValuesBlockA, testStringValuesBlockB, testStringValuesBlockC, testBlockEnd), BIGINT.getLong(testHashBlock, 0))); - assertTrue(groupByHash.contains(0, new Page(testValuesBlock, testStringValuesBlockA, testStringValuesBlockB, testStringValuesBlockC, testBlockEnd))); - } - @Test(dataProvider = "groupByHashType") public void testForceRehash(GroupByHashType groupByHashType) { @@ -360,8 +293,9 @@ public void testForceRehash(GroupByHashType groupByHashType) groupByHash.getGroupIds(new Page(valuesBlock, hashBlock)).process(); // Ensure that all groups are present in GroupByHash - for (int i = 0; i < valuesBlock.getPositionCount(); i++) { - assertTrue(groupByHash.contains(i, new Page(valuesBlock), BIGINT.getLong(hashBlock, i))); + int groupCount = groupByHash.getGroupCount(); + for (int groupId : getGroupIds(groupByHash, new Page(valuesBlock, hashBlock))) { + assertThat(groupId).isLessThan(groupCount); } } @@ -699,4 +633,12 @@ private static void assertGroupByHashWork(Page page, List types, Class // Compare by name since classes are private assertThat(work).isInstanceOf(clazz); } + + private static int[] getGroupIds(GroupByHash groupByHash, Page page) + { + Work work = groupByHash.getGroupIds(page); + work.process(); + int[] groupIds = work.getResult(); + return groupIds; + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java index 2dcf77ded807..c43a6c412aa9 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestHashSemiJoinOperator.java @@ -37,11 +37,8 @@ import static com.google.common.collect.Iterables.concat; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.testing.Assertions.assertGreaterThanOrEqual; import static io.trino.RowPagesBuilder.rowPagesBuilder; import static io.trino.SessionTestUtils.TEST_SESSION; -import static io.trino.operator.GroupByHashYieldAssertion.createPagesWithDistinctHashKeys; -import static io.trino.operator.GroupByHashYieldAssertion.finishOperatorWithYieldingGroupByHash; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -49,7 +46,6 @@ import static io.trino.testing.TestingTaskContext.createTaskContext; import static java.util.concurrent.Executors.newCachedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; -import static org.testng.Assert.assertEquals; @Test(singleThreaded = true) public class TestHashSemiJoinOperator @@ -217,36 +213,6 @@ public void testSemiJoinOnVarcharType(boolean hashEnabled) OperatorAssertion.assertOperatorEquals(joinOperatorFactory, driverContext, probeInput, expected, hashEnabled, ImmutableList.of(probeTypes.size())); } - @Test(dataProvider = "dataType") - public void testSemiJoinMemoryReservationYield(Type type) - { - // We only need the first column so we are creating the pages with hashEnabled false - List input = createPagesWithDistinctHashKeys(type, 5_000, 500); - - // create the operator - SetBuilderOperatorFactory setBuilderOperatorFactory = new SetBuilderOperatorFactory( - 1, - new PlanNodeId("test"), - type, - 0, - Optional.of(1), - 10, - new JoinCompiler(typeOperators), - typeOperators); - - // run test - GroupByHashYieldAssertion.GroupByHashYieldResult result = finishOperatorWithYieldingGroupByHash( - input, - type, - setBuilderOperatorFactory, - operator -> ((SetBuilderOperator) operator).getCapacity(), - 450_000); - - assertGreaterThanOrEqual(result.getYieldCount(), 4); - assertGreaterThanOrEqual(result.getMaxReservedBytes(), 20L << 19); - assertEquals(result.getOutput().stream().mapToInt(Page::getPositionCount).sum(), 0); - } - @Test(dataProvider = "hashEnabledValues") public void testBuildSideNulls(boolean hashEnabled) {