Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import io.trino.operator.join.JoinProbe.JoinProbeFactory;
import io.trino.operator.join.LookupJoinOperatorFactory;
import io.trino.operator.join.LookupSourceFactory;
import io.trino.operator.join.unspilled.JoinProbe;
import io.trino.operator.join.unspilled.PartitionedLookupSourceFactory;
import io.trino.spi.type.Type;
import io.trino.spiller.PartitioningSpillerFactory;
Expand Down Expand Up @@ -59,7 +60,7 @@ public OperatorFactory join(
probeOutputChannelTypes,
lookupSourceFactory.getBuildOutputTypes(),
joinType,
new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel),
new JoinProbe.JoinProbeFactory(probeOutputChannels, probeJoinChannel, probeHashChannel),
blockTypeOperators,
probeJoinChannel,
probeHashChannel);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.Arrays;
import java.util.List;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.units.DataSize.Unit.KILOBYTE;
import static io.trino.operator.SyntheticAddress.decodePosition;
Expand Down Expand Up @@ -179,6 +180,78 @@ public int getAddressIndex(int position, Page hashChannelsPage)
return -1;
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes)
{
return getAddressIndex(positions, hashChannelsPage);
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage)
{
checkArgument(hashChannelsPage.getChannelCount() == 1, "Multiple channel page passed to BigintPagesHash");

int positionCount = positions.length;
long[] incomingValues = new long[positionCount];
int[] hashPositions = new int[positionCount];

for (int i = 0; i < positionCount; i++) {
incomingValues[i] = hashChannelsPage.getBlock(0).getLong(positions[i], 0);
hashPositions[i] = getHashPosition(incomingValues[i], mask);
}

int[] found = new int[positionCount];
int foundCount = 0;
int[] result = new int[positionCount];
Arrays.fill(result, -1);
int[] foundKeys = new int[positionCount];

// Search for positions in the hash array. This is the most CPU-consuming part as
// it relies on random memory accesses
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is also values[foundKeys[index]] at line 228 that is random load, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. But it only fetches positions that has been found, which is often only a fraction.

for (int i = 0; i < positionCount; i++) {
foundKeys[i] = keys[hashPositions[i]];
}
// Found positions are put into `found` array
for (int i = 0; i < positionCount; i++) {
if (foundKeys[i] != -1) {
found[foundCount++] = i;
}
}

// At this step we determine if the found keys were indeed the proper ones or it is a hash collision.
// The result array is updated for the found ones, while the collisions land into `remaining` array.
int[] remaining = found; // Rename for readability
int remainingCount = 0;

for (int i = 0; i < foundCount; i++) {
int index = found[i];
if (values[foundKeys[index]] == incomingValues[index]) {
result[index] = foundKeys[index];
}
else {
remaining[remainingCount++] = index;
}
}

// At this point for any reasoable load factor of a hash array (< .75), there is no more than
// 10 - 15% of positions left. We search for them in a sequential order and update the result array.
for (int i = 0; i < remainingCount; i++) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think it's a good idea to extract this part to a separate method? it would be useful to see the impact of this part in the profiler output

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will get inlined anyway. As of readability, I prefer those perf-optimised methods to be a bit longer but with some comments. I believe it is more readable that the standard clean code approach, given that the code is not easy to understand and never will be.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will get inlined anyway

Probably, but profilers can figure out the original method in most cases and attribute the time correctly e.g. in the flame graph.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYM? If the method is inlined it does not exist in the JFR. Or am I missing something?

int index = remaining[i];
int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked

while (keys[position] != -1) {
if (values[keys[position]] == incomingValues[index]) {
result[index] = keys[position];
break;
}
// increment position and mask to handler wrap around
position = (position + 1) & mask;
}
}

return result;
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,78 @@ public int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHas
return -1;
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage)
{
long[] hashes = new long[positions[positions.length - 1] + 1];
for (int i = 0; i < positions.length; i++) {
hashes[positions[i]] = pagesHashStrategy.hashRow(positions[i], hashChannelsPage);
}

return getAddressIndex(positions, hashChannelsPage, hashes);
}

@Override
public int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes)
{
int positionCount = positions.length;
int[] hashPositions = new int[positionCount];

for (int i = 0; i < positionCount; i++) {
hashPositions[i] = getHashPosition(rawHashes[positions[i]], mask);
}

int[] found = new int[positionCount];
int foundCount = 0;
int[] result = new int[positionCount];
Arrays.fill(result, -1);
int[] foundKeys = new int[positionCount];

// Search for positions in the hash array. This is the most CPU-consuming part as
// it relies on random memory accesses
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

positionEqualsCurrentRowIgnoreNulls can be actually way more expensive as depending on the number of channels it can have a lot more random memory accesses + it has to do the the equals logic

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Batching some methods from PagesHashStrategy might be a good follow-up

for (int i = 0; i < positionCount; i++) {
foundKeys[i] = keys[hashPositions[i]];
}
// Found positions are put into `found` array
for (int i = 0; i < positionCount; i++) {
if (foundKeys[i] != -1) {
found[foundCount++] = i;
}
}

// At this step we determine if the found keys were indeed the proper ones or it is a hash collision.
// The result array is updated for the found ones, while the collisions land into `remaining` array.
int[] remaining = found; // Rename for readability
int remainingCount = 0;
for (int i = 0; i < foundCount; i++) {
int index = found[i];
if (positionEqualsCurrentRowIgnoreNulls(foundKeys[index], (byte) rawHashes[positions[index]], positions[index], hashChannelsPage)) {
result[index] = foundKeys[index];
}
else {
remaining[remainingCount++] = index;
}
}

// At this point for any reasoable load factor of a hash array (< .75), there is no more than
// 10 - 15% of positions left. We search for them in a sequential order and update the result array.
for (int i = 0; i < remainingCount; i++) {
int index = remaining[i];
int position = (hashPositions[index] + 1) & mask; // hashPositions[index] position has already been checked

while (keys[position] != -1) {
if (positionEqualsCurrentRowIgnoreNulls(keys[position], (byte) rawHashes[positions[index]], positions[index], hashChannelsPage)) {
result[index] = keys[position];
break;
}
// increment position and mask to handler wrap around
position = (position + 1) & mask;
}
}

return result;
}

@Override
public void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset)
{
Expand Down
42 changes: 42 additions & 0 deletions core/trino-main/src/main/java/io/trino/operator/join/JoinHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

Expand Down Expand Up @@ -86,6 +87,20 @@ public long getJoinPosition(int position, Page hashChannelsPage, Page allChannel
return startJoinPosition(addressIndex, position, allChannelsPage);
}

@Override
public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result)
{
int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage, rawHashes);
startJoinPosition(addressIndexex, positions, allChannelsPage, result);
}

@Override
public void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] result)
{
int[] addressIndexex = pagesHash.getAddressIndex(positions, hashChannelsPage);
startJoinPosition(addressIndexex, positions, allChannelsPage, result);
}

private long startJoinPosition(int currentJoinPosition, int probePosition, Page allProbeChannelsPage)
{
if (currentJoinPosition == -1) {
Expand All @@ -97,6 +112,33 @@ private long startJoinPosition(int currentJoinPosition, int probePosition, Page
return positionLinks.start(currentJoinPosition, probePosition, allProbeChannelsPage);
}

private long[] startJoinPosition(int[] currentJoinPositions, int[] probePositions, Page allProbeChannelsPage, long[] result)
{
checkArgument(currentJoinPositions.length == probePositions.length,
"currentJoinPositions and probePositions arrays must have the same size, %s != %s",
currentJoinPositions.length,
probePositions.length);
int positionCount = currentJoinPositions.length;

if (positionLinks == null) {
for (int i = 0; i < positionCount; i++) {
result[probePositions[i]] = currentJoinPositions[i];
}
return result;
}

for (int i = 0; i < positionCount; i++) {
if (currentJoinPositions[i] == -1) {
result[probePositions[i]] = -1;
}
else {
result[probePositions[i]] = positionLinks.start(currentJoinPositions[i], probePositions[i], allProbeChannelsPage);
}
}

return result;
}

@Override
public long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,30 @@ public interface LookupSource

long joinPositionWithinPartition(long joinPosition);

/**
* The `rawHashes` and `result` arrays are global to the entire processed page (thus, the same size),
* while the `positions` array may hold any number of selected positions from this page
*/
default void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] rawHashes, long[] result)
{
for (int i = 0; i < positions.length; i++) {
result[positions[i]] = getJoinPosition(positions[i], hashChannelsPage, allChannelsPage, rawHashes[positions[i]]);
}
}

long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage, long rawHash);

/**
* The `result` array is global to the entire processed page, while the `positions` array may hold
* any number of selected positions from this page
*/
default void getJoinPosition(int[] positions, Page hashChannelsPage, Page allChannelsPage, long[] result)
{
for (int i = 0; i < positions.length; i++) {
result[positions[i]] = getJoinPosition(positions[i], hashChannelsPage, allChannelsPage);
}
}

long getJoinPosition(int position, Page hashChannelsPage, Page allChannelsPage);

long getNextJoinPosition(long currentJoinPosition, int probePosition, Page allProbeChannelsPage);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,24 @@ public interface PagesHash

int getAddressIndex(int rightPosition, Page hashChannelsPage, long rawHash);

default int[] getAddressIndex(int[] positions, Page hashChannelsPage)
{
int[] result = new int[positions.length];
for (int i = 0; i < positions.length; i++) {
result[i] = getAddressIndex(positions[i], hashChannelsPage);
}
return result;
}

default int[] getAddressIndex(int[] positions, Page hashChannelsPage, long[] rawHashes)
{
int[] result = new int[positions.length];
for (int i = 0; i < positions.length; i++) {
result[i] = getAddressIndex(positions[i], hashChannelsPage, rawHashes[positions[i]]);
}
return result;
}

void appendTo(long position, PageBuilder pageBuilder, int outputChannelOffset);

static int getHashPosition(long rawHash, long mask)
Expand Down
Loading