Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.Type;
import io.trino.sql.planner.DynamicFilterSourceConsumer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.type.BlockTypeOperators;
Expand All @@ -32,7 +33,6 @@
import javax.annotation.Nullable;

import java.util.List;
import java.util.function.Consumer;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
Expand Down Expand Up @@ -77,19 +77,20 @@ public static class DynamicFilterSourceOperatorFactory
{
private final int operatorId;
private final PlanNodeId planNodeId;
private final Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer;
private final DynamicFilterSourceConsumer dynamicPredicateConsumer;
private final List<Channel> channels;
private final int maxDisinctValues;
private final DataSize maxFilterSize;
private final int minMaxCollectionLimit;
private final BlockTypeOperators blockTypeOperators;

private boolean closed;
private int createdOperatorsCount;

public DynamicFilterSourceOperatorFactory(
int operatorId,
PlanNodeId planNodeId,
Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer,
DynamicFilterSourceConsumer dynamicPredicateConsumer,
List<Channel> channels,
int maxDisinctValues,
DataSize maxFilterSize,
Expand All @@ -114,6 +115,7 @@ public DynamicFilterSourceOperatorFactory(
public Operator createOperator(DriverContext driverContext)
{
checkState(!closed, "Factory is already closed");
createdOperatorsCount++;
return new DynamicFilterSourceOperator(
driverContext.addOperatorContext(operatorId, planNodeId, DynamicFilterSourceOperator.class.getSimpleName()),
dynamicPredicateConsumer,
Expand All @@ -130,6 +132,7 @@ public void noMoreOperators()
{
checkState(!closed, "Factory is already closed");
closed = true;
dynamicPredicateConsumer.setPartitionCount(createdOperatorsCount);
}

@Override
Expand All @@ -142,7 +145,7 @@ public OperatorFactory duplicate()
private final OperatorContext context;
private boolean finished;
private Page current;
private final Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer;
private final DynamicFilterSourceConsumer dynamicPredicateConsumer;
private final int maxDistinctValues;
private final long maxFilterSizeInBytes;

Expand All @@ -164,7 +167,7 @@ public OperatorFactory duplicate()

private DynamicFilterSourceOperator(
OperatorContext context,
Consumer<TupleDomain<DynamicFilterId>> dynamicPredicateConsumer,
DynamicFilterSourceConsumer dynamicPredicateConsumer,
List<Channel> channels,
PlanNodeId planNodeId,
int maxDistinctValues,
Expand Down Expand Up @@ -270,7 +273,7 @@ private void handleTooLargePredicate()
// The resulting predicate is too large
if (minMaxChannels.isEmpty()) {
// allow all probe-side values to be read.
dynamicPredicateConsumer.accept(TupleDomain.all());
dynamicPredicateConsumer.addPartition(TupleDomain.all());
}
else {
if (minMaxCollectionLimit < 0) {
Expand All @@ -294,7 +297,7 @@ private void handleTooLargePredicate()
private void handleMinMaxCollectionLimitExceeded()
{
// allow all probe-side values to be read.
dynamicPredicateConsumer.accept(TupleDomain.all());
dynamicPredicateConsumer.addPartition(TupleDomain.all());
// Drop references to collected values.
minValues = null;
maxValues = null;
Expand Down Expand Up @@ -387,7 +390,7 @@ public void finish()
}
minValues = null;
maxValues = null;
dynamicPredicateConsumer.accept(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
dynamicPredicateConsumer.addPartition(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
return;
}
for (int channelIndex = 0; channelIndex < channels.size(); ++channelIndex) {
Expand All @@ -397,7 +400,7 @@ public void finish()
}
valueSets = null;
blockBuilders = null;
dynamicPredicateConsumer.accept(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
dynamicPredicateConsumer.addPartition(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow()));
}

private Domain convertToDomain(Type type, Block block)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.sql.planner;

import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.planner.plan.DynamicFilterId;

public interface DynamicFilterSourceConsumer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: DynamicFilterConsumer, or maybe simply drop this interface and override the method you need for testing?

{
void addPartition(TupleDomain<DynamicFilterId> tupleDomain);

void setPartitionCount(int partitionCount);
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,27 @@
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;

import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

public class LocalDynamicFilterConsumer
implements DynamicFilterSourceConsumer
{
private static final int PARTITION_COUNT_INITIAL_VALUE = -1;
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
// Mapping from dynamic filter ID to its build channel indices.
private final Map<DynamicFilterId, Integer> buildChannels;

Expand All @@ -49,39 +53,61 @@ public class LocalDynamicFilterConsumer

private final SettableFuture<TupleDomain<DynamicFilterId>> resultFuture;

// Number of build-side partitions to be collected.
private final int partitionCount;
// Number of build-side partitions to be collected, must be provided by setPartitionCount
@GuardedBy("this")
private int expectedPartitionCount = PARTITION_COUNT_INITIAL_VALUE;

// The resulting predicates from each build-side partition.
@GuardedBy("this")
private final List<TupleDomain<DynamicFilterId>> partitions;

public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> buildChannels, Map<DynamicFilterId, Type> filterBuildTypes, int partitionCount)
public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> buildChannels, Map<DynamicFilterId, Type> filterBuildTypes)
{
this.buildChannels = requireNonNull(buildChannels, "buildChannels is null");
this.filterBuildTypes = requireNonNull(filterBuildTypes, "filterBuildTypes is null");
verify(buildChannels.keySet().equals(filterBuildTypes.keySet()), "filterBuildTypes and buildChannels must have same keys");

this.resultFuture = SettableFuture.create();

this.partitionCount = partitionCount;
this.partitions = new ArrayList<>(partitionCount);
this.partitions = new ArrayList<>();
}

public ListenableFuture<Map<DynamicFilterId, Domain>> getDynamicFilterDomains()
{
return Futures.transform(resultFuture, this::convertTupleDomain, directExecutor());
}

private void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
Comment thread
raunaqmorarka marked this conversation as resolved.
Outdated
{
if (resultFuture.isDone()) {
return;
}
TupleDomain<DynamicFilterId> result = null;
synchronized (this) {
// Called concurrently by each DynamicFilterSourceOperator instance (when collection is over).
verify(partitions.size() < partitionCount);
verify(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE || partitions.size() < expectedPartitionCount);
// NOTE: may result in a bit more relaxed constraint if there are multiple columns and multiple rows.
// See the comment at TupleDomain::columnWiseUnion() for more details.
partitions.add(tupleDomain);
if (partitions.size() == partitionCount || tupleDomain.isAll()) {
if (partitions.size() == expectedPartitionCount || tupleDomain.isAll()) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
}
}

if (result != null) {
resultFuture.set(result);
}
}

@Override
public void setPartitionCount(int partitionCount)
{
TupleDomain<DynamicFilterId> result = null;
synchronized (this) {
checkState(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE, "setPartitionCount should be called only once");
expectedPartitionCount = partitionCount;
if (partitions.size() == expectedPartitionCount) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
}
Expand Down Expand Up @@ -109,7 +135,6 @@ private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilte
public static LocalDynamicFilterConsumer create(
JoinNode planNode,
List<Type> buildSourceTypes,
int partitionCount,
Set<DynamicFilterId> collectedFilters)
{
checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty.");
Expand All @@ -134,25 +159,20 @@ public static LocalDynamicFilterConsumer create(
.collect(toImmutableMap(
Map.Entry::getKey,
entry -> buildSourceTypes.get(entry.getValue())));
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes, partitionCount);
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes);
}

public Map<DynamicFilterId, Integer> getBuildChannels()
{
return buildChannels;
}

public Consumer<TupleDomain<DynamicFilterId>> getTupleDomainConsumer()
{
return this::addPartition;
}

@Override
public String toString()
{
return toStringHelper(this)
.add("buildChannels", buildChannels)
.add("partitionCount", partitionCount)
.add("expectedPartitionCount", expectedPartitionCount)
.add("partitions", partitions)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2553,7 +2553,7 @@ private PhysicalOperation createNestedLoopJoin(JoinNode node, Set<DynamicFilterI
checkArgument(partitionCount == 1, "Expected local execution to not be parallel");

int operatorId = buildContext.getNextOperatorId();
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, partitionCount, localDynamicFilters);
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters);
if (localDynamicFilter.isPresent()) {
buildSource = createDynamicFilterSourceOperatorFactory(operatorId, localDynamicFilter.get(), node, buildSource, buildContext);
}
Expand Down Expand Up @@ -2817,7 +2817,7 @@ private JoinBridgeManager<PartitionedLookupSourceFactory> createLookupSourceFact
buildOutputTypes);

int operatorId = buildContext.getNextOperatorId();
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, partitionCount, localDynamicFilters);
Optional<LocalDynamicFilterConsumer> localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters);
if (localDynamicFilter.isPresent()) {
buildSource = createDynamicFilterSourceOperatorFactory(operatorId, localDynamicFilter.get(), node, buildSource, buildContext);
}
Expand Down Expand Up @@ -2874,7 +2874,7 @@ private PhysicalOperation createDynamicFilterSourceOperatorFactory(
new DynamicFilterSourceOperatorFactory(
operatorId,
node.getId(),
dynamicFilter.getTupleDomainConsumer(),
dynamicFilter,
filterBuildChannels,
multipleIf(getDynamicFilteringMaxDistinctValuesPerDriver(session, isReplicatedJoin), taskConcurrency, isBuildSideSingle),
multipleIf(getDynamicFilteringMaxSizePerDriver(session, isReplicatedJoin), taskConcurrency, isBuildSideSingle),
Expand All @@ -2899,7 +2899,6 @@ private Optional<LocalDynamicFilterConsumer> createDynamicFilter(
PhysicalOperation buildSource,
JoinNode node,
LocalExecutionPlanContext context,
int partitionCount,
Set<DynamicFilterId> localDynamicFilters)
{
Set<DynamicFilterId> coordinatorDynamicFilters = getCoordinatorDynamicFilters(node.getDynamicFilters().keySet(), node, context.getTaskId());
Expand All @@ -2914,7 +2913,7 @@ private Optional<LocalDynamicFilterConsumer> createDynamicFilter(
buildSource.getPipelineExecutionStrategy() != GROUPED_EXECUTION,
"Dynamic filtering cannot be used with grouped execution");
log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters());
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), partitionCount, collectedDynamicFilters);
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), collectedDynamicFilters);
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
if (!localDynamicFilters.isEmpty()) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
Expand Down Expand Up @@ -3080,8 +3079,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
log.debug("[Semi-join] Dynamic filter: %s", filterId);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)),
partitionCount);
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)));
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
if (isLocalDynamicFilter) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
Expand All @@ -3094,7 +3092,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
new DynamicFilterSourceOperatorFactory(
operatorId,
node.getId(),
filterConsumer.getTupleDomainConsumer(),
filterConsumer,
ImmutableList.of(new DynamicFilterSourceOperator.Channel(filterId, buildSource.getTypes().get(buildChannel), buildChannel)),
getDynamicFilteringMaxDistinctValuesPerDriver(session, isReplicatedJoin),
getDynamicFilteringMaxSizePerDriver(session, isReplicatedJoin),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import io.airlift.units.DataSize;
import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.DynamicFilterSourceConsumer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.TestingTaskContext;
Expand Down Expand Up @@ -93,7 +95,13 @@ public void setup()
operatorFactory = new DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory(
1,
new PlanNodeId("joinNodeId"),
(tupleDomain -> {}),
new DynamicFilterSourceConsumer() {
@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain) {}

@Override
public void setPartitionCount(int partitionCount) {}
},
ImmutableList.of(new DynamicFilterSourceOperator.Channel(new DynamicFilterId("0"), BIGINT, 0)),
maxDistinctValuesCount,
DataSize.ofBytes(Long.MAX_VALUE),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.spi.predicate.ValueSet;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import io.trino.sql.planner.DynamicFilterSourceConsumer;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.testing.MaterializedResult;
Expand Down Expand Up @@ -132,19 +133,23 @@ private OperatorFactory createOperatorFactory(
return new DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory(
0,
new PlanNodeId("PLAN_NODE_ID"),
this::consumePredicate,
new DynamicFilterSourceConsumer() {
@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
{
partitions.add(tupleDomain);
}

@Override
public void setPartitionCount(int partitionCount) {}
},
ImmutableList.copyOf(buildChannels),
maxFilterDistinctValues,
maxFilterSize,
minMaxCollectionLimit,
blockTypeOperators);
}

private void consumePredicate(TupleDomain<DynamicFilterId> partitionPredicate)
{
partitions.add(partitionPredicate);
}

private Operator createOperator(OperatorFactory operatorFactory)
{
return operatorFactory.createOperator(pipelineContext.addDriverContext());
Expand Down
Loading