diff --git a/core/trino-main/src/main/java/io/trino/operator/DynamicFilterSourceOperator.java b/core/trino-main/src/main/java/io/trino/operator/DynamicFilterSourceOperator.java index 3661584a51bd..cef72034675f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DynamicFilterSourceOperator.java +++ b/core/trino-main/src/main/java/io/trino/operator/DynamicFilterSourceOperator.java @@ -24,6 +24,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.Type; +import io.trino.sql.planner.DynamicFilterSourceConsumer; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.type.BlockTypeOperators; @@ -32,7 +33,6 @@ import javax.annotation.Nullable; import java.util.List; -import java.util.function.Consumer; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; @@ -77,7 +77,7 @@ public static class DynamicFilterSourceOperatorFactory { private final int operatorId; private final PlanNodeId planNodeId; - private final Consumer> dynamicPredicateConsumer; + private final DynamicFilterSourceConsumer dynamicPredicateConsumer; private final List channels; private final int maxDisinctValues; private final DataSize maxFilterSize; @@ -85,11 +85,12 @@ public static class DynamicFilterSourceOperatorFactory private final BlockTypeOperators blockTypeOperators; private boolean closed; + private int createdOperatorsCount; public DynamicFilterSourceOperatorFactory( int operatorId, PlanNodeId planNodeId, - Consumer> dynamicPredicateConsumer, + DynamicFilterSourceConsumer dynamicPredicateConsumer, List channels, int maxDisinctValues, DataSize maxFilterSize, @@ -114,6 +115,7 @@ public DynamicFilterSourceOperatorFactory( public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); + createdOperatorsCount++; return new DynamicFilterSourceOperator( driverContext.addOperatorContext(operatorId, planNodeId, DynamicFilterSourceOperator.class.getSimpleName()), dynamicPredicateConsumer, @@ -130,6 +132,7 @@ public void noMoreOperators() { checkState(!closed, "Factory is already closed"); closed = true; + dynamicPredicateConsumer.setPartitionCount(createdOperatorsCount); } @Override @@ -142,7 +145,7 @@ public OperatorFactory duplicate() private final OperatorContext context; private boolean finished; private Page current; - private final Consumer> dynamicPredicateConsumer; + private final DynamicFilterSourceConsumer dynamicPredicateConsumer; private final int maxDistinctValues; private final long maxFilterSizeInBytes; @@ -164,7 +167,7 @@ public OperatorFactory duplicate() private DynamicFilterSourceOperator( OperatorContext context, - Consumer> dynamicPredicateConsumer, + DynamicFilterSourceConsumer dynamicPredicateConsumer, List channels, PlanNodeId planNodeId, int maxDistinctValues, @@ -270,7 +273,7 @@ private void handleTooLargePredicate() // The resulting predicate is too large if (minMaxChannels.isEmpty()) { // allow all probe-side values to be read. - dynamicPredicateConsumer.accept(TupleDomain.all()); + dynamicPredicateConsumer.addPartition(TupleDomain.all()); } else { if (minMaxCollectionLimit < 0) { @@ -294,7 +297,7 @@ private void handleTooLargePredicate() private void handleMinMaxCollectionLimitExceeded() { // allow all probe-side values to be read. - dynamicPredicateConsumer.accept(TupleDomain.all()); + dynamicPredicateConsumer.addPartition(TupleDomain.all()); // Drop references to collected values. minValues = null; maxValues = null; @@ -387,7 +390,7 @@ public void finish() } minValues = null; maxValues = null; - dynamicPredicateConsumer.accept(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow())); + dynamicPredicateConsumer.addPartition(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow())); return; } for (int channelIndex = 0; channelIndex < channels.size(); ++channelIndex) { @@ -397,7 +400,7 @@ public void finish() } valueSets = null; blockBuilders = null; - dynamicPredicateConsumer.accept(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow())); + dynamicPredicateConsumer.addPartition(TupleDomain.withColumnDomains(domainsBuilder.buildOrThrow())); } private Domain convertToDomain(Type type, Block block) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/DynamicFilterSourceConsumer.java b/core/trino-main/src/main/java/io/trino/sql/planner/DynamicFilterSourceConsumer.java new file mode 100644 index 000000000000..ed9ef53d5cc8 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/DynamicFilterSourceConsumer.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.planner.plan.DynamicFilterId; + +public interface DynamicFilterSourceConsumer +{ + void addPartition(TupleDomain tupleDomain); + + void setPartitionCount(int partitionCount); +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java index b9a70320d576..98bfa93158a3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java @@ -24,15 +24,17 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.PlanNode; +import javax.annotation.concurrent.GuardedBy; + import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Consumer; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -40,7 +42,9 @@ import static java.util.function.Function.identity; public class LocalDynamicFilterConsumer + implements DynamicFilterSourceConsumer { + private static final int PARTITION_COUNT_INITIAL_VALUE = -1; // Mapping from dynamic filter ID to its build channel indices. private final Map buildChannels; @@ -49,22 +53,22 @@ public class LocalDynamicFilterConsumer private final SettableFuture> resultFuture; - // Number of build-side partitions to be collected. - private final int partitionCount; + // Number of build-side partitions to be collected, must be provided by setPartitionCount + @GuardedBy("this") + private int expectedPartitionCount = PARTITION_COUNT_INITIAL_VALUE; // The resulting predicates from each build-side partition. + @GuardedBy("this") private final List> partitions; - public LocalDynamicFilterConsumer(Map buildChannels, Map filterBuildTypes, int partitionCount) + public LocalDynamicFilterConsumer(Map buildChannels, Map filterBuildTypes) { this.buildChannels = requireNonNull(buildChannels, "buildChannels is null"); this.filterBuildTypes = requireNonNull(filterBuildTypes, "filterBuildTypes is null"); verify(buildChannels.keySet().equals(filterBuildTypes.keySet()), "filterBuildTypes and buildChannels must have same keys"); this.resultFuture = SettableFuture.create(); - - this.partitionCount = partitionCount; - this.partitions = new ArrayList<>(partitionCount); + this.partitions = new ArrayList<>(); } public ListenableFuture> getDynamicFilterDomains() @@ -72,16 +76,38 @@ public ListenableFuture> getDynamicFilterDomains() return Futures.transform(resultFuture, this::convertTupleDomain, directExecutor()); } - private void addPartition(TupleDomain tupleDomain) + @Override + public void addPartition(TupleDomain tupleDomain) { + if (resultFuture.isDone()) { + return; + } TupleDomain result = null; synchronized (this) { // Called concurrently by each DynamicFilterSourceOperator instance (when collection is over). - verify(partitions.size() < partitionCount); + verify(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE || partitions.size() < expectedPartitionCount); // NOTE: may result in a bit more relaxed constraint if there are multiple columns and multiple rows. // See the comment at TupleDomain::columnWiseUnion() for more details. partitions.add(tupleDomain); - if (partitions.size() == partitionCount || tupleDomain.isAll()) { + if (partitions.size() == expectedPartitionCount || tupleDomain.isAll()) { + // No more partitions are left to be processed. + result = TupleDomain.columnWiseUnion(partitions); + } + } + + if (result != null) { + resultFuture.set(result); + } + } + + @Override + public void setPartitionCount(int partitionCount) + { + TupleDomain result = null; + synchronized (this) { + checkState(expectedPartitionCount == PARTITION_COUNT_INITIAL_VALUE, "setPartitionCount should be called only once"); + expectedPartitionCount = partitionCount; + if (partitions.size() == expectedPartitionCount) { // No more partitions are left to be processed. result = TupleDomain.columnWiseUnion(partitions); } @@ -109,7 +135,6 @@ private Map convertTupleDomain(TupleDomain buildSourceTypes, - int partitionCount, Set collectedFilters) { checkArgument(!planNode.getDynamicFilters().isEmpty(), "Join node dynamicFilters is empty."); @@ -134,7 +159,7 @@ public static LocalDynamicFilterConsumer create( .collect(toImmutableMap( Map.Entry::getKey, entry -> buildSourceTypes.get(entry.getValue()))); - return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes, partitionCount); + return new LocalDynamicFilterConsumer(buildChannels, filterBuildTypes); } public Map getBuildChannels() @@ -142,17 +167,12 @@ public Map getBuildChannels() return buildChannels; } - public Consumer> getTupleDomainConsumer() - { - return this::addPartition; - } - @Override public String toString() { return toStringHelper(this) .add("buildChannels", buildChannels) - .add("partitionCount", partitionCount) + .add("expectedPartitionCount", expectedPartitionCount) .add("partitions", partitions) .toString(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index e5cbeeebdc3c..cf2097fe9bcf 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -2553,7 +2553,7 @@ private PhysicalOperation createNestedLoopJoin(JoinNode node, Set localDynamicFilter = createDynamicFilter(buildSource, node, context, partitionCount, localDynamicFilters); + Optional localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters); if (localDynamicFilter.isPresent()) { buildSource = createDynamicFilterSourceOperatorFactory(operatorId, localDynamicFilter.get(), node, buildSource, buildContext); } @@ -2817,7 +2817,7 @@ private JoinBridgeManager createLookupSourceFact buildOutputTypes); int operatorId = buildContext.getNextOperatorId(); - Optional localDynamicFilter = createDynamicFilter(buildSource, node, context, partitionCount, localDynamicFilters); + Optional localDynamicFilter = createDynamicFilter(buildSource, node, context, localDynamicFilters); if (localDynamicFilter.isPresent()) { buildSource = createDynamicFilterSourceOperatorFactory(operatorId, localDynamicFilter.get(), node, buildSource, buildContext); } @@ -2874,7 +2874,7 @@ private PhysicalOperation createDynamicFilterSourceOperatorFactory( new DynamicFilterSourceOperatorFactory( operatorId, node.getId(), - dynamicFilter.getTupleDomainConsumer(), + dynamicFilter, filterBuildChannels, multipleIf(getDynamicFilteringMaxDistinctValuesPerDriver(session, isReplicatedJoin), taskConcurrency, isBuildSideSingle), multipleIf(getDynamicFilteringMaxSizePerDriver(session, isReplicatedJoin), taskConcurrency, isBuildSideSingle), @@ -2899,7 +2899,6 @@ private Optional createDynamicFilter( PhysicalOperation buildSource, JoinNode node, LocalExecutionPlanContext context, - int partitionCount, Set localDynamicFilters) { Set coordinatorDynamicFilters = getCoordinatorDynamicFilters(node.getDynamicFilters().keySet(), node, context.getTaskId()); @@ -2914,7 +2913,7 @@ private Optional createDynamicFilter( buildSource.getPipelineExecutionStrategy() != GROUPED_EXECUTION, "Dynamic filtering cannot be used with grouped execution"); log.debug("[Join] Dynamic filters: %s", node.getDynamicFilters()); - LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), partitionCount, collectedDynamicFilters); + LocalDynamicFilterConsumer filterConsumer = LocalDynamicFilterConsumer.create(node, buildSource.getTypes(), collectedDynamicFilters); ListenableFuture> domainsFuture = filterConsumer.getDynamicFilterDomains(); if (!localDynamicFilters.isEmpty()) { addSuccessCallback(domainsFuture, context::addLocalDynamicFilters); @@ -3080,8 +3079,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont log.debug("[Semi-join] Dynamic filter: %s", filterId); LocalDynamicFilterConsumer filterConsumer = new LocalDynamicFilterConsumer( ImmutableMap.of(filterId, buildChannel), - ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel)), - partitionCount); + ImmutableMap.of(filterId, buildSource.getTypes().get(buildChannel))); ListenableFuture> domainsFuture = filterConsumer.getDynamicFilterDomains(); if (isLocalDynamicFilter) { addSuccessCallback(domainsFuture, context::addLocalDynamicFilters); @@ -3094,7 +3092,7 @@ public PhysicalOperation visitSemiJoin(SemiJoinNode node, LocalExecutionPlanCont new DynamicFilterSourceOperatorFactory( operatorId, node.getId(), - filterConsumer.getTupleDomainConsumer(), + filterConsumer, ImmutableList.of(new DynamicFilterSourceOperator.Channel(filterId, buildSource.getTypes().get(buildChannel), buildChannel)), getDynamicFilteringMaxDistinctValuesPerDriver(session, isReplicatedJoin), getDynamicFilteringMaxSizePerDriver(session, isReplicatedJoin), diff --git a/core/trino-main/src/test/java/io/trino/operator/BenchmarkDynamicFilterSourceOperator.java b/core/trino-main/src/test/java/io/trino/operator/BenchmarkDynamicFilterSourceOperator.java index 855619d8d4f7..13efb7be98a4 100644 --- a/core/trino-main/src/test/java/io/trino/operator/BenchmarkDynamicFilterSourceOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/BenchmarkDynamicFilterSourceOperator.java @@ -17,7 +17,9 @@ import io.airlift.units.DataSize; import io.trino.spi.Page; import io.trino.spi.PageBuilder; +import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.TypeOperators; +import io.trino.sql.planner.DynamicFilterSourceConsumer; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.TestingTaskContext; @@ -93,7 +95,13 @@ public void setup() operatorFactory = new DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory( 1, new PlanNodeId("joinNodeId"), - (tupleDomain -> {}), + new DynamicFilterSourceConsumer() { + @Override + public void addPartition(TupleDomain tupleDomain) {} + + @Override + public void setPartitionCount(int partitionCount) {} + }, ImmutableList.of(new DynamicFilterSourceOperator.Channel(new DynamicFilterId("0"), BIGINT, 0)), maxDistinctValuesCount, DataSize.ofBytes(Long.MAX_VALUE), diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java b/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java index 16b325c2c378..05c6838ef6f8 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDynamicFilterSourceOperator.java @@ -24,6 +24,7 @@ import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; +import io.trino.sql.planner.DynamicFilterSourceConsumer; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; @@ -132,7 +133,16 @@ private OperatorFactory createOperatorFactory( return new DynamicFilterSourceOperator.DynamicFilterSourceOperatorFactory( 0, new PlanNodeId("PLAN_NODE_ID"), - this::consumePredicate, + new DynamicFilterSourceConsumer() { + @Override + public void addPartition(TupleDomain tupleDomain) + { + partitions.add(tupleDomain); + } + + @Override + public void setPartitionCount(int partitionCount) {} + }, ImmutableList.copyOf(buildChannels), maxFilterDistinctValues, maxFilterSize, @@ -140,11 +150,6 @@ private OperatorFactory createOperatorFactory( blockTypeOperators); } - private void consumePredicate(TupleDomain partitionPredicate) - { - partitions.add(partitionPredicate); - } - private Operator createOperator(OperatorFactory operatorFactory) { return operatorFactory.createOperator(pipelineContext.addDriverContext()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFilterConsumer.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFilterConsumer.java index 2d7bf8fd056e..219174f66816 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFilterConsumer.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLocalDynamicFilterConsumer.java @@ -31,7 +31,6 @@ import java.util.Map; import java.util.Optional; -import java.util.function.Consumer; import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING; import static io.trino.SystemSessionProperties.FORCE_SINGLE_NODE_OUTPUT; @@ -63,14 +62,13 @@ public void testSimple() { LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMap.of(new DynamicFilterId("123"), 0), - ImmutableMap.of(new DynamicFilterId("123"), INTEGER), - 1); + ImmutableMap.of(new DynamicFilterId("123"), INTEGER)); + filter.setPartitionCount(1); assertEquals(filter.getBuildChannels(), ImmutableMap.of(new DynamicFilterId("123"), 0)); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 7L)))); assertEquals(result.get(), ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 7L))); @@ -82,19 +80,18 @@ public void testShortCircuitOnAllTupleDomain() { LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMap.of(new DynamicFilterId("123"), 0), - ImmutableMap.of(new DynamicFilterId("123"), INTEGER), - 2); + ImmutableMap.of(new DynamicFilterId("123"), INTEGER)); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.all(INTEGER)))); assertEquals(result.get(), ImmutableMap.of(new DynamicFilterId("123"), Domain.all(INTEGER))); + filter.setPartitionCount(2); // adding another partition domain won't change final domain - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 1L)))); assertEquals(result.get(), ImmutableMap.of(new DynamicFilterId("123"), Domain.all(INTEGER))); } @@ -105,20 +102,20 @@ public void testMultiplePartitions() { LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMap.of(new DynamicFilterId("123"), 0), - ImmutableMap.of(new DynamicFilterId("123"), INTEGER), - 2); + ImmutableMap.of(new DynamicFilterId("123"), INTEGER)); assertEquals(filter.getBuildChannels(), ImmutableMap.of(new DynamicFilterId("123"), 0)); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 10L)))); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 20L)))); + assertFalse(result.isDone()); + filter.setPartitionCount(2); assertEquals(result.get(), ImmutableMap.of( new DynamicFilterId("123"), Domain.multipleValues(INTEGER, ImmutableList.of(10L, 20L)))); } @@ -135,14 +132,13 @@ public void testAllDomain() filter2, 1), ImmutableMap.of( filter1, INTEGER, - filter2, INTEGER), - 1); + filter2, INTEGER)); + filter.setPartitionCount(1); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( filter1, Domain.all(INTEGER), filter2, Domain.singleValue(INTEGER, 1L)))); assertEquals(result.get(), ImmutableMap.of(filter1, Domain.all(INTEGER), filter2, Domain.singleValue(INTEGER, 1L))); @@ -154,14 +150,13 @@ public void testNone() { LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMap.of(new DynamicFilterId("123"), 0), - ImmutableMap.of(new DynamicFilterId("123"), INTEGER), - 1); + ImmutableMap.of(new DynamicFilterId("123"), INTEGER)); + filter.setPartitionCount(1); assertEquals(filter.getBuildChannels(), ImmutableMap.of(new DynamicFilterId("123"), 0)); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.none()); + filter.addPartition(TupleDomain.none()); assertEquals(result.get(), ImmutableMap.of( new DynamicFilterId("123"), Domain.none(INTEGER))); @@ -173,14 +168,13 @@ public void testMultipleColumns() { LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMap.of(new DynamicFilterId("123"), 0, new DynamicFilterId("456"), 1), - ImmutableMap.of(new DynamicFilterId("123"), INTEGER, new DynamicFilterId("456"), INTEGER), - 1); + ImmutableMap.of(new DynamicFilterId("123"), INTEGER, new DynamicFilterId("456"), INTEGER)); + filter.setPartitionCount(1); assertEquals(filter.getBuildChannels(), ImmutableMap.of(new DynamicFilterId("123"), 0, new DynamicFilterId("456"), 1)); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 10L), new DynamicFilterId("456"), Domain.singleValue(INTEGER, 20L)))); assertEquals(result.get(), ImmutableMap.of( @@ -194,19 +188,18 @@ public void testMultiplePartitionsAndColumns() { LocalDynamicFilterConsumer filter = new LocalDynamicFilterConsumer( ImmutableMap.of(new DynamicFilterId("123"), 0, new DynamicFilterId("456"), 1), - ImmutableMap.of(new DynamicFilterId("123"), INTEGER, new DynamicFilterId("456"), BIGINT), - 2); + ImmutableMap.of(new DynamicFilterId("123"), INTEGER, new DynamicFilterId("456"), BIGINT)); + filter.setPartitionCount(2); assertEquals(filter.getBuildChannels(), ImmutableMap.of(new DynamicFilterId("123"), 0, new DynamicFilterId("456"), 1)); - Consumer> consumer = filter.getTupleDomainConsumer(); ListenableFuture> result = filter.getDynamicFilterDomains(); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 10L), new DynamicFilterId("456"), Domain.singleValue(BIGINT, 100L)))); assertFalse(result.isDone()); - consumer.accept(TupleDomain.withColumnDomains(ImmutableMap.of( + filter.addPartition(TupleDomain.withColumnDomains(ImmutableMap.of( new DynamicFilterId("123"), Domain.singleValue(INTEGER, 20L), new DynamicFilterId("456"), Domain.singleValue(BIGINT, 200L)))); @@ -246,12 +239,14 @@ public void testDynamicFilterPruning() LocalDynamicFilterConsumer consumer = LocalDynamicFilterConsumer.create( joinNode, ImmutableList.of(BIGINT, INTEGER, SMALLINT), - 1, ImmutableSet.of(filter1, filter3)); assertEquals(consumer.getBuildChannels(), ImmutableMap.of(filter1, 0, filter3, 2)); // make sure domain types got propagated correctly - consumer.getTupleDomainConsumer().accept(TupleDomain.none()); + assertFalse(consumer.getDynamicFilterDomains().isDone()); + consumer.addPartition(TupleDomain.none()); + assertFalse(consumer.getDynamicFilterDomains().isDone()); + consumer.setPartitionCount(1); assertEquals( consumer.getDynamicFilterDomains().get(), ImmutableMap.of(filter1, Domain.none(BIGINT), filter3, Domain.none(SMALLINT)));