diff --git a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java index c154f035c18f..8957a188b555 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java +++ b/core/trino-main/src/main/java/io/trino/operator/TrinoOperatorFactories.java @@ -17,6 +17,7 @@ import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupJoinOperatorFactory; import io.trino.operator.join.LookupSourceFactory; +import io.trino.operator.join.unspilled.JoinProbe; import io.trino.operator.join.unspilled.PartitionedLookupSourceFactory; import io.trino.spi.type.Type; import io.trino.spiller.PartitioningSpillerFactory; @@ -59,7 +60,7 @@ public OperatorFactory join( probeOutputChannelTypes, lookupSourceFactory.getBuildOutputTypes(), joinType, - new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel), + new JoinProbe.JoinProbeFactory(probeOutputChannels, probeJoinChannel, probeHashChannel), blockTypeOperators, probeJoinChannel, probeHashChannel); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java index 7edd8958d15b..f0babd315448 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/BigintPagesHash.java @@ -26,6 +26,7 @@ import java.util.Arrays; import java.util.List; +import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.slice.SizeOf.sizeOf; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.trino.operator.SyntheticAddress.decodePosition; @@ -179,6 +180,78 @@ public int getAddressIndex(int position, Page hashChannelsPage) return -1; } + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) + { + return getAddressIndex(positions, hashChannelsPage); + } + + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage) + { + checkArgument(hashChannelsPage.getChannelCount() == 1, "Multiple channel page passed to BigintPagesHash"); + + int positionCount = positions.length; + long[] incomingValues = new long[positionCount]; + int[] hashPositions = new int[positionCount]; + + for (int i = 0; i < positionCount; i++) { + incomingValues[i] = hashChannelsPage.getBlock(0).getLong(positions[i], 0); + hashPositions[i] = getHashPosition(incomingValues[i], mask); + } + + int[] found = new int[positionCount]; + int foundCount = 0; + int[] result = new int[positionCount]; + Arrays.fill(result, -1); + int[] foundKeys = new int[positionCount]; + + // Search for positions in the hash array. This is the most CPU-consuming part as + // it relies on random memory accesses + for (int i = 0; i < positionCount; i++) { + foundKeys[i] = keys[hashPositions[i]]; + } + // Found positions are put into `found` array + for (int i = 0; i < positionCount; i++) { + if (foundKeys[i] != -1) { + found[foundCount++] = i; + } + } + + // At this step we determine if the found keys were indeed the proper ones or it is a hash collision. + // The result array is updated for the found ones, while the collisions land into `remaining` array. + int[] remaining = found; // Rename for readability + int remainingCount = 0; + + for (int i = 0; i < foundCount; i++) { + int index = found[i]; + if (values[foundKeys[index]] == incomingValues[index]) { + result[index] = foundKeys[index]; + } + else { + remaining[remainingCount++] = index; + } + } + + // At this point for any reasoable load factor of a hash array (< .75), there is no more than + // 10 - 15% of positions left. We search for them in a sequential order and update the result array. + for (int i = 0; i < remainingCount; i++) { + int index = remaining[i]; + int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked + + while (keys[position] != -1) { + if (values[keys[position]] == incomingValues[index]) { + result[index] = keys[position]; + break; + } + // increment position and mask to handler wrap around + position = (position + 1) & mask; + } + } + + return result; + } + @Override public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset) { diff --git a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java index f9b748bac60d..bc086026b226 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/DefaultPagesHash.java @@ -175,6 +175,78 @@ public int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHas return -1; } + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage) + { + long[] hashes = new long[positions[positions.length - 1] + 1]; + for (int i = 0; i < positions.length; i++) { + hashes[positions[i]] = pagesHashStrategy.hashRow(positions[i], hashChannelsPage); + } + + return getAddressIndex(positions, hashChannelsPage, hashes); + } + + @Override + public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) + { + int positionCount = positions.length; + int[] hashPositions = new int[positionCount]; + + for (int i = 0; i < positionCount; i++) { + hashPositions[i] = getHashPosition(rawHashes[positions[i]], mask); + } + + int[] found = new int[positionCount]; + int foundCount = 0; + int[] result = new int[positionCount]; + Arrays.fill(result, -1); + int[] foundKeys = new int[positionCount]; + + // Search for positions in the hash array. This is the most CPU-consuming part as + // it relies on random memory accesses + for (int i = 0; i < positionCount; i++) { + foundKeys[i] = keys[hashPositions[i]]; + } + // Found positions are put into `found` array + for (int i = 0; i < positionCount; i++) { + if (foundKeys[i] != -1) { + found[foundCount++] = i; + } + } + + // At this step we determine if the found keys were indeed the proper ones or it is a hash collision. + // The result array is updated for the found ones, while the collisions land into `remaining` array. + int[] remaining = found; // Rename for readability + int remainingCount = 0; + for (int i = 0; i < foundCount; i++) { + int index = found[i]; + if (positionEqualsCurrentRowIgnoreNulls(foundKeys[index], (byte) rawHashes[positions[index]], positions[index], hashChannelsPage)) { + result[index] = foundKeys[index]; + } + else { + remaining[remainingCount++] = index; + } + } + + // At this point for any reasoable load factor of a hash array (< .75), there is no more than + // 10 - 15% of positions left. We search for them in a sequential order and update the result array. + for (int i = 0; i < remainingCount; i++) { + int index = remaining[i]; + int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked + + while (keys[position] != -1) { + if (positionEqualsCurrentRowIgnoreNulls(keys[position], (byte) rawHashes[positions[index]], positions[index], hashChannelsPage)) { + result[index] = keys[position]; + break; + } + // increment position and mask to handler wrap around + position = (position + 1) & mask; + } + } + + return result; + } + @Override public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset) { diff --git a/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java b/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java index 6756187a7137..81befd2d6613 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java @@ -21,6 +21,7 @@ import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -86,6 +87,20 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel return startJoinPosition(addressIndex, position, allChannelsPage); } + @Override + public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result) + { + int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage, rawHashes); + startJoinPosition(addressIndexex, positions, allChannelsPage, result); + } + + @Override + public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] result) + { + int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage); + startJoinPosition(addressIndexex, positions, allChannelsPage, result); + } + private long startJoinPosition(int currentJoinPosition, int probePosition, Page allProbeChannelsPage) { if (currentJoinPosition == -1) { @@ -97,6 +112,33 @@ private long startJoinPosition(int currentJoinPosition, int probePosition, Page return positionLinks.start(currentJoinPosition, probePosition, allProbeChannelsPage); } + private long[] startJoinPosition(int[] currentJoinPositions, int[] probePositions, Page allProbeChannelsPage, long[] result) + { + checkArgument(currentJoinPositions.length == probePositions.length, + "currentJoinPositions and probePositions arrays must have the same size, %s != %s", + currentJoinPositions.length, + probePositions.length); + int positionCount = currentJoinPositions.length; + + if (positionLinks == null) { + for (int i = 0; i < positionCount; i++) { + result[probePositions[i]] = currentJoinPositions[i]; + } + return result; + } + + for (int i = 0; i < positionCount; i++) { + if (currentJoinPositions[i] == -1) { + result[probePositions[i]] = -1; + } + else { + result[probePositions[i]] = positionLinks.start(currentJoinPositions[i], probePositions[i], allProbeChannelsPage); + } + } + + return result; + } + @Override public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) { diff --git a/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java index c889296e8021..5b83fdf931b2 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/LookupSource.java @@ -30,8 +30,30 @@ public interface LookupSource long joinPositionWithinPartition(long joinPosition); + /** + * The `rawHashes` and `result` arrays are global to the entire processed page (thus, the same size), + * while the `positions` array may hold any number of selected positions from this page + */ + default void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result) + { + for (int i = 0; i < positions.length; i++) { + result[positions[i]] = getJoinPosition(positions[i], hashChannelsPage, allChannelsPage, rawHashes[positions[i]]); + } + } + long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash); + /** + * The `result` array is global to the entire processed page, while the `positions` array may hold + * any number of selected positions from this page + */ + default void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] result) + { + for (int i = 0; i < positions.length; i++) { + result[positions[i]] = getJoinPosition(positions[i], hashChannelsPage, allChannelsPage); + } + } + long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage); long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage); diff --git a/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java b/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java index 60cd02f185c9..084ce12c2507 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/PagesHash.java @@ -30,6 +30,24 @@ public interface PagesHash int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHash); + default int[] getAddressIndex(int[] positions, Page hashChannelsPage) + { + int[] result = new int[positions.length]; + for (int i = 0; i < positions.length; i++) { + result[i] = getAddressIndex(positions[i], hashChannelsPage); + } + return result; + } + + default int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes) + { + int[] result = new int[positions.length]; + for (int i = 0; i < positions.length; i++) { + result[i] = getAddressIndex(positions[i], hashChannelsPage, rawHashes[positions[i]]); + } + return result; + } + void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset); static int getHashPosition(long rawHash, long mask) diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/JoinProbe.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/JoinProbe.java new file mode 100644 index 000000000000..00d70ca74695 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/JoinProbe.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.operator.join.unspilled; + +import com.google.common.primitives.Ints; +import io.trino.operator.join.LookupSource; +import io.trino.spi.Page; +import io.trino.spi.block.Block; + +import javax.annotation.Nullable; + +import java.util.Arrays; +import java.util.List; +import java.util.OptionalInt; +import java.util.stream.IntStream; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.type.BigintType.BIGINT; +import static java.util.Objects.requireNonNull; + +public class JoinProbe +{ + public static class JoinProbeFactory + { + private final int[] probeOutputChannels; + private final int[] probeJoinChannels; + private final int probeHashChannel; // only valid when >= 0 + + public JoinProbeFactory(List probeOutputChannels, List probeJoinChannels, OptionalInt probeHashChannel) + { + this.probeOutputChannels = Ints.toArray(requireNonNull(probeOutputChannels, "probeOutputChannels is null")); + this.probeJoinChannels = Ints.toArray(requireNonNull(probeJoinChannels, "probeJoinChannels is null")); + this.probeHashChannel = requireNonNull(probeHashChannel, "probeHashChannel is null").orElse(-1); + } + + public JoinProbe createJoinProbe(Page page, LookupSource lookupSource) + { + Page probePage = page.getLoadedPage(probeJoinChannels); + return new JoinProbe(probeOutputChannels, page, probePage, lookupSource, probeHashChannel >= 0 ? page.getBlock(probeHashChannel).getLoadedBlock() : null); + } + } + + private final int[] probeOutputChannels; + private final Page page; + private final long[] joinPositionCache; + private int position = -1; + + private JoinProbe(int[] probeOutputChannels, Page page, Page probePage, LookupSource lookupSource, @Nullable Block probeHashBlock) + { + this.probeOutputChannels = requireNonNull(probeOutputChannels, "probeOutputChannels is null"); + this.page = requireNonNull(page, "page is null"); + + joinPositionCache = fillCache(lookupSource, page, probeHashBlock, probePage); + } + + public int[] getOutputChannels() + { + return probeOutputChannels; + } + + public boolean advanceNextPosition() + { + verify(++position <= page.getPositionCount(), "already finished"); + return !isFinished(); + } + + public boolean isFinished() + { + return position == page.getPositionCount(); + } + + public long getCurrentJoinPosition() + { + return joinPositionCache[position]; + } + + public int getPosition() + { + return position; + } + + public Page getPage() + { + return page; + } + + private static long[] fillCache( + LookupSource lookupSource, + Page page, + Block probeHashBlock, + Page probePage) + { + int positionCount = page.getPositionCount(); + List nullableBlocks = IntStream.range(0, probePage.getChannelCount()) + .mapToObj(i -> probePage.getBlock(i)) + .filter(Block::mayHaveNull) + .collect(toImmutableList()); + + long[] joinPositionCache = new long[positionCount]; + if (!nullableBlocks.isEmpty()) { + Arrays.fill(joinPositionCache, -1); + boolean[] isNull = new boolean[positionCount]; + int nonNullCount = getIsNull(nullableBlocks, positionCount, isNull); + if (nonNullCount < positionCount) { + // We only store positions that are not null + int[] positions = new int[nonNullCount]; + nonNullCount = 0; + for (int i = 0; i < positionCount; i++) { + if (!isNull[i]) { + positions[nonNullCount] = i; + } + // This way less code is in the if branch and CPU should be able to optimize branch prediction better + nonNullCount += isNull[i] ? 0 : 1; + } + if (probeHashBlock != null) { + long[] hashes = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + hashes[i] = BIGINT.getLong(probeHashBlock, i); + } + lookupSource.getJoinPosition(positions, probePage, page, hashes, joinPositionCache); + } + else { + lookupSource.getJoinPosition(positions, probePage, page, joinPositionCache); + } + return joinPositionCache; + } // else fall back to non-null path + } + int[] positions = new int[positionCount]; + for (int i = 0; i < positionCount; i++) { + positions[i] = i; + } + if (probeHashBlock != null) { + long[] hashes = new long[positionCount]; + for (int i = 0; i < positionCount; i++) { + hashes[i] = BIGINT.getLong(probeHashBlock, i); + } + lookupSource.getJoinPosition(positions, probePage, page, hashes, joinPositionCache); + } + else { + lookupSource.getJoinPosition(positions, probePage, page, joinPositionCache); + } + + return joinPositionCache; + } + + private static int getIsNull(List nullableBlocks, int positionCount, boolean[] isNull) + { + for (int i = 0; i < nullableBlocks.size() - 1; i++) { + Block block = nullableBlocks.get(i); + for (int position = 0; position < positionCount; position++) { + isNull[position] |= block.isNull(position); + } + } + // Last block will also calculate `nonNullCount` + int nonNullCount = 0; + Block lastBlock = nullableBlocks.get(nullableBlocks.size() - 1); + for (int position = 0; position < positionCount; position++) { + isNull[position] |= lastBlock.isNull(position); + nonNullCount += isNull[position] ? 0 : 1; + } + + return nonNullCount; + } +} diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java index ad16e077fcbd..1498d4badc66 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperator.java @@ -21,10 +21,10 @@ import io.trino.operator.ProcessorContext; import io.trino.operator.WorkProcessor; import io.trino.operator.WorkProcessorOperatorAdapter.AdapterWorkProcessorOperator; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.JoinStatisticsCounter; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSource; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java index 1352737f1a63..6d6f39ae7f51 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinOperatorFactory.java @@ -30,9 +30,9 @@ import io.trino.operator.WorkProcessorOperatorAdapter.AdapterWorkProcessorOperatorFactory; import io.trino.operator.join.JoinBridgeManager; import io.trino.operator.join.JoinOperatorFactory; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupOuterOperator.LookupOuterOperatorFactory; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; import io.trino.sql.planner.plan.PlanNodeId; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java index fe0246092721..d5fb4d890002 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/LookupJoinPageBuilder.java @@ -13,7 +13,6 @@ */ package io.trino.operator.join.unspilled; -import io.trino.operator.join.JoinProbe; import io.trino.operator.join.LookupSource; import io.trino.spi.Page; import io.trino.spi.PageBuilder; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java index f0826b19a2d6..18751f386d35 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PageJoiner.java @@ -18,11 +18,10 @@ import io.trino.operator.DriverYieldSignal; import io.trino.operator.ProcessorContext; import io.trino.operator.WorkProcessor; -import io.trino.operator.join.JoinProbe; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.JoinStatisticsCounter; import io.trino.operator.join.LookupJoinOperatorFactory.JoinType; import io.trino.operator.join.LookupSource; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.type.Type; @@ -31,7 +30,6 @@ import java.io.Closeable; import java.util.List; -import static com.google.common.base.Verify.verify; import static com.google.common.base.Verify.verifyNotNull; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.addSuccessCallback; @@ -97,17 +95,10 @@ public WorkProcessor.TransformationState process(@Nullable Page probePage) { boolean finishing = probePage == null; - if (probe == null) { - if (!finishing) { - // create new probe for next probe page - probe = joinProbeFactory.createJoinProbe(probePage); - } - else { - close(); - return finished(); - } + if (probe == null && finishing) { + close(); + return finished(); } - verify(probe != null, "no probe to work with"); if (lookupSource == null) { if (!lookupSourceFuture.isDone()) { @@ -117,6 +108,9 @@ public WorkProcessor.TransformationState process(@Nullable Page probePage) lookupSource = requireNonNull(getDone(lookupSourceFuture)); statisticsCounter.updateLookupSourcePositions(lookupSource.getJoinPositionCount()); } + if (probe == null) { + probe = joinProbeFactory.createJoinProbe(probePage, lookupSource); + } processProbe(lookupSource); @@ -153,7 +147,7 @@ private void processProbe(LookupSource lookupSource) } statisticsCounter.recordProbe(joinSourcePositions); } - if (!advanceProbePosition(lookupSource)) { + if (!advanceProbePosition()) { break; } } @@ -211,14 +205,14 @@ private boolean outerJoinCurrentPosition() /** * @return whether there are more positions on probe side */ - private boolean advanceProbePosition(LookupSource lookupSource) + private boolean advanceProbePosition() { if (!probe.advanceNextPosition()) { return false; } // update join position - joinPosition = probe.getCurrentJoinPosition(lookupSource); + joinPosition = probe.getCurrentJoinPosition(); // reset row join state for next row joinSourcePositions = 0; currentProbePositionProducedRow = false; diff --git a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java index a701de4c7049..0b11a75d81e8 100644 --- a/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java +++ b/core/trino-main/src/main/java/io/trino/operator/join/unspilled/PartitionedLookupSource.java @@ -142,6 +142,58 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel return encodePartitionedJoinPosition(partition, toIntExact(joinPosition)); } + @Override + public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result) + { + int positionCount = positions.length; + int partitionCount = partitionGenerator.getPartitionCount(); + + int[] partitions = new int[positionCount]; + int[] partitionPositionsCount = new int[partitionCount]; + + // Get the partitions for every position and calculate the size of every partition + for (int i = 0; i < positionCount; i++) { + int partition = partitionGenerator.getPartition(rawHashes[positions[i]]); + partitions[i] = partition; + partitionPositionsCount[partition]++; + } + + int[][] positionsPerPartition = new int[partitionCount][]; + for (int partition = 0; partition < partitionCount; partition++) { + positionsPerPartition[partition] = new int[partitionPositionsCount[partition]]; + } + + // Split input positions into partitions + int[] positionsPerPartitionCount = new int[partitionCount]; + for (int i = 0; i < positionCount; i++) { + int partition = partitions[i]; + positionsPerPartition[partition][positionsPerPartitionCount[partition]] = positions[i]; + positionsPerPartitionCount[partition]++; + } + + // Delegate partitioned positions to designated lookup sources + for (int partition = 0; partition < partitionCount; partition++) { + lookupSources[partition].getJoinPosition(positionsPerPartition[partition], hashChannelsPage, allChannelsPage, rawHashes, result); + } + + for (int i = 0; i < positionCount; i++) { + int partition = partitions[i]; + result[positions[i]] = encodePartitionedJoinPosition(partition, (int) result[positions[i]]); + } + } + + @Override + public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] result) + { + int positionCount = positions.length; + long[] rawHashes = new long[result.length]; + for (int i = 0; i < positionCount; i++) { + rawHashes[positions[i]] = partitionGenerator.getRawHash(hashChannelsPage, positions[i]); + } + + getJoinPosition(positions, hashChannelsPage, allChannelsPage, rawHashes, result); + } + @Override public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage) { diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java index 4bba9ea19923..30fb9a5a8a87 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestLookupJoinPageBuilder.java @@ -14,9 +14,8 @@ package io.trino.operator.join.unspilled; import com.google.common.collect.ImmutableList; -import io.trino.operator.join.JoinProbe; -import io.trino.operator.join.JoinProbe.JoinProbeFactory; import io.trino.operator.join.LookupSource; +import io.trino.operator.join.unspilled.JoinProbe.JoinProbeFactory; import io.trino.spi.Page; import io.trino.spi.PageBuilder; import io.trino.spi.block.Block; @@ -46,9 +45,9 @@ public void testPageBuilder() Block block = blockBuilder.build(); Page page = new Page(block, block); - JoinProbeFactory joinProbeFactory = new JoinProbeFactory(new int[] {0, 1}, ImmutableList.of(0, 1), OptionalInt.empty()); - JoinProbe probe = joinProbeFactory.createJoinProbe(page); + JoinProbeFactory joinProbeFactory = new JoinProbeFactory(ImmutableList.of(0, 1), ImmutableList.of(0, 1), OptionalInt.empty()); LookupSource lookupSource = new TestLookupSource(ImmutableList.of(BIGINT, BIGINT), page); + JoinProbe probe = joinProbeFactory.createJoinProbe(page, lookupSource); LookupJoinPageBuilder lookupJoinPageBuilder = new LookupJoinPageBuilder(ImmutableList.of(BIGINT, BIGINT)); int joinPosition = 0; @@ -94,12 +93,12 @@ public void testDifferentPositions() } Block block = blockBuilder.build(); Page page = new Page(block); - JoinProbeFactory joinProbeFactory = new JoinProbeFactory(new int[] {0}, ImmutableList.of(0), OptionalInt.empty()); + JoinProbeFactory joinProbeFactory = new JoinProbeFactory(ImmutableList.of(0), ImmutableList.of(0), OptionalInt.empty()); LookupSource lookupSource = new TestLookupSource(ImmutableList.of(BIGINT), page); LookupJoinPageBuilder lookupJoinPageBuilder = new LookupJoinPageBuilder(ImmutableList.of(BIGINT)); // empty - JoinProbe probe = joinProbeFactory.createJoinProbe(page); + JoinProbe probe = joinProbeFactory.createJoinProbe(page, lookupSource); Page output = lookupJoinPageBuilder.build(probe); assertEquals(output.getChannelCount(), 2); assertTrue(output.getBlock(0) instanceof DictionaryBlock); @@ -107,7 +106,7 @@ public void testDifferentPositions() lookupJoinPageBuilder.reset(); // the probe covers non-sequential positions - probe = joinProbeFactory.createJoinProbe(page); + probe = joinProbeFactory.createJoinProbe(page, lookupSource); for (int joinPosition = 0; probe.advanceNextPosition(); joinPosition++) { if (joinPosition % 2 == 1) { continue; @@ -125,7 +124,7 @@ public void testDifferentPositions() lookupJoinPageBuilder.reset(); // the probe covers everything - probe = joinProbeFactory.createJoinProbe(page); + probe = joinProbeFactory.createJoinProbe(page, lookupSource); for (int joinPosition = 0; probe.advanceNextPosition(); joinPosition++) { lookupJoinPageBuilder.appendRow(probe, lookupSource, joinPosition); } @@ -140,7 +139,7 @@ public void testDifferentPositions() lookupJoinPageBuilder.reset(); // the probe covers some sequential positions - probe = joinProbeFactory.createJoinProbe(page); + probe = joinProbeFactory.createJoinProbe(page, lookupSource); for (int joinPosition = 0; probe.advanceNextPosition(); joinPosition++) { if (joinPosition < 10 || joinPosition >= 50) { continue; @@ -166,7 +165,7 @@ public void testCrossJoinWithEmptyBuild() // nothing on the build side so we don't append anything LookupSource lookupSource = new TestLookupSource(ImmutableList.of(), page); - JoinProbe probe = (new JoinProbeFactory(new int[] {0}, ImmutableList.of(0), OptionalInt.empty())).createJoinProbe(page); + JoinProbe probe = (new JoinProbeFactory(ImmutableList.of(0), ImmutableList.of(0), OptionalInt.empty())).createJoinProbe(page, lookupSource); LookupJoinPageBuilder lookupJoinPageBuilder = new LookupJoinPageBuilder(ImmutableList.of(BIGINT)); // append the same row many times should also flush in the end @@ -222,7 +221,7 @@ public long getJoinPosition(int position, Page page, Page allChannelsPage, long @Override public long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage) { - throw new UnsupportedOperationException(); + return -1; } @Override