Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 2 additions & 13 deletions core/trino-main/src/main/java/io/trino/operator/ChannelSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FLAT_RETURN;
import static io.trino.spi.function.InvocationConvention.simpleConvention;
import static io.trino.spi.type.BigintType.BIGINT;
import static java.util.Objects.requireNonNull;

public class ChannelSet
Expand Down Expand Up @@ -88,24 +87,14 @@ public ChannelSet build()
return new ChannelSet(set);
}

public void addAll(Block valueBlock, Block hashBlock)
public void addAll(Block valueBlock)
{
if (valueBlock.getPositionCount() == 0) {
return;
}

if (valueBlock instanceof RunLengthEncodedBlock rleBlock) {
if (hashBlock != null) {
set.add(rleBlock.getValue(), 0, BIGINT.getLong(hashBlock, 0));
}
else {
set.add(rleBlock.getValue(), 0);
}
}
else if (hashBlock != null) {
for (int position = 0; position < valueBlock.getPositionCount(); position++) {
set.add(valueBlock, position, BIGINT.getLong(hashBlock, position));
}
set.add(rleBlock.getValue(), 0);
}
else {
for (int position = 0; position < valueBlock.getPositionCount(); position++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import jakarta.annotation.Nullable;

import java.util.List;
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
Expand All @@ -40,7 +39,6 @@
import static io.trino.operator.WorkProcessor.TransformationState.finished;
import static io.trino.operator.WorkProcessor.TransformationState.ofResult;
import static io.trino.operator.WorkProcessorOperatorAdapter.createAdapterOperatorFactory;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static java.util.Objects.requireNonNull;

Expand All @@ -52,10 +50,9 @@ public static OperatorFactory createOperatorFactory(
PlanNodeId planNodeId,
SetSupplier setSupplier,
List<? extends Type> probeTypes,
int probeJoinChannel,
Optional<Integer> probeJoinHashChannel)
int probeJoinChannel)
{
return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, setSupplier, probeTypes, probeJoinChannel, probeJoinHashChannel));
return createAdapterOperatorFactory(new Factory(operatorId, planNodeId, setSupplier, probeTypes, probeJoinChannel));
}

private static class Factory
Expand All @@ -66,25 +63,23 @@ private static class Factory
private final SetSupplier setSupplier;
private final List<Type> probeTypes;
private final int probeJoinChannel;
private final Optional<Integer> probeJoinHashChannel;
private boolean closed;

private Factory(int operatorId, PlanNodeId planNodeId, SetSupplier setSupplier, List<? extends Type> probeTypes, int probeJoinChannel, Optional<Integer> probeJoinHashChannel)
private Factory(int operatorId, PlanNodeId planNodeId, SetSupplier setSupplier, List<? extends Type> probeTypes, int probeJoinChannel)
{
this.operatorId = operatorId;
this.planNodeId = requireNonNull(planNodeId, "planNodeId is null");
this.setSupplier = setSupplier;
this.probeTypes = ImmutableList.copyOf(probeTypes);
checkArgument(probeJoinChannel >= 0, "probeJoinChannel is negative");
this.probeJoinChannel = probeJoinChannel;
this.probeJoinHashChannel = probeJoinHashChannel;
}

@Override
public WorkProcessorOperator create(ProcessorContext processorContext, WorkProcessor<Page> sourcePages)
{
checkState(!closed, "Factory is already closed");
return new HashSemiJoinOperator(sourcePages, setSupplier, probeJoinChannel, probeJoinHashChannel, processorContext.getMemoryTrackingContext());
return new HashSemiJoinOperator(sourcePages, setSupplier, probeJoinChannel, processorContext.getMemoryTrackingContext());
}

@Override
Expand Down Expand Up @@ -114,7 +109,7 @@ public void close()
@Override
public Factory duplicate()
{
return new Factory(operatorId, planNodeId, setSupplier, probeTypes, probeJoinChannel, probeJoinHashChannel);
return new Factory(operatorId, planNodeId, setSupplier, probeTypes, probeJoinChannel);
}
}

Expand All @@ -124,14 +119,12 @@ private HashSemiJoinOperator(
WorkProcessor<Page> sourcePages,
SetSupplier channelSetFuture,
int probeJoinChannel,
Optional<Integer> probeHashChannel,
MemoryTrackingContext memoryTrackingContext)
{
pages = sourcePages
.transform(new SemiJoinPages(
channelSetFuture,
probeJoinChannel,
probeHashChannel,
memoryTrackingContext.aggregateUserMemoryContext()));
}

Expand All @@ -144,23 +137,19 @@ public WorkProcessor<Page> getOutputPages()
private static class SemiJoinPages
implements WorkProcessor.Transformation<Page, Page>
{
private static final int NO_PRECOMPUTED_HASH_CHANNEL = -1;

private final int probeJoinChannel;
private final int probeHashChannel; // when >= 0, this is the precomputed hash channel
private final ListenableFuture<ChannelSet> channelSetFuture;
private final LocalMemoryContext localMemoryContext;

@Nullable
private ChannelSet channelSet;

public SemiJoinPages(SetSupplier channelSetFuture, int probeJoinChannel, Optional<Integer> probeHashChannel, AggregatedMemoryContext aggregatedMemoryContext)
public SemiJoinPages(SetSupplier channelSetFuture, int probeJoinChannel, AggregatedMemoryContext aggregatedMemoryContext)
{
checkArgument(probeJoinChannel >= 0, "probeJoinChannel is negative");

this.channelSetFuture = channelSetFuture.getChannelSet();
this.probeJoinChannel = probeJoinChannel;
this.probeHashChannel = probeHashChannel.orElse(NO_PRECOMPUTED_HASH_CHANNEL);
this.localMemoryContext = aggregatedMemoryContext.newLocalMemoryContext(SemiJoinPages.class.getSimpleName());
}

Expand Down Expand Up @@ -190,7 +179,6 @@ public TransformationState<Page> process(Page inputPage)

Block probeBlock = inputPage.getBlock(probeJoinChannel).copyRegion(0, inputPage.getPositionCount());
boolean probeMayHaveNull = probeBlock.mayHaveNull();
Block hashBlock = probeHashChannel >= 0 ? inputPage.getBlock(probeHashChannel).copyRegion(0, inputPage.getPositionCount()) : null;

// update hashing strategy to use probe cursor
for (int position = 0; position < inputPage.getPositionCount(); position++) {
Expand All @@ -203,14 +191,7 @@ public TransformationState<Page> process(Page inputPage)
}
}
else {
boolean contains;
if (hashBlock != null) {
long rawHash = BIGINT.getLong(hashBlock, position);
contains = channelSet.contains(probeBlock, position, rawHash);
}
else {
contains = channelSet.contains(probeBlock, position);
}
boolean contains = channelSet.contains(probeBlock, position);
if (!contains && channelSet.containsNull()) {
blockBuilder.appendNull();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ public static OperatorFactory join(
boolean hasFilter,
List<Type> probeTypes,
List<Integer> probeJoinChannel,
OptionalInt probeHashChannel,
Optional<List<Integer>> probeOutputChannelsOptional)
{
List<Integer> probeOutputChannels = probeOutputChannelsOptional.orElseGet(() -> rangeList(probeTypes.size()));
Expand All @@ -60,7 +59,7 @@ public static OperatorFactory join(
probeOutputChannelTypes,
lookupSourceFactory.getBuildOutputTypes(),
joinType,
new JoinProbe.JoinProbeFactory(probeOutputChannels, probeJoinChannel, probeHashChannel, hasFilter)));
new JoinProbe.JoinProbeFactory(probeOutputChannels, probeJoinChannel, hasFilter)));
}

public static OperatorFactory spillingJoin(
Expand All @@ -70,7 +69,6 @@ public static OperatorFactory spillingJoin(
JoinBridgeManager<? extends LookupSourceFactory> lookupSourceFactory,
List<Type> probeTypes,
List<Integer> probeJoinChannel,
OptionalInt probeHashChannel,
Optional<List<Integer>> probeOutputChannelsOptional,
OptionalInt totalOperatorsCount,
PartitioningSpillerFactory partitioningSpillerFactory,
Expand All @@ -89,11 +87,10 @@ public static OperatorFactory spillingJoin(
probeOutputChannelTypes,
lookupSourceFactory.getBuildOutputTypes(),
joinType,
new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel, probeHashChannel),
new JoinProbeFactory(probeOutputChannels.stream().mapToInt(i -> i).toArray(), probeJoinChannel),
typeOperators,
totalOperatorsCount,
probeJoinChannel,
probeHashChannel,
partitioningSpillerFactory));
}

Expand Down
17 changes: 6 additions & 11 deletions core/trino-main/src/main/java/io/trino/operator/PagesIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -456,19 +456,19 @@ private PagesIndexOrdering createPagesIndexComparator(List<Integer> sortChannels

public Supplier<LookupSource> createLookupSourceSupplier(Session session, List<Integer> joinChannels)
{
return createLookupSourceSupplier(session, joinChannels, OptionalInt.empty(), Optional.empty(), Optional.empty(), ImmutableList.of());
return createLookupSourceSupplier(session, joinChannels, Optional.empty(), Optional.empty(), ImmutableList.of());
}

public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, OptionalInt hashChannel)
public PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels)
{
return createPagesHashStrategy(joinChannels, hashChannel, Optional.empty());
return createPagesHashStrategy(joinChannels, Optional.empty());
}

private PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, OptionalInt hashChannel, Optional<List<Integer>> outputChannels)
private PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Optional<List<Integer>> outputChannels)
{
try {
return joinCompiler.compilePagesHashStrategyFactory(types, joinChannels, outputChannels)
.createPagesHashStrategy(ImmutableList.copyOf(channels), hashChannel);
.createPagesHashStrategy(ImmutableList.copyOf(channels));
}
catch (Exception e) {
log.error(e, "Lookup source compile failed for types=%s error=%s", types, e);
Expand All @@ -480,7 +480,6 @@ private PagesHashStrategy createPagesHashStrategy(List<Integer> joinChannels, Op
outputChannels.orElseGet(() -> rangeList(types.size())),
ImmutableList.copyOf(channels),
joinChannels,
hashChannel,
Optional.empty(),
blockTypeOperators);
}
Expand All @@ -494,12 +493,11 @@ public PagesIndexComparator createChannelComparator(int leftChannel, int rightCh
public LookupSourceSupplier createLookupSourceSupplier(
Session session,
List<Integer> joinChannels,
OptionalInt hashChannel,
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories)
{
return createLookupSourceSupplier(session, joinChannels, hashChannel, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.empty(), defaultHashArraySizeSupplier());
return createLookupSourceSupplier(session, joinChannels, filterFunctionFactory, sortChannel, searchFunctionFactories, Optional.empty(), defaultHashArraySizeSupplier());
}

public PagesSpatialIndexSupplier createPagesSpatialIndex(
Expand All @@ -521,7 +519,6 @@ public PagesSpatialIndexSupplier createPagesSpatialIndex(
public LookupSourceSupplier createLookupSourceSupplier(
Session session,
List<Integer> joinChannels,
OptionalInt hashChannel,
Optional<JoinFilterFunctionFactory> filterFunctionFactory,
Optional<Integer> sortChannel,
List<JoinFilterFunctionFactory> searchFunctionFactories,
Expand All @@ -539,7 +536,6 @@ public LookupSourceSupplier createLookupSourceSupplier(
session,
valueAddresses,
channels,
hashChannel,
filterFunctionFactory,
sortChannel,
searchFunctionFactories,
Expand All @@ -551,7 +547,6 @@ public LookupSourceSupplier createLookupSourceSupplier(
outputChannels.orElseGet(() -> rangeList(types.size())),
channels,
joinChannels,
hashChannel,
sortChannel,
blockTypeOperators);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import io.trino.sql.gen.JoinCompiler;
import io.trino.sql.planner.plan.PlanNodeId;

import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -65,7 +63,6 @@ public static class SetBuilderOperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final Optional<Integer> hashChannel;
private final SetSupplier setProvider;
private final int setChannel;
private final int expectedPositions;
Expand All @@ -78,7 +75,6 @@ public SetBuilderOperatorFactory(
PlanNodeId planNodeId,
Type type,
int setChannel,
Optional<Integer> hashChannel,
int expectedPositions,
JoinCompiler joinCompiler,
TypeOperators typeOperators)
Expand All @@ -88,7 +84,6 @@ public SetBuilderOperatorFactory(
checkArgument(setChannel >= 0, "setChannel is negative");
this.setProvider = new SetSupplier(requireNonNull(type, "type is null"));
this.setChannel = setChannel;
this.hashChannel = requireNonNull(hashChannel, "hashChannel is null");
this.expectedPositions = expectedPositions;
this.joinCompiler = requireNonNull(joinCompiler, "joinCompiler is null");
this.typeOperators = requireNonNull(typeOperators, "blockTypeOperators is null");
Expand All @@ -104,7 +99,7 @@ public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
OperatorContext operatorContext = driverContext.addOperatorContext(operatorId, planNodeId, SetBuilderOperator.class.getSimpleName());
return new SetBuilderOperator(operatorContext, setProvider, setChannel, hashChannel, expectedPositions, joinCompiler, typeOperators);
return new SetBuilderOperator(operatorContext, setProvider, setChannel, expectedPositions, joinCompiler, typeOperators);
}

@Override
Expand All @@ -116,14 +111,13 @@ public void noMoreOperators()
@Override
public OperatorFactory duplicate()
{
return new SetBuilderOperatorFactory(operatorId, planNodeId, setProvider.getType(), setChannel, hashChannel, expectedPositions, joinCompiler, typeOperators);
return new SetBuilderOperatorFactory(operatorId, planNodeId, setProvider.getType(), setChannel, expectedPositions, joinCompiler, typeOperators);
}
}

private final OperatorContext operatorContext;
private final SetSupplier setSupplier;
private final int setChannel;
private final int hashChannel;

private final ChannelSetBuilder channelSetBuilder;

Expand All @@ -133,7 +127,6 @@ public SetBuilderOperator(
OperatorContext operatorContext,
SetSupplier setSupplier,
int setChannel,
Optional<Integer> hashChannel,
int expectedPositions,
JoinCompiler joinCompiler,
TypeOperators typeOperators)
Expand All @@ -142,7 +135,6 @@ public SetBuilderOperator(
this.setSupplier = requireNonNull(setSupplier, "setSupplier is null");

this.setChannel = setChannel;
this.hashChannel = hashChannel.orElse(-1);

// Set builder has a single channel which goes in channel 0, if hash is present, add a hashBlock to channel 1
this.channelSetBuilder = new ChannelSetBuilder(
Expand Down Expand Up @@ -189,7 +181,7 @@ public void addInput(Page page)
requireNonNull(page, "page is null");
checkState(!isFinished(), "Operator is already finished");

channelSetBuilder.addAll(page.getBlock(setChannel), hashChannel == -1 ? null : page.getBlock(hashChannel));
channelSetBuilder.addAll(page.getBlock(setChannel));
}

@Override
Expand Down
Loading
Loading