Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.operator.join.LookupJoinOperatorFactory;
import io.trino.operator.join.LookupJoinOperatorFactory.JoinType;
import io.trino.operator.join.LookupSourceFactory;
import io.trino.operator.join.unspilled.JoinProbe;
import io.trino.spi.type.Type;
import io.trino.spiller.PartitioningSpillerFactory;
import io.trino.sql.planner.plan.PlanNodeId;
Expand Down Expand Up @@ -58,6 +59,7 @@ public OperatorFactory innerJoin(
operatorId,
planNodeId,
lookupSourceFactory,
hasFilter,
probeTypes,
probeJoinChannel,
probeHashChannel,
Expand Down Expand Up @@ -91,6 +93,7 @@ public OperatorFactory probeOuterJoin(
operatorId,
planNodeId,
lookupSourceFactory,
hasFilter,
probeTypes,
probeJoinChannel,
probeHashChannel,
Expand Down Expand Up @@ -124,6 +127,7 @@ public OperatorFactory lookupOuterJoin(
operatorId,
planNodeId,
lookupSourceFactory,
hasFilter,
probeTypes,
probeJoinChannel,
probeHashChannel,
Expand Down Expand Up @@ -156,6 +160,7 @@ public OperatorFactory fullOuterJoin(
operatorId,
planNodeId,
lookupSourceFactory,
hasFilter,
probeTypes,
probeJoinChannel,
probeHashChannel,
Expand All @@ -180,6 +185,7 @@ private OperatorFactory createJoinOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
JoinBridgeManager<?> lookupSourceFactoryManager,
boolean hasFilter,
List<Type> probeTypes,
List<Integer> probeJoinChannel,
OptionalInt probeHashChannel,
Expand Down Expand Up @@ -225,7 +231,7 @@ private OperatorFactory createJoinOperatorFactory(
joinType,
outputSingleMatch,
waitForBuild,
new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel),
new JoinProbe.JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel, hasFilter),
blockTypeOperators,
probeJoinChannel,
probeHashChannel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Arrays;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.trino.operator.SyntheticAddress.decodePosition;
Expand Down Expand Up @@ -53,6 +54,7 @@ public final class BigintPagesHash

private final long hashCollisions;
private final double expectedHashCollisions;
private final boolean uniqueMapping;

public BigintPagesHash(
LongArrayList addresses,
Expand Down Expand Up @@ -82,6 +84,7 @@ public BigintPagesHash(
// We will process addresses in batches, to save memory on array of hashes.
int positionsInStep = Math.min(addresses.size() + 1, (int) CACHE_SIZE.toBytes() / Integer.SIZE);
long hashCollisionsLocal = 0;
boolean uniqueMapping = true;

for (int step = 0; step * positionsInStep <= addresses.size(); step++) {
int stepBeginPosition = step * positionsInStep;
Expand Down Expand Up @@ -110,6 +113,7 @@ public BigintPagesHash(
// found a slot for this key
// link the new key position to the current key position
realPosition = positionLinks.link(realPosition, currentKey);
uniqueMapping = false;

// key[pos] updated outside of this loop
break;
Expand All @@ -123,6 +127,7 @@ public BigintPagesHash(
values[pos] = value;
}
}
this.uniqueMapping = uniqueMapping;

size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() +
sizeOf(key) + sizeOf(values);
Expand Down Expand Up @@ -176,6 +181,74 @@ public int getAddressIndex(int position, Page hashChannelsPage)
return -1;
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes)
{
return getAddressIndex(positions, hashChannelsPage);
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage)
{
checkArgument(hashChannelsPage.getChannelCount() == 1, "Non-signle channel page passed to BigintPagesHash");

int positionCount = positions.length;
long[] incomingValues = new long[positionCount];
int[] hashPositions = new int[positionCount];

for (int i = 0; i < positionCount; i++) {
incomingValues[i] = hashChannelsPage.getBlock(0).getLong(positions[i], 0);
hashPositions[i] = getHashPosition(incomingValues[i], mask);
}

int[] found = new int[positionCount];
int foundCount = 0;
int[] result = new int[positionCount];
Arrays.fill(result, -1);
int[] foundKeys = new int[positionCount];

// Search for positions in the hash array. The ones that were found are put into `found` array,
// while the `foundKeys` arrays holds the keys that has been read from the hash array
for (int i = 0; i < positionCount; i++) {
if (key[hashPositions[i]] != -1) {
found[foundCount] = i;
foundKeys[foundCount++] = key[hashPositions[i]];
}
}

// At this step we determine if the found keys were indeed the proper ones or it is a hash collision.
// The result array is updated for the found ones, while the collisions land into `remaining` array.
int[] remaining = found; // Rename for readability
int remainingCount = 0;
for (int i = 0; i < foundCount; i++) {
int index = found[i];
if (values[hashPositions[index]] == incomingValues[index]) {
result[index] = foundKeys[i];
}
else {
remaining[remainingCount++] = index;
}
}

// At this point for any reasoable load factor of a hash array (< .75), there is no more than
// 10 - 15% of positions left. We search for them in a sequential order and update the result array.
for (int i = 0; i < remainingCount; i++) {
int index = remaining[i];
int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked

while (key[position] != -1) {
if (values[position] == incomingValues[index]) {
result[index] = key[position];
break;
}
// increment position and mask to handler wrap around
position = (position + 1) & mask;
}
}

return result;
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand All @@ -186,6 +259,18 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf
pagesHashStrategy.appendTo(blockIndex, blockPosition, pageBuilder, outputChannelOffset);
}

@Override
public boolean isMappingUnique()
{
return uniqueMapping;
}

@Override
public boolean usesHash()
{
return false;
}

private boolean isPositionNull(int position)
{
long pageAddress = addresses.getLong(position);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public final class DefaultPagesHash
private final byte[] positionToHashes;
private final long hashCollisions;
private final double expectedHashCollisions;
private final boolean uniqueMapping;

public DefaultPagesHash(
LongArrayList addresses,
Expand All @@ -77,6 +78,7 @@ public DefaultPagesHash(
int positionsInStep = Math.min(addresses.size() + 1, (int) CACHE_SIZE.toBytes() / Integer.SIZE);
long[] positionToFullHashes = new long[positionsInStep];
long hashCollisionsLocal = 0;
boolean uniqueMapping = true;

for (int step = 0; step * positionsInStep <= addresses.size(); step++) {
int stepBeginPosition = step * positionsInStep;
Expand Down Expand Up @@ -110,6 +112,7 @@ public DefaultPagesHash(
// found a slot for this key
// link the new key position to the current key position
realPosition = positionLinks.link(realPosition, currentKey);
uniqueMapping = false;

// key[pos] updated outside of this loop
break;
Expand All @@ -122,6 +125,7 @@ public DefaultPagesHash(
key[pos] = realPosition;
}
}
this.uniqueMapping = uniqueMapping;

size = sizeOf(addresses.elements()) + pagesHashStrategy.getSizeInBytes() +
sizeOf(key) + sizeOf(positionToHashes);
Expand Down Expand Up @@ -174,6 +178,75 @@ public int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHas
return -1;
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage)
{
long[] hashes = new long[positions.length];
for (int i = 0; i < positions.length; i++) {
hashes[i] = pagesHashStrategy.hashRow(i, hashChannelsPage);
}

return getAddressIndex(positions, hashChannelsPage, hashes);
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes)
{
int positionCount = positions.length;
int[] hashPositions = new int[positionCount];

for (int i = 0; i < positionCount; i++) {
hashPositions[i] = getHashPosition(rawHashes[i], mask);
}

int[] found = new int[positionCount];
int foundCount = 0;
int[] result = new int[positionCount];
Arrays.fill(result, -1);
int[] foundKeys = new int[positionCount];

// Search for positions in the hash array. The ones that were found are put into `found` array,
// while the `foundKeys` arrays holds the keys that has been read from the hash array
for (int i = 0; i < positionCount; i++) {
if (key[hashPositions[i]] != -1) {
found[foundCount] = i;
foundKeys[foundCount++] = key[hashPositions[i]];
}
}

// At this step we determine if the found keys were indeed the proper ones or it is a hash collision.
// The result array is updated for the found ones, while the collisions land into `remaining` array.
int[] remaining = found; // Rename for readability
int remainingCount = 0;
for (int i = 0; i < foundCount; i++) {
int index = found[i];
if (positionEqualsCurrentRowIgnoreNulls(foundKeys[i], (byte) rawHashes[index], positions[index], hashChannelsPage)) {
result[index] = foundKeys[i];
}
else {
remaining[remainingCount++] = index;
}
}

// At this point for any reasoable load factor of a hash array (< .75), there is no more than
// 10 - 15% of positions left. We search for them in a sequential order and update the result array.
for (int i = 0; i < remainingCount; i++) {
int index = remaining[i];
int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked

while (key[position] != -1) {
if (positionEqualsCurrentRowIgnoreNulls(key[position], (byte) rawHashes[index], positions[index], hashChannelsPage)) {
result[index] = key[position];
break;
}
// increment position and mask to handler wrap around
position = (position + 1) & mask;
}
}

return result;
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand All @@ -184,6 +257,12 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf
pagesHashStrategy.appendTo(blockIndex, blockPosition, pageBuilder, outputChannelOffset);
}

@Override
public boolean isMappingUnique()
{
return uniqueMapping;
}

private boolean isPositionNull(int position)
{
long pageAddress = addresses.getLong(position);
Expand Down
58 changes: 58 additions & 0 deletions core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -86,6 +87,20 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel
return startJoinPosition(addressIndex, position, allChannelsPage);
}

@Override
public long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes)
{
int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage, rawHashes);
return startJoinPosition(addressIndexex, positions, allChannelsPage);
}

@Override
public long[] getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage)
{
int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage);
return startJoinPosition(addressIndexex, positions, allChannelsPage);
}

private long startJoinPosition(int currentJoinPosition, int probePosition, Page allProbeChannelsPage)
{
if (currentJoinPosition == -1) {
Expand All @@ -97,6 +112,31 @@ private long startJoinPosition(int currentJoinPosition, int probePosition, Page
return positionLinks.start(currentJoinPosition, probePosition, allProbeChannelsPage);
}

private long[] startJoinPosition(int[] currentJoinPosition, int[] probePosition, Page allProbeChannelsPage)
{
checkArgument(currentJoinPosition.length == probePosition.length);
int positionCount = currentJoinPosition.length;
long[] result = new long[positionCount];

if (positionLinks == null) {
for (int i = 0; i < positionCount; i++) {
result[i] = currentJoinPosition[i];
}
return result;
}

for (int i = 0; i < positionCount; i++) {
if (currentJoinPosition[i] == -1) {
result[i] = -1;
}
else {
result[i] = positionLinks.start(currentJoinPosition[i], probePosition[i], allProbeChannelsPage);
}
}

return result;
}

@Override
public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage)
{
Expand All @@ -118,6 +158,24 @@ public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOf
pagesHash.appendTo(toIntExact(position), pageBuilder, outputChannelOffset);
}

@Override
public boolean isMappingUnique()
{
return pagesHash.isMappingUnique();
}

@Override
public boolean isJoinPositionAlwaysEligible()
{
return filterFunction == null;
}

@Override
public boolean usesHash()
{
return pagesHash.usesHash();
}

@Override
public void close()
{
Expand Down
Loading