Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@
package io.trino.sql.planner;

import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.Type;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;

import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

Expand All @@ -51,71 +49,106 @@ public class LocalDynamicFilterConsumer
// Mapping from dynamic filter ID to its build channel type.
private final Map<DynamicFilterId, Type> filterBuildTypes;

private final SettableFuture<TupleDomain<DynamicFilterId>> resultFuture;
private final List<Consumer<Map<DynamicFilterId, Domain>>> collectors;

// Number of build-side partitions to be collected, must be provided by setPartitionCount
@GuardedBy("this")
private int expectedPartitionCount = PARTITION_COUNT_INITIAL_VALUE;

@GuardedBy("this")
private boolean collected;

// The resulting predicates from each build-side partition.
@Nullable
@GuardedBy("this")
private final List<TupleDomain<DynamicFilterId>> partitions;
private List<TupleDomain<DynamicFilterId>> partitions;

public LocalDynamicFilterConsumer(Map<DynamicFilterId, Integer> buildChannels, Map<DynamicFilterId, Type> filterBuildTypes)
public LocalDynamicFilterConsumer(
Map<DynamicFilterId, Integer> buildChannels,
Map<DynamicFilterId, Type> filterBuildTypes,
List<Consumer<Map<DynamicFilterId, Domain>>> collectors)
{
this.buildChannels = requireNonNull(buildChannels, "buildChannels is null");
this.filterBuildTypes = requireNonNull(filterBuildTypes, "filterBuildTypes is null");
verify(buildChannels.keySet().equals(filterBuildTypes.keySet()), "filterBuildTypes and buildChannels must have same keys");

this.resultFuture = SettableFuture.create();
requireNonNull(collectors, "collectors is null");
checkArgument(!collectors.isEmpty(), "collectors is empty");
this.collectors = collectors;
this.partitions = new ArrayList<>();
}

public ListenableFuture<Map<DynamicFilterId, Domain>> getDynamicFilterDomains()
{
return Futures.transform(resultFuture, this::convertTupleDomain, directExecutor());
}

@Override
public void addPartition(TupleDomain<DynamicFilterId> tupleDomain)
{
if (resultFuture.isDone()) {
return;
}
TupleDomain<DynamicFilterId> result = null;
TupleDomain<DynamicFilterId> result;
synchronized (this) {
if (collected) {
return;
}
requireNonNull(partitions, "partitions is null");
// Called concurrently by each DynamicFilterSourceOperator instance (when collection is over).
verify(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE || partitions.size() < expectedPartitionCount);
// NOTE: may result in a bit more relaxed constraint if there are multiple columns and multiple rows.
// See the comment at TupleDomain::columnWiseUnion() for more details.
partitions.add(tupleDomain);
if (partitions.size() == expectedPartitionCount || tupleDomain.isAll()) {
if (tupleDomain.isAll()) {
result = tupleDomain;
}
else if (partitions.size() == expectedPartitionCount) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
if (partitions.isEmpty()) {
result = TupleDomain.none();
}
else {
result = TupleDomain.columnWiseUnion(partitions);
}
}
else {
return;
}
collected = true;
partitions = null;
}

if (result != null) {
resultFuture.set(result);
}
notifyConsumers(result);
}

@Override
public void setPartitionCount(int partitionCount)
{
TupleDomain<DynamicFilterId> result = null;
TupleDomain<DynamicFilterId> result;
synchronized (this) {
if (collected) {
return;
}
checkState(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE, "setPartitionCount should be called only once");
requireNonNull(partitions, "partitions is null");
expectedPartitionCount = partitionCount;
if (partitions.size() == expectedPartitionCount) {
// No more partitions are left to be processed.
result = TupleDomain.columnWiseUnion(partitions);
if (partitions.isEmpty()) {
result = TupleDomain.none();
}
else {
result = TupleDomain.columnWiseUnion(partitions);
}
collected = true;
partitions = null;
}
else {
return;
}
}

if (result != null) {
resultFuture.set(result);
}
notifyConsumers(result);
}

private void notifyConsumers(TupleDomain<DynamicFilterId> result)
{
requireNonNull(result, "result is null");
Map<DynamicFilterId, Domain> dynamicFilterDomains = convertTupleDomain(result);
collectors.forEach(consumer -> consumer.accept(dynamicFilterDomains));
}

private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilterId> result)
Expand All @@ -135,7 +168,8 @@ private Map<DynamicFilterId, Domain> convertTupleDomain(TupleDomain<DynamicFilte
public static LocalDynamicFilterConsumer create(
JoinNode planNode,
List<Type> buildSourceTypes,
Set<DynamicFilterId> collectedFilters)
Set<DynamicFilterId> collectedFilters,
List<Consumer<Map<DynamicFilterId, Domain>>> collectors)
{
checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty.");
checkArgument(!collectedFilters.isEmpty(), "Collected dynamic filters set is empty");
Expand All @@ -159,7 +193,7 @@ public static LocalDynamicFilterConsumer create(
.collect(toImmutableMap(
Map.Entry::getKey,
entry -> buildSourceTypes.get(entry.getValue())));
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes);
return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes, collectors);
}

public Map<DynamicFilterId, Integer> getBuildChannels()
Expand All @@ -168,11 +202,12 @@ public Map<DynamicFilterId, Integer> getBuildChannels()
}

@Override
public String toString()
public synchronized String toString()
{
return toStringHelper(this)
.add("buildChannels", buildChannels)
.add("expectedPartitionCount", expectedPartitionCount)
.add("collected", collected)
.add("partitions", partitions)
.toString();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.airlift.units.DataSize;
import io.trino.Session;
Expand Down Expand Up @@ -268,6 +267,7 @@
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
Expand All @@ -286,7 +286,6 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Range.closedOpen;
import static com.google.common.collect.Sets.difference;
import static io.airlift.concurrent.MoreFutures.addSuccessCallback;
import static io.trino.SystemSessionProperties.getAdaptivePartialAggregationMinRows;
import static io.trino.SystemSessionProperties.getAdaptivePartialAggregationUniqueRowsRatioThreshold;
import static io.trino.SystemSessionProperties.getAggregationOperatorUnspillMemoryLimit;
Expand Down Expand Up @@ -745,9 +744,12 @@ private void registerCoordinatorDynamicFilters(List<DynamicFilters.Descriptor> d
difference(consumedFilterIds, dynamicFiltersCollector.getRegisteredDynamicFilterIds()));
}

private void addCoordinatorDynamicFilters(Map<DynamicFilterId, Domain> dynamicTupleDomain)
private Consumer<Map<DynamicFilterId, Domain>> getCoordinatorDynamicFilterDomainsCollector(Set<DynamicFilterId> coordinatorDynamicFilters)
{
taskContext.updateDomains(dynamicTupleDomain);
return domains -> taskContext.updateDomains(
domains.entrySet().stream()
.filter(entry -> coordinatorDynamicFilters.contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)));
}

public Optional<IndexSourceContext> getIndexSourceContext()
Expand Down Expand Up @@ -2913,18 +2915,19 @@ private Optional<LocalDynamicFilterConsumer> createDynamicFilter(
buildSource.getPipelineExecutionStrategy() != GROUPED_EXECUTION,
"Dynamic filtering cannot be used with grouped execution");
log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters());
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), collectedDynamicFilters);
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
ImmutableList.Builder<Consumer<Map<DynamicFilterId, Domain>>> collectors = ImmutableList.builder();
if (!localDynamicFilters.isEmpty()) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
collectors.add(context::addLocalDynamicFilters);
}
if (!coordinatorDynamicFilters.isEmpty()) {
addSuccessCallback(
domainsFuture,
domains -> context.addCoordinatorDynamicFilters(domains.entrySet().stream()
.filter(entry -> coordinatorDynamicFilters.contains(entry.getKey()))
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue))));
collectors.add(context.getCoordinatorDynamicFilterDomainsCollector(coordinatorDynamicFilters));
}
LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(
node,
buildSource.getTypes(),
collectedDynamicFilters,
collectors.build());

return Optional.of(filterConsumer);
}

Expand Down Expand Up @@ -3077,17 +3080,18 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont
// Add a DynamicFilterSourceOperatorFactory to build operator factories
DynamicFilterId filterId = node.getDynamicFilterId().get();
log.debug("[Semi-join] Dynamic filter: %s", filterId);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)));
ListenableFuture<Map<DynamicFilterId, Domain>> domainsFuture = filterConsumer.getDynamicFilterDomains();
ImmutableList.Builder<Consumer<Map<DynamicFilterId, Domain>>> collectors = ImmutableList.builder();
if (isLocalDynamicFilter) {
addSuccessCallback(domainsFuture, context::addLocalDynamicFilters);
collectors.add(context::addLocalDynamicFilters);
}
if (isCoordinatorDynamicFilter) {
addSuccessCallback(domainsFuture, context::addCoordinatorDynamicFilters);
collectors.add(context.getCoordinatorDynamicFilterDomainsCollector(ImmutableSet.of(filterId)));
}
boolean isReplicatedJoin = isBuildSideReplicated(node);
LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer(
ImmutableMap.of(filterId, buildChannel),
ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)),
collectors.build());
buildSource = new PhysicalOperation(
new DynamicFilterSourceOperatorFactory(
operatorId,
Expand Down
Loading