diff --git a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java index a73d63b3a5d9..aebbdf8490d9 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java +++ b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java @@ -18,6 +18,7 @@ import io.trino.operator.join.LookupJoinOperatorFactory; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSourceFactory; +import io.trino.operator.join.unspilled.JoinProbe; import io.trino.spi.type.Type; import io.trino.spiller.PartitioningSpillerFactory; import io.trino.sql.planner.plan.PlanNodeId; @@ -58,6 +59,7 @@ public OperatorFactory innerJoin( operatorId, planNodeId, lookupSourceFactory, + hasFilter, probeTypes, probeJoinChannel, probeHashChannel, @@ -91,6 +93,7 @@ public OperatorFactory probeOuterJoin( operatorId, planNodeId, lookupSourceFactory, + hasFilter, probeTypes, probeJoinChannel, probeHashChannel, @@ -124,6 +127,7 @@ public OperatorFactory lookupOuterJoin( operatorId, planNodeId, lookupSourceFactory, + hasFilter, probeTypes, probeJoinChannel, probeHashChannel, @@ -156,6 +160,7 @@ public OperatorFactory fullOuterJoin( operatorId, planNodeId, lookupSourceFactory, + hasFilter, probeTypes, probeJoinChannel, probeHashChannel, @@ -180,6 +185,7 @@ private OperatorFactory createJoinOperatorFactory( int operatorId, PlanNodeId planNodeId, JoinBridgeManager lookupSourceFactoryManager, + boolean hasFilter, List probeTypes, List probeJoinChannel, OptionalInt probeHashChannel, @@ -225,7 +231,7 @@ private OperatorFactory createJoinOperatorFactory( joinType, outputSingleMatch, waitForBuild, - new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel), + new JoinProbe.JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel, hasFilter), blockTypeOperators, probeJoinChannel, probeHashChannel); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java index 84034a825f01..d3db6e6da1b4 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java @@ -26,6 +26,7 @@ import java.util.Arrays; import java.util.List; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.trino.operator.SyntheticAddress.decodePosition; @@ -53,6 +54,7 @@ public final class BigintPagesHash private final long hashCollisions; private final double expectedHashCollisions; + private final boolean uniqueMapping; public BigintPagesHash( LongArrayList addresses, @@ -82,6 +84,7 @@ public BigintPagesHash( // We will process addresses in batches, to save memory on array of hashes. int positionsInStep = Math.min(addresses.size() + 1, (int) CACHE_SIZE.toBytes() / Integer.SIZE); long hashCollisionsLocal = 0; + boolean uniqueMapping = true; for (int step = 0; step * positionsInStep <= addresses.size(); step++) { int stepBeginPosition = step * positionsInStep; @@ -110,6 +113,7 @@ public BigintPagesHash( // found a slot for this key // link the new key position to the current key position realPosition = positionLinks.link(realPosition, currentKey); + uniqueMapping = false; // key[pos] updated outside of this loop break; @@ -123,6 +127,7 @@ public BigintPagesHash( values[pos] = value; } } + this.uniqueMapping = uniqueMapping; size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + sizeOf(key) + sizeOf(values); @@ -176,6 +181,74 @@ public int getAddressIndex(int position, Page hashChannelsPage) return -1; } + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) + { + return getAddressIndex(positions, hashChannelsPage); + } + + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage) + { + checkArgument(hashChannelsPage.getChannelCount() == 1, "Non-signle channel page passed to BigintPagesHash"); + + int positionCount = positions.length; + long[] incomingValues = new long[positionCount]; + int[] hashPositions = new int[positionCount]; + + for (int i = 0; i < positionCount; i++) { + incomingValues[i] = hashChannelsPage.getBlock(0).getLong(positions[i], 0); + hashPositions[i] = getHashPosition(incomingValues[i], mask); + } + + int[] found = new int[positionCount]; + int foundCount = 0; + int[] result = new int[positionCount]; + Arrays.fill(result, -1); + int[] foundKeys = new int[positionCount]; + + // Search for positions in the hash array. The ones that were found are put into `found` array, + // while the `foundKeys` arrays holds the keys that has been read from the hash array + for (int i = 0; i < positionCount; i++) { + if (key[hashPositions[i]] != -1) { + found[foundCount] = i; + foundKeys[foundCount++] = key[hashPositions[i]]; + } + } + + // At this step we determine if the found keys were indeed the proper ones or it is a hash collision. + // The result array is updated for the found ones, while the collisions land into `remaining` array. + int[] remaining = found; // Rename for readability + int remainingCount = 0; + for (int i = 0; i < foundCount; i++) { + int index = found[i]; + if (values[hashPositions[index]] == incomingValues[index]) { + result[index] = foundKeys[i]; + } + else { + remaining[remainingCount++] = index; + } + } + + // At this point for any reasoable load factor of a hash array (< .75), there is no more than + // 10 - 15% of positions left. We search for them in a sequential order and update the result array. + for (int i = 0; i < remainingCount; i++) { + int index = remaining[i]; + int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked + + while (key[position] != -1) { + if (values[position] == incomingValues[index]) { + result[index] = key[position]; + break; + } + // increment position and mask to handler wrap around + position = (position + 1) & mask; + } + } + + return result; + } + @Override public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset) { @@ -186,6 +259,18 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf pagesHashStrategy.appendTo(blockIndex, blockPosition, pageBuilder, outputChannelOffset); } + @Override + public boolean isMappingUnique() + { + return uniqueMapping; + } + + @Override + public boolean usesHash() + { + return false; + } + private boolean isPositionNull(int position) { long pageAddress = addresses.getLong(position); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java index e6c39580f88f..31bc383ec936 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java @@ -54,6 +54,7 @@ public final class DefaultPagesHash private final byte[] positionToHashes; private final long hashCollisions; private final double expectedHashCollisions; + private final boolean uniqueMapping; public DefaultPagesHash( LongArrayList addresses, @@ -77,6 +78,7 @@ public DefaultPagesHash( int positionsInStep = Math.min(addresses.size() + 1, (int) CACHE_SIZE.toBytes() / Integer.SIZE); long[] positionToFullHashes = new long[positionsInStep]; long hashCollisionsLocal = 0; + boolean uniqueMapping = true; for (int step = 0; step * positionsInStep <= addresses.size(); step++) { int stepBeginPosition = step * positionsInStep; @@ -110,6 +112,7 @@ public DefaultPagesHash( // found a slot for this key // link the new key position to the current key position realPosition = positionLinks.link(realPosition, currentKey); + uniqueMapping = false; // key[pos] updated outside of this loop break; @@ -122,6 +125,7 @@ public DefaultPagesHash( key[pos] = realPosition; } } + this.uniqueMapping = uniqueMapping; size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() + sizeOf(key) + sizeOf(positionToHashes); @@ -174,6 +178,75 @@ public int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHas return -1; } + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage) + { + long[] hashes = new long[positions.length]; + for (int i = 0; i < positions.length; i++) { + hashes[i] = pagesHashStrategy.hashRow(i, hashChannelsPage); + } + + return getAddressIndex(positions, hashChannelsPage, hashes); + } + + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) + { + int positionCount = positions.length; + int[] hashPositions = new int[positionCount]; + + for (int i = 0; i < positionCount; i++) { + hashPositions[i] = getHashPosition(rawHashes[i], mask); + } + + int[] found = new int[positionCount]; + int foundCount = 0; + int[] result = new int[positionCount]; + Arrays.fill(result, -1); + int[] foundKeys = new int[positionCount]; + + // Search for positions in the hash array. The ones that were found are put into `found` array, + // while the `foundKeys` arrays holds the keys that has been read from the hash array + for (int i = 0; i < positionCount; i++) { + if (key[hashPositions[i]] != -1) { + found[foundCount] = i; + foundKeys[foundCount++] = key[hashPositions[i]]; + } + } + + // At this step we determine if the found keys were indeed the proper ones or it is a hash collision. + // The result array is updated for the found ones, while the collisions land into `remaining` array. + int[] remaining = found; // Rename for readability + int remainingCount = 0; + for (int i = 0; i < foundCount; i++) { + int index = found[i]; + if (positionEqualsCurrentRowIgnoreNulls(foundKeys[i], (byte) rawHashes[index], positions[index], hashChannelsPage)) { + result[index] = foundKeys[i]; + } + else { + remaining[remainingCount++] = index; + } + } + + // At this point for any reasoable load factor of a hash array (< .75), there is no more than + // 10 - 15% of positions left. We search for them in a sequential order and update the result array. + for (int i = 0; i < remainingCount; i++) { + int index = remaining[i]; + int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked + + while (key[position] != -1) { + if (positionEqualsCurrentRowIgnoreNulls(key[position], (byte) rawHashes[index], positions[index], hashChannelsPage)) { + result[index] = key[position]; + break; + } + // increment position and mask to handler wrap around + position = (position + 1) & mask; + } + } + + return result; + } + @Override public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset) { @@ -184,6 +257,12 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf pagesHashStrategy.appendTo(blockIndex, blockPosition, pageBuilder, outputChannelOffset); } + @Override + public boolean isMappingUnique() + { + return uniqueMapping; + } + private boolean isPositionNull(int position) { long pageAddress = addresses.getLong(position); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java b/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java index dd1ef7bfa29b..27573e8e8817 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java @@ -21,6 +21,7 @@ import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -86,6 +87,20 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel return startJoinPosition(addressIndex, position, allChannelsPage); } + @Override + public long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes) + { + int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage, rawHashes); + return startJoinPosition(addressIndexex, positions, allChannelsPage); + } + + @Override + public long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage) + { + int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage); + return startJoinPosition(addressIndexex, positions, allChannelsPage); + } + private long startJoinPosition(int currentJoinPosition, int probePosition, Page allProbeChannelsPage) { if (currentJoinPosition == -1) { @@ -97,6 +112,31 @@ private long startJoinPosition(int currentJoinPosition, int probePosition, Page return positionLinks.start(currentJoinPosition, probePosition, allProbeChannelsPage); } + private long[] startJoinPosition(int[] currentJoinPosition, int[] probePosition, Page allProbeChannelsPage) + { + checkArgument(currentJoinPosition.length == probePosition.length); + int positionCount = currentJoinPosition.length; + long[] result = new long[positionCount]; + + if (positionLinks == null) { + for (int i = 0; i < positionCount; i++) { + result[i] = currentJoinPosition[i]; + } + return result; + } + + for (int i = 0; i < positionCount; i++) { + if (currentJoinPosition[i] == -1) { + result[i] = -1; + } + else { + result[i] = positionLinks.start(currentJoinPosition[i], probePosition[i], allProbeChannelsPage); + } + } + + return result; + } + @Override public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) { @@ -118,6 +158,24 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf pagesHash.appendTo(toIntExact(position), pageBuilder, outputChannelOffset); } + @Override + public boolean isMappingUnique() + { + return pagesHash.isMappingUnique(); + } + + @Override + public boolean isJoinPositionAlwaysEligible() + { + return filterFunction == null; + } + + @Override + public boolean usesHash() + { + return pagesHash.usesHash(); + } + @Override public void close() { diff --git a/core/trino-main/src/main/java/io/trino/operator/join/JoinStatisticsCounter.java b/core/trino-main/src/main/java/io/trino/operator/join/JoinStatisticsCounter.java index 2f1421ee78fd..3cda1bd30255 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/JoinStatisticsCounter.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/JoinStatisticsCounter.java @@ -53,6 +53,11 @@ public void updateLookupSourcePositions(long lookupSourcePositionsDelta) } public void recordProbe(int numSourcePositions) + { + recordProbe(numSourcePositions, 1); + } + + public void recordProbe(int numSourcePositions, int numberOfValues) { int bucket; if (numSourcePositions <= INDIVIDUAL_BUCKETS) { @@ -67,8 +72,8 @@ else if (numSourcePositions <= 100) { else { bucket = INDIVIDUAL_BUCKETS + 3; } - logHistogramCounters[2 * bucket]++; - logHistogramCounters[2 * bucket + 1] += numSourcePositions; + logHistogramCounters[2 * bucket] += numberOfValues; + logHistogramCounters[2 * bucket + 1] += numSourcePositions * numberOfValues; } @Override diff --git a/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java index c889296e8021..34ddd1303fd1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java @@ -32,8 +32,28 @@ public interface LookupSource long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash); + default long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes) + { + long[] result = new long[positions.length]; + for (int i = 0; i < positions.length; i++) { + result[i] = getJoinPosition(positions[i], hashChannelsPage, allChannelsPage, rawHashes[i]); + } + + return result; + } + long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage); + default long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage) + { + long[] result = new long[positions.length]; + for (int i = 0; i < positions.length; i++) { + result[i] = getJoinPosition(positions[i], hashChannelsPage, allChannelsPage); + } + + return result; + } + long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage); void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset); @@ -44,4 +64,30 @@ public interface LookupSource @Override void close(); + + /** + * @return true if there is a certainty that every position from the probe side is joined + * with at most a single position on the build side. + * This is true for queries where joins are carried out on the indexed/unique column. + */ + default boolean isMappingUnique() + { + return false; + } + + /** + * @return true if `isJoinPositionEligible` always returns true, regardless of the input arguments + */ + default boolean isJoinPositionAlwaysEligible() + { + return false; + } + + /** + * @return false if the hash argument is ignored by the lookup source in `getJoinPosition` methods + */ + default boolean usesHash() + { + return true; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java index d7c158098b55..b16d39df73d8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java @@ -30,5 +30,39 @@ public interface PagesHash int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHash); + default int[] getAddressIndex(int[] positions, Page hashChannelsPage) + { + int[] result = new int[positions.length]; + for (int i = 0; i < positions.length; i++) { + result[i] = getAddressIndex(positions[i], hashChannelsPage); + } + return result; + } + + default int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) + { + int[] result = new int[positions.length]; + for (int i = 0; i < positions.length; i++) { + result[i] = getAddressIndex(positions[i], hashChannelsPage, rawHashes[i]); + } + return result; + } + void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset); + + /** + * {@link LookupSource#isMappingUnique()} + */ + default boolean isMappingUnique() + { + return false; + } + + /** + * {@link LookupSource#usesHash()} + */ + default boolean usesHash() + { + return true; + } } diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/JoinProbe.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/JoinProbe.java new file mode 100644 index 000000000000..dc212324b909 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/JoinProbe.java @@ -0,0 +1,228 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.join.unspilled; + +import com.google.common.primitives.Ints; +import io.trino.operator.join.LookupSource; +import io.trino.spi.Page; +import io.trino.spi.block.Block; +import io.trino.spi.block.RunLengthEncodedBlock; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.OptionalInt; + +import static com.google.common.base.Verify.verify; +import static io.trino.spi.type.BigintType.BIGINT; +import static java.util.Objects.requireNonNull; + +public class JoinProbe +{ + private static final int RLE_MASK = 0; + private static final int NON_RLE_MASK = -1; + + public static class JoinProbeFactory + { + private final int[] probeOutputChannels; + private final int[] probeJoinChannels; + private final int probeHashChannel; // only valid when >= 0 + private final boolean hasFilter; + + public JoinProbeFactory(int[] probeOutputChannels, List probeJoinChannels, OptionalInt probeHashChannel, boolean hasFilter) + { + this.probeOutputChannels = probeOutputChannels; + this.probeJoinChannels = Ints.toArray(probeJoinChannels); + this.probeHashChannel = probeHashChannel.orElse(-1); + this.hasFilter = hasFilter; + } + + public JoinProbe createJoinProbe(Page page, LookupSource lookupSource) + { + Page probePage = page.getLoadedPage(probeJoinChannels); + return new JoinProbe(probeOutputChannels, page, probePage, lookupSource, probeHashChannel >= 0 ? page.getBlock(probeHashChannel).getLoadedBlock() : null, hasFilter); + } + } + + private final int[] probeOutputChannels; + private final int positionCount; + private final Page page; + private final Page probePage; + @Nullable + private final Block probeHashBlock; + private final LookupSource lookupSource; + private final long[] joinPositionCache; + private final boolean isRle; + /** + * This value is 0xFFFFFFFF for non-rle and 0x00000000 for rle. This way if we access cache by: + * `cache[position & rleMask]` we will always get the first position for RLE and the proper position for non-rle. + * This way there is no branch (if) in the code + */ + private final int rleMask; + private int position = -1; + + private JoinProbe(int[] probeOutputChannels, Page page, Page probePage, LookupSource lookupSource, @Nullable Block probeHashBlock, boolean hasFilter) + { + this.probeOutputChannels = requireNonNull(probeOutputChannels, "probeOutputChannels is null"); + this.page = requireNonNull(page, "page is null"); + this.positionCount = page.getPositionCount(); + this.probePage = requireNonNull(probePage, "probePage is null"); + this.lookupSource = requireNonNull(lookupSource, "lookupSource is null"); + this.probeHashBlock = probeHashBlock; + + isRle = hasOnlyRleBlocks(probePage); + rleMask = isRle ? RLE_MASK : NON_RLE_MASK; + joinPositionCache = fillCache(); + } + + public int[] getOutputChannels() + { + return probeOutputChannels; + } + + public boolean advanceNextPosition() + { + verify(++position <= positionCount, "already finished"); + return !isFinished(); + } + + public boolean isFinished() + { + return position == positionCount; + } + + public long getCurrentJoinPosition() + { + return joinPositionCache[position & rleMask]; + } + + public int getPosition() + { + return position; + } + + public Page getPage() + { + return page; + } + + public boolean isRle() + { + return isRle; + } + + private long[] fillCache() + { + if (isRle) { + long[] joinPositionCache = new long[1]; + if (rowContainsNull(0)) { + joinPositionCache[0] = -1; + } + else { + joinPositionCache[0] = lookupSource.getJoinPosition(0, probePage, page); + } + + return joinPositionCache; + } + + long[] joinPositionCache = new long[positionCount]; + Arrays.fill(joinPositionCache, -1); + if (probeMayHaveNull(probePage)) { + int nonNullCount = 0; + boolean[] isNull = new boolean[positionCount]; + for (int i = 0; i < positionCount; i++) { + isNull[i] = rowContainsNull(i); + nonNullCount += isNull[i] ? 0 : 1; + } + if (nonNullCount < positionCount) { + // We only store positions that are not null + int[] positions = new int[nonNullCount]; + nonNullCount = 0; + for (int i = 0; i < positionCount; i++) { + if (!isNull[i]) { + positions[nonNullCount++] = i; + } + } + long[] packedPositionCache; + if (probeHashBlock != null) { + long[] hashes = new long[nonNullCount]; + for (int i = 0; i < nonNullCount; i++) { + hashes[i] = BIGINT.getLong(probeHashBlock, positions[i]); + } + packedPositionCache = lookupSource.getJoinPosition(positions, probePage, page, hashes); + } + else { + packedPositionCache = lookupSource.getJoinPosition(positions, probePage, page); + } + // Unpack + nonNullCount = 0; + for (int i = 0; i < positionCount; i++) { + if (!isNull[i]) { + joinPositionCache[i] = packedPositionCache[nonNullCount++]; + } + } + return joinPositionCache; + } // else fall back to non-null path + } + int[] positions = new int[positionCount]; + for (int i = 0; i < positionCount; i++) { + positions[i] = i; + } + if (probeHashBlock != null) { + long[] hashes = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + hashes[i] = BIGINT.getLong(probeHashBlock, i); + } + return lookupSource.getJoinPosition(positions, probePage, page, hashes); + } + else { + return lookupSource.getJoinPosition(positions, probePage, page); + } + } + + private boolean rowContainsNull(int position) + { + for (int i = 0; i < probePage.getChannelCount(); i++) { + if (probePage.getBlock(i).isNull(position)) { + return true; + } + } + return false; + } + + private static boolean probeMayHaveNull(Page probePage) + { + for (int i = 0; i < probePage.getChannelCount(); i++) { + if (probePage.getBlock(i).mayHaveNull()) { + return true; + } + } + return false; + } + + private static boolean hasOnlyRleBlocks(Page probePage) + { + if (probePage.getChannelCount() == 0) { + return false; + } + + for (int i = 0; i < probePage.getChannelCount(); i++) { + if (!(probePage.getBlock(i) instanceof RunLengthEncodedBlock)) { + return false; + } + } + return true; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java index ad16e077fcbd..1498d4badc66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java @@ -21,10 +21,10 @@ import io.trino.operator.ProcessorContext; import io.trino.operator.WorkProcessor; import io.trino.operator.WorkProcessorOperatorAdapter.AdapterWorkProcessorOperator; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.JoinStatisticsCounter; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSource; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java index 881593987516..bc3e8406180c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java @@ -29,9 +29,9 @@ import io.trino.operator.WorkProcessorOperatorAdapter.AdapterWorkProcessorOperatorFactory; import io.trino.operator.join.JoinBridgeManager; import io.trino.operator.join.JoinOperatorFactory; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupOuterOperator.LookupOuterOperatorFactory; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.PlanNodeId; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java index 3d759a47f950..dc0fc7c1ca4b 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java @@ -13,7 +13,6 @@ */ package io.trino.operator.join.unspilled; -import io.trino.operator.join.JoinProbe; import io.trino.operator.join.LookupSource; import io.trino.spi.Page; import io.trino.spi.PageBuilder; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java index f0826b19a2d6..cd956153ecc8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java @@ -18,11 +18,10 @@ import io.trino.operator.DriverYieldSignal; import io.trino.operator.ProcessorContext; import io.trino.operator.WorkProcessor; -import io.trino.operator.join.JoinProbe; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.JoinStatisticsCounter; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSource; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; @@ -31,7 +30,6 @@ import java.io.Closeable; import java.util.List; -import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.addSuccessCallback; @@ -97,17 +95,10 @@ public WorkProcessor.TransformationState process(@Nullable Page probePage) { boolean finishing = probePage == null; - if (probe == null) { - if (!finishing) { - // create new probe for next probe page - probe = joinProbeFactory.createJoinProbe(probePage); - } - else { - close(); - return finished(); - } + if (probe == null && finishing) { + close(); + return finished(); } - verify(probe != null, "no probe to work with"); if (lookupSource == null) { if (!lookupSourceFuture.isDone()) { @@ -117,6 +108,9 @@ public WorkProcessor.TransformationState process(@Nullable Page probePage) lookupSource = requireNonNull(getDone(lookupSourceFuture)); statisticsCounter.updateLookupSourcePositions(lookupSource.getJoinPositionCount()); } + if (probe == null) { + probe = joinProbeFactory.createJoinProbe(probePage, lookupSource); + } processProbe(lookupSource); @@ -142,6 +136,62 @@ public WorkProcessor.TransformationState process(@Nullable Page probePage) } private void processProbe(LookupSource lookupSource) + { + if (lookupSource.isMappingUnique()) { + processUniqueMappingProbe(lookupSource); + } + else { + processStandardProbe(lookupSource); + } + } + + private void processUniqueMappingProbe(LookupSource lookupSource) + { + if (probe.getPosition() == -1) { + probe.advanceNextPosition(); + } + if (probe.isFinished()) { + return; + } + + int matches = 0; + int mismatches = 0; + + if (lookupSource.isJoinPositionAlwaysEligible()) { + do { + joinPosition = probe.getCurrentJoinPosition(); + boolean match = joinPosition >= 0; + matches += match ? 1 : 0; + mismatches += match ? 0 : 1; + if (match) { + pageBuilder.appendRow(probe, lookupSource, joinPosition); + } + else if (probeOnOuterSide) { + pageBuilder.appendNullForBuild(probe); + } + } + while (probe.advanceNextPosition() && !pageBuilder.isFull() && !yieldSignal.isSet()); + } + else { + do { + joinPosition = probe.getCurrentJoinPosition(); + boolean match = joinPosition >= 0 && lookupSource.isJoinPositionEligible(joinPosition, probe.getPosition(), probe.getPage()); + matches += match ? 1 : 0; + mismatches += match ? 0 : 1; + if (match) { + pageBuilder.appendRow(probe, lookupSource, joinPosition); + } + else if (probeOnOuterSide) { + pageBuilder.appendNullForBuild(probe); + } + } + while (probe.advanceNextPosition() && !pageBuilder.isFull() && !yieldSignal.isSet()); + } + statisticsCounter.recordProbe(0, mismatches); + statisticsCounter.recordProbe(1, matches); + } + + private void processStandardProbe(LookupSource lookupSource) { do { if (probe.getPosition() >= 0) { @@ -153,7 +203,7 @@ private void processProbe(LookupSource lookupSource) } statisticsCounter.recordProbe(joinSourcePositions); } - if (!advanceProbePosition(lookupSource)) { + if (!advanceProbePosition()) { break; } } @@ -211,14 +261,14 @@ private boolean outerJoinCurrentPosition() /** * @return whether there are more positions on probe side */ - private boolean advanceProbePosition(LookupSource lookupSource) + private boolean advanceProbePosition() { if (!probe.advanceNextPosition()) { return false; } // update join position - joinPosition = probe.getCurrentJoinPosition(lookupSource); + joinPosition = probe.getCurrentJoinPosition(); // reset row join state for next row joinSourcePositions = 0; currentProbePositionProducedRow = false; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java index 9cbf2bffbc91..3ca9f42f50a1 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java @@ -46,6 +46,11 @@ public class PartitionedLookupSource implements LookupSource { + // If the estimated size of positions per partition is smaller than that number, + // the batched version will fall back to the sequential one. + // This number has been determined by TPC benchmark results. + private static final int MIN_PARTITION_SIZE_FOR_BATCHING = 8; + public static TrackingLookupSourceSupplier createPartitionedLookupSourceSupplier(List> partitions, List hashChannelTypes, boolean outer, BlockTypeOperators blockTypeOperators) { if (outer) { @@ -90,6 +95,9 @@ public OuterPositionIterator getOuterPositionIterator() private final int shiftSize; @Nullable private final OuterPositionTracker outerPositionTracker; + private final boolean uniqueMapping; + private final boolean joinPositionsAlwaysEligible; + private final boolean usesHash; private boolean closed; @@ -104,6 +112,13 @@ private PartitionedLookupSource(List lookupSources, List this.partitionMask = lookupSources.size() - 1; this.shiftSize = numberOfTrailingZeros(lookupSources.size()) + 1; this.outerPositionTracker = outerPositionTracker.orElse(null); + + uniqueMapping = lookupSources.stream() + .allMatch(lookupSource -> lookupSource.isMappingUnique()); + joinPositionsAlwaysEligible = lookupSources.stream() + .allMatch(lookupSource -> lookupSource.isJoinPositionAlwaysEligible()); + usesHash = lookupSources.stream() + .anyMatch(lookupSource -> lookupSource.usesHash()); } @Override @@ -144,6 +159,98 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel return encodePartitionedJoinPosition(partition, toIntExact(joinPosition)); } + @Override + public long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes) + { + int positionCount = positions.length; + int partitionCount = partitionGenerator.getPartitionCount(); + + if (positionCount / partitionCount < MIN_PARTITION_SIZE_FOR_BATCHING) { + return LookupSource.super.getJoinPosition(positions, hashChannelsPage, allChannelsPage, rawHashes); + } + + int[] partitions = new int[positionCount]; + int[] partitionPositionsCount = new int[partitionCount]; + + // Get the partitions for every position and calculate the size of every partition + for (int i = 0; i < positionCount; i++) { + int partition = partitionGenerator.getPartition(rawHashes[i]); + partitions[i] = partition; + partitionPositionsCount[partition]++; + } + + int[][] positionsPerPartition = new int[partitionCount][]; + int[] positionPerPartitionCount = new int[partitionCount]; + long[][] resultPerPartition = new long[partitionCount][]; + @Nullable + long[][] hashesPerPartition; + if (usesHash) { + hashesPerPartition = new long[partitionCount][]; + for (int partition = 0; partition < partitionCount; partition++) { + positionsPerPartition[partition] = new int[partitionPositionsCount[partition]]; + hashesPerPartition[partition] = new long[partitionPositionsCount[partition]]; + } + + // Split input positions into partitions + for (int i = 0; i < positionCount; i++) { + int partition = partitions[i]; + positionsPerPartition[partition][positionPerPartitionCount[partition]] = positions[i]; + hashesPerPartition[partition][positionPerPartitionCount[partition]] = rawHashes[i]; + positionPerPartitionCount[partition]++; + } + + // Delegate partitioned positions to designated lookup sources + for (int partition = 0; partition < partitionCount; partition++) { + resultPerPartition[partition] = lookupSources[partition].getJoinPosition(positionsPerPartition[partition], hashChannelsPage, allChannelsPage, hashesPerPartition[partition]); + } + } + else { + hashesPerPartition = null; + for (int partition = 0; partition < partitionCount; partition++) { + positionsPerPartition[partition] = new int[partitionPositionsCount[partition]]; + } + + // Split input positions into partitions + for (int i = 0; i < positionCount; i++) { + int partition = partitions[i]; + positionsPerPartition[partition][positionPerPartitionCount[partition]] = positions[i]; + positionPerPartitionCount[partition]++; + } + + // Delegate partitioned positions to designated lookup sources + for (int partition = 0; partition < partitionCount; partition++) { + resultPerPartition[partition] = lookupSources[partition].getJoinPosition(positionsPerPartition[partition], hashChannelsPage, allChannelsPage); + } + } + + // Merge results into a single array + long[] result = new long[positionCount]; + for (int partition = 0; partition < partitionCount; partition++) { + positionPerPartitionCount[partition] = 0; + } + for (int i = 0; i < positionCount; i++) { + int partition = partitions[i]; + result[i] = toIntExact(resultPerPartition[partition][positionPerPartitionCount[partition]++]); + if (result[i] != -1) { + result[i] = encodePartitionedJoinPosition(partition, (int) result[i]); + } + } + + return result; + } + + @Override + public long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage) + { + int positionCount = positions.length; + long[] rawHashes = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + rawHashes[i] = partitionGenerator.getRawHash(hashChannelsPage, positions[i]); + } + + return getJoinPosition(positions, hashChannelsPage, allChannelsPage, rawHashes); + } + @Override public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) { @@ -339,4 +446,16 @@ public void commit() } } } + + @Override + public boolean isMappingUnique() + { + return uniqueMapping; + } + + @Override + public boolean isJoinPositionAlwaysEligible() + { + return joinPositionsAlwaysEligible; + } } diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java index 86a549e5000e..6b72e3a5b663 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java @@ -14,9 +14,8 @@ package io.trino.operator.join.unspilled; import com.google.common.collect.ImmutableList; -import io.trino.operator.join.JoinProbe; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupSource; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; @@ -46,10 +45,10 @@ public void testPageBuilder() Block block = blockBuilder.build(); Page page = new Page(block, block); - JoinProbeFactory joinProbeFactory = new JoinProbeFactory(new int[] {0, 1}, ImmutableList.of(0, 1), OptionalInt.empty()); - JoinProbe probe = joinProbeFactory.createJoinProbe(page); + JoinProbeFactory joinProbeFactory = new JoinProbeFactory(new int[] {0, 1}, ImmutableList.of(0, 1), OptionalInt.empty(), false); LookupSource lookupSource = new TestLookupSource(ImmutableList.of(BIGINT, BIGINT), page); - io.trino.operator.join.LookupJoinPageBuilder lookupJoinPageBuilder = new io.trino.operator.join.LookupJoinPageBuilder(ImmutableList.of(BIGINT, BIGINT)); + JoinProbe probe = joinProbeFactory.createJoinProbe(page, lookupSource); + LookupJoinPageBuilder lookupJoinPageBuilder = new LookupJoinPageBuilder(ImmutableList.of(BIGINT, BIGINT)); int joinPosition = 0; while (!lookupJoinPageBuilder.isFull() && probe.advanceNextPosition()) { @@ -94,12 +93,12 @@ public void testDifferentPositions() } Block block = blockBuilder.build(); Page page = new Page(block); - JoinProbeFactory joinProbeFactory = new JoinProbeFactory(new int[] {0}, ImmutableList.of(0), OptionalInt.empty()); + JoinProbeFactory joinProbeFactory = new JoinProbeFactory(new int[] {0}, ImmutableList.of(0), OptionalInt.empty(), false); LookupSource lookupSource = new TestLookupSource(ImmutableList.of(BIGINT), page); - io.trino.operator.join.LookupJoinPageBuilder lookupJoinPageBuilder = new io.trino.operator.join.LookupJoinPageBuilder(ImmutableList.of(BIGINT)); + LookupJoinPageBuilder lookupJoinPageBuilder = new LookupJoinPageBuilder(ImmutableList.of(BIGINT)); // empty - JoinProbe probe = joinProbeFactory.createJoinProbe(page); + JoinProbe probe = joinProbeFactory.createJoinProbe(page, lookupSource); Page output = lookupJoinPageBuilder.build(probe); assertEquals(output.getChannelCount(), 2); assertTrue(output.getBlock(0) instanceof DictionaryBlock); @@ -107,7 +106,7 @@ public void testDifferentPositions() lookupJoinPageBuilder.reset(); // the probe covers non-sequential positions - probe = joinProbeFactory.createJoinProbe(page); + probe = joinProbeFactory.createJoinProbe(page, lookupSource); for (int joinPosition = 0; probe.advanceNextPosition(); joinPosition++) { if (joinPosition % 2 == 1) { continue; @@ -125,7 +124,7 @@ public void testDifferentPositions() lookupJoinPageBuilder.reset(); // the probe covers everything - probe = joinProbeFactory.createJoinProbe(page); + probe = joinProbeFactory.createJoinProbe(page, lookupSource); for (int joinPosition = 0; probe.advanceNextPosition(); joinPosition++) { lookupJoinPageBuilder.appendRow(probe, lookupSource, joinPosition); } @@ -140,7 +139,7 @@ public void testDifferentPositions() lookupJoinPageBuilder.reset(); // the probe covers some sequential positions - probe = joinProbeFactory.createJoinProbe(page); + probe = joinProbeFactory.createJoinProbe(page, lookupSource); for (int joinPosition = 0; probe.advanceNextPosition(); joinPosition++) { if (joinPosition < 10 || joinPosition >= 50) { continue; @@ -166,7 +165,7 @@ public void testCrossJoinWithEmptyBuild() // nothing on the build side so we don't append anything LookupSource lookupSource = new TestLookupSource(ImmutableList.of(), page); - JoinProbe probe = (new JoinProbeFactory(new int[] {0}, ImmutableList.of(0), OptionalInt.empty())).createJoinProbe(page); + JoinProbe probe = (new JoinProbeFactory(new int[] {0}, ImmutableList.of(0), OptionalInt.empty(), false)).createJoinProbe(page, lookupSource); LookupJoinPageBuilder lookupJoinPageBuilder = new LookupJoinPageBuilder(ImmutableList.of(BIGINT)); // append the same row many times should also flush in the end @@ -222,7 +221,7 @@ public long getJoinPosition(int position, Page page, Page allChannelsPage, long @Override public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage) { - throw new UnsupportedOperationException(); + return -1; } @Override