From b5c48cc74c5b3ff48e900eaa479fc88516e899fc Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Thu, 23 Jun 2022 15:54:04 -0400 Subject: [PATCH 1/7] Refactor DynamicFilterService Union collected domains iteratively --- .../trino/execution/DynamicFilterConfig.java | 17 +- .../io/trino/server/DynamicFilterService.java | 374 ++++++++++-------- .../execution/TestDynamicFilterConfig.java | 3 - .../TestSourcePartitionedScheduler.java | 11 +- .../policy/TestPhasedExecutionSchedule.java | 4 +- .../server/TestDynamicFilterService.java | 51 ++- .../server/remotetask/TestHttpRemoteTask.java | 13 +- 7 files changed, 262 insertions(+), 211 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java b/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java index 10c85e062cfa..b1680382f2f2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java @@ -31,14 +31,14 @@ "dynamic-filtering-max-per-driver-size", "experimental.dynamic-filtering-max-per-driver-size", "dynamic-filtering-range-row-limit-per-driver", - "experimental.dynamic-filtering-refresh-interval" + "experimental.dynamic-filtering-refresh-interval", + "dynamic-filtering.service-thread-count" }) public class DynamicFilterConfig { private boolean enableDynamicFiltering = true; private boolean enableCoordinatorDynamicFiltersDistribution = true; private boolean enableLargeDynamicFilters; - private int serviceThreadCount = 2; private int smallBroadcastMaxDistinctValuesPerDriver = 200; private DataSize smallBroadcastMaxSizePerDriver = DataSize.of(20, KILOBYTE); @@ -96,19 +96,6 @@ public DynamicFilterConfig setEnableLargeDynamicFilters(boolean enableLargeDynam return this; } - @Min(1) - public int getServiceThreadCount() - { - return serviceThreadCount; - } - - @Config("dynamic-filtering.service-thread-count") - public DynamicFilterConfig setServiceThreadCount(int serviceThreadCount) - { - this.serviceThreadCount = serviceThreadCount; - return this; - } - @Min(0) public int getSmallBroadcastMaxDistinctValuesPerDriver() { diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index c696d07a461b..94e8982f1e49 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -16,18 +16,15 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; -import com.google.common.collect.SetMultimap; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Inject; import io.airlift.units.Duration; import io.trino.Session; -import io.trino.execution.DynamicFilterConfig; import io.trino.execution.SqlQueryExecution; import io.trino.execution.StageId; import io.trino.execution.TaskId; @@ -53,28 +50,27 @@ import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.SemiJoinNode; -import javax.annotation.PreDestroy; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import java.util.ArrayList; import java.util.Collection; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; -import java.util.OptionalInt; +import java.util.Queue; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ExecutorService; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static com.google.common.base.Functions.identity; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; @@ -82,13 +78,13 @@ import static com.google.common.collect.Sets.intersection; import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.collect.Sets.union; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.MoreFutures.addSuccessCallback; +import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.MoreFutures.toCompletableFuture; import static io.airlift.concurrent.MoreFutures.unmodifiableFuture; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; -import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.airlift.units.Duration.succinctNanos; import static io.trino.spi.connector.DynamicFilter.EMPTY; -import static io.trino.spi.predicate.Domain.union; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.DynamicFilters.extractSourceSymbols; import static io.trino.sql.planner.DomainCoercer.applySaturatedCasts; @@ -96,7 +92,6 @@ import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static java.lang.String.format; import static java.util.Objects.requireNonNull; -import static java.util.concurrent.Executors.newFixedThreadPool; @ThreadSafe public class DynamicFilterService @@ -104,32 +99,14 @@ public class DynamicFilterService private final Metadata metadata; private final FunctionManager functionManager; private final TypeOperators typeOperators; - private final ExecutorService executor; private final Map dynamicFilterContexts = new ConcurrentHashMap<>(); @Inject - public DynamicFilterService(Metadata metadata, FunctionManager functionManager, TypeOperators typeOperators, DynamicFilterConfig dynamicFilterConfig) - { - this( - metadata, - functionManager, - typeOperators, - newFixedThreadPool(dynamicFilterConfig.getServiceThreadCount(), daemonThreadsNamed("DynamicFilterService"))); - } - - @VisibleForTesting - public DynamicFilterService(Metadata metadata, FunctionManager functionManager, TypeOperators typeOperators, ExecutorService executor) + public DynamicFilterService(Metadata metadata, FunctionManager functionManager, TypeOperators typeOperators) { this.metadata = requireNonNull(metadata, "metadata is null"); this.functionManager = requireNonNull(functionManager, "functionManager is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); - this.executor = requireNonNull(executor, "executor is null"); - } - - @PreDestroy - public void stop() - { - executor.shutdownNow(); } public void registerQuery(SqlQueryExecution sqlQueryExecution, SubPlan fragmentedPlan) @@ -395,7 +372,6 @@ public void addTaskDynamicFilters(TaskId taskId, Map ne taskId, taskAttemptId); context.addTaskDynamicFilters(taskId, newDynamicFilters); - executor.submit(() -> collectDynamicFilters(taskId.getStageId(), Optional.of(newDynamicFilters.keySet()))); } public void stageCannotScheduleMoreTasks(StageId stageId, int attemptId, int numberOfTasks) @@ -412,7 +388,6 @@ public void stageCannotScheduleMoreTasks(StageId stageId, int attemptId, int num stageId, attemptId); context.stageCannotScheduleMoreTasks(stageId, numberOfTasks); - executor.submit(() -> collectDynamicFilters(stageId, Optional.empty())); } public static Set getOutboundDynamicFilters(PlanFragment plan) @@ -423,38 +398,6 @@ public static Set getOutboundDynamicFilters(PlanFragment plan) getProducedDynamicFilters(plan.getRoot()))); } - private void collectDynamicFilters(StageId stageId, Optional> selectedFilters) - { - DynamicFilterContext context = dynamicFilterContexts.get(stageId.getQueryId()); - if (context == null) { - // query has been removed - return; - } - - OptionalInt stageNumberOfTasks = context.getNumberOfTasks(stageId); - Map> newDynamicFilters = context.getTaskDynamicFilters(stageId, selectedFilters).entrySet().stream() - .filter(stageDomains -> { - if (stageDomains.getValue().stream().anyMatch(Domain::isAll)) { - // if one of the domains is all, we don't need to get dynamic filters from all tasks - return true; - } - - if (!stageDomains.getValue().isEmpty() && context.getReplicatedDynamicFilters().contains(stageDomains.getKey())) { - // for replicated dynamic filters it's enough to get dynamic filter from a single task - checkState( - stageDomains.getValue().size() == 1, - "Replicated dynamic filter should be collected from single task"); - return true; - } - - // check if all tasks of a dynamic filter source have reported dynamic filter summary - return stageNumberOfTasks.isPresent() && stageDomains.getValue().size() == stageNumberOfTasks.getAsInt(); - }) - .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); - - context.addDynamicFilters(newDynamicFilters); - } - @VisibleForTesting Optional getSummary(QueryId queryId, DynamicFilterId filterId) { @@ -469,7 +412,6 @@ private TupleDomain translateSummaryToTupleDomain( TypeProvider typeProvider) { Collection descriptors = descriptorMultimap.get(filterId); - checkState(descriptors != null, "No descriptors for dynamic filter %s", filterId); Domain summary = dynamicFilterContext.getDynamicFilterSummaries().get(filterId); return TupleDomain.withColumnDomains(descriptors.stream() .collect(toImmutableMap( @@ -702,31 +644,171 @@ public String toString() } } - /* - * DynamicFilterContext can be fully lock-free since computing dynamic filter summaries - * is idempotent. Concurrent computations of DF summaries should produce exact same result - * when partial (from tasks) DFs are available. Partial DFs are only removed when - * final dynamic filter summary is computed. - */ + private static class DynamicFilterCollectionContext + { + private final boolean replicated; + private final Set collectedTasks = newConcurrentHashSet(); + private final Queue summaryDomains = new ConcurrentLinkedQueue<>(); + + @GuardedBy("this") + private volatile Integer expectedTaskCount; + @GuardedBy("this") + private int collectedTaskCount; + + private final long start = System.nanoTime(); + private final AtomicReference collectionDuration = new AtomicReference<>(); + @GuardedBy("this") + private volatile boolean collected; + private final SettableFuture collectedDomainsFuture = SettableFuture.create(); + + private DynamicFilterCollectionContext(boolean replicated) + { + this.replicated = replicated; + } + + public void collect(TaskId taskId, Domain domain) + { + if (collected) { + return; + } + + if (replicated) { + collectReplicated(domain); + } + else { + collectPartitioned(taskId, domain); + } + } + + private void collectReplicated(Domain domain) + { + Domain result; + synchronized (this) { + if (collected) { + return; + } + collectedTaskCount++; + collected = true; + result = domain; + } + collectionDuration.set(Duration.succinctNanos(System.nanoTime() - start)); + collectedDomainsFuture.set(result); + } + + private void collectPartitioned(TaskId taskId, Domain domain) + { + if (!collectedTasks.add(taskId.getPartitionId())) { + return; + } + summaryDomains.add(domain); + unionSummaryDomains(); + + Domain result; + synchronized (this) { + if (collected) { + return; + } + collectedTaskCount++; + boolean allPartitionsCollected = expectedTaskCount != null && expectedTaskCount == collectedTaskCount; + if (allPartitionsCollected) { + // run final compaction as previous concurrent compactions may have left more than a single domain + unionSummaryDomains(); + } + + boolean collectionFinished = domain.isAll() || allPartitionsCollected; + if (!collectionFinished) { + return; + } + collected = true; + if (domain.isAll()) { + result = domain; + } + else { + // run union one more time + unionSummaryDomains(); + int summaryDomainsCount = summaryDomains.size(); + verify(summaryDomainsCount == 1, "summaryDomainsCount is expected to be equal to 1, got: %s", summaryDomainsCount); + result = summaryDomains.poll(); + } + } + + verify(result != null); + collectionDuration.set(Duration.succinctNanos(System.nanoTime() - start)); + collectedDomainsFuture.set(result); + } + + private void unionSummaryDomains() + { + while (true) { + // This method is called every time a new domain is added to the summaryDomains queue. + // In a normal situation (when there's no race) there should be no more than 2 domains in the queue. + Domain first = summaryDomains.poll(); + if (first == null) { + return; + } + Domain second = summaryDomains.poll(); + if (second == null) { + summaryDomains.add(first); + return; + } + summaryDomains.add(first.union(second)); + } + } + + public void setExpectedTaskCount(int count) + { + if (collected || expectedTaskCount != null) { + return; + } + checkArgument(count > 0, "count is expected to be greater than zero: %s", count); + + Domain result; + synchronized (this) { + if (collected || expectedTaskCount != null) { + return; + } + expectedTaskCount = count; + verify(collectedTaskCount <= expectedTaskCount, + "collectedTaskCount is expected to be less than or equal to %s, got: %s", + expectedTaskCount, + collectedTaskCount); + if (collectedTaskCount != expectedTaskCount) { + return; + } + // run union one more time + unionSummaryDomains(); + + verify(summaryDomains.size() == 1); + result = summaryDomains.poll(); + } + verify(result != null); + collectionDuration.set(Duration.succinctNanos(System.nanoTime() - start)); + collectedDomainsFuture.set(result); + } + + public ListenableFuture getCollectedDomainFuture() + { + return collectedDomainsFuture; + } + + public Optional getCollectionDuration() + { + return Optional.ofNullable(collectionDuration.get()); + } + } + private static class DynamicFilterContext { private final Session session; - private final Map dynamicFilterSummaries = new ConcurrentHashMap<>(); - private final Map dynamicFilterCollectionTime = new ConcurrentHashMap<>(); private final Set dynamicFilters; - private final Map> lazyDynamicFilters; private final Set replicatedDynamicFilters; + private final Map> lazyDynamicFilters; + private final Map dynamicFilterCollectionContexts; + private final Map> stageDynamicFilters = new ConcurrentHashMap<>(); private final Map stageNumberOfTasks = new ConcurrentHashMap<>(); - // when map value for given filter id is empty it means that dynamic filter has already been collected - // and no partial task domains are required - private final Map> taskDynamicFilters = new ConcurrentHashMap<>(); - @GuardedBy("dynamicFilterConsumers") - // This should not be a ConcurrentHashMap because we want to prevent concurrent addition of new consumers during the - // removal of existing consumers from this map in addDynamicFilters. This ensures that new consumers don't miss filter completion. - private final Map>>> dynamicFilterConsumers = new HashMap<>(); + private final int attemptId; - private final long queryAttemptStartTime = System.nanoTime(); private DynamicFilterContext( Session session, @@ -741,10 +823,16 @@ private DynamicFilterContext( this.lazyDynamicFilters = lazyDynamicFilters.stream() .collect(toImmutableMap(identity(), filter -> SettableFuture.create())); this.replicatedDynamicFilters = requireNonNull(replicatedDynamicFilters, "replicatedDynamicFilters is null"); - dynamicFilters.forEach(filter -> { - taskDynamicFilters.put(filter, new ConcurrentHashMap<>()); - dynamicFilterConsumers.put(filter, new ArrayList<>()); - }); + ImmutableMap.Builder collectionContexts = ImmutableMap.builder(); + for (DynamicFilterId dynamicFilterId : dynamicFilters) { + DynamicFilterCollectionContext collectionContext = new DynamicFilterCollectionContext(replicatedDynamicFilters.contains(dynamicFilterId)); + collectionContexts.put(dynamicFilterId, collectionContext); + SettableFuture lazyDynamicFilterFuture = this.lazyDynamicFilters.get(dynamicFilterId); + if (lazyDynamicFilterFuture != null) { + collectionContext.getCollectedDomainFuture().addListener(() -> lazyDynamicFilterFuture.set(null), directExecutor()); + } + } + dynamicFilterCollectionContexts = collectionContexts.buildOrThrow(); this.attemptId = attemptId; } @@ -760,22 +848,10 @@ DynamicFilterContext createContextForQueryRetry(int attemptId) void addDynamicFilterConsumer(Set dynamicFilterIds, Consumer> consumer) { - ImmutableMap.Builder collectedDomainsBuilder = ImmutableMap.builder(); - dynamicFilterIds.forEach(dynamicFilterId -> { - List>> consumers; - synchronized (dynamicFilterConsumers) { - consumers = dynamicFilterConsumers.get(dynamicFilterId); - if (consumers != null) { - consumers.add(consumer); - return; - } - } - // filter has already been collected - collectedDomainsBuilder.put(dynamicFilterId, dynamicFilterSummaries.get(dynamicFilterId)); - }); - Map collectedDomains = collectedDomainsBuilder.buildOrThrow(); - if (!collectedDomains.isEmpty()) { - consumer.accept(collectedDomains); + for (DynamicFilterId dynamicFilterId : dynamicFilterIds) { + DynamicFilterCollectionContext collectionContext = dynamicFilterCollectionContexts.get(dynamicFilterId); + verify(collectionContext != null, "collectionContext is missing for %s", dynamicFilterId); + addSuccessCallback(collectionContext.getCollectedDomainFuture(), domain -> consumer.accept(ImmutableMap.of(dynamicFilterId, domain))); } } @@ -789,76 +865,45 @@ private int getTotalDynamicFilters() return dynamicFilters.size(); } - private OptionalInt getNumberOfTasks(StageId stageId) + private void addTaskDynamicFilters(TaskId taskId, Map newDynamicFilters) { - return Optional.ofNullable(stageNumberOfTasks.get(stageId)) - .map(OptionalInt::of) - .orElse(OptionalInt.empty()); - } + newDynamicFilters.forEach((dynamicFilterId, domain) -> { + DynamicFilterCollectionContext collectionContext = dynamicFilterCollectionContexts.get(dynamicFilterId); + verify(collectionContext != null, "collectionContext is missing for %s", dynamicFilterId); + collectionContext.collect(taskId, domain); + }); - private Map> getTaskDynamicFilters(StageId stageId, Optional> selectedFilters) - { - return selectedFilters.orElseGet(() -> stageDynamicFilters.get(stageId)).stream() - .collect(toImmutableMap( - identity(), - filter -> Optional.ofNullable(taskDynamicFilters.get(filter)) - .map(taskDomains -> ImmutableList.copyOf(taskDomains.values())) - // return empty list in case filter has already been collected and task domains have been removed - .orElse(ImmutableList.of()))); + if (stageDynamicFilters.computeIfAbsent(taskId.getStageId(), key -> newConcurrentHashSet()).addAll(newDynamicFilters.keySet())) { + updateExpectedTaskCount(); + } } - private void addDynamicFilters(Map> newDynamicFilters) + private void stageCannotScheduleMoreTasks(StageId stageId, int numberOfTasks) { - SetMultimap>, DynamicFilterId> completedConsumers = HashMultimap.create(); - newDynamicFilters.forEach((filter, domain) -> { - if (taskDynamicFilters.remove(filter) == null) { - // filter has been collected concurrently - return; - } - dynamicFilterCollectionTime.put(filter, System.nanoTime()); - dynamicFilterSummaries.put(filter, union(domain)); - Optional.ofNullable(lazyDynamicFilters.get(filter)).ifPresent(future -> future.set(null)); - List>> consumers; - synchronized (dynamicFilterConsumers) { - // this section is executed only once due to the earlier null check on taskDynamicFilters.remove(filter) - consumers = requireNonNull(dynamicFilterConsumers.remove(filter)); - } - // dynamic filter updates are batched up per-consumer to reduce number of callbacks - consumers.forEach(consumer -> completedConsumers.put(consumer, filter)); - }); - completedConsumers.asMap().forEach((consumer, dynamicFilterIds) -> consumer.accept( - dynamicFilterIds.stream() - .collect(toImmutableMap( - identity(), - filterId -> requireNonNull(dynamicFilterSummaries.get(filterId)))))); + if (stageNumberOfTasks.put(stageId, numberOfTasks) == null) { + updateExpectedTaskCount(); + } } - private void addTaskDynamicFilters(TaskId taskId, Map newDynamicFilters) + private void updateExpectedTaskCount() { - stageDynamicFilters.computeIfAbsent(taskId.getStageId(), ignored -> newConcurrentHashSet()) - .addAll(newDynamicFilters.keySet()); - newDynamicFilters.forEach((filter, domain) -> { - Map taskDomains = taskDynamicFilters.get(filter); - if (taskDomains == null) { - // dynamic filter has already been collected - return; + stageNumberOfTasks.forEach((stage, taskCount) -> { + Set filtersIds = stageDynamicFilters.get(stage); + if (filtersIds != null) { + for (DynamicFilterId filterId : filtersIds) { + DynamicFilterCollectionContext collectionContext = dynamicFilterCollectionContexts.get(filterId); + verify(collectionContext != null, "collectionContext is missing for %s", filterId); + collectionContext.setExpectedTaskCount(taskCount); + } } - // Narrowing down of task dynamic filter is not supported. - // Currently, task dynamic filters are derived from join and semi-join, - // which produce just a single version of dynamic filter. - Domain previousDomain = taskDomains.put(taskId, domain); - checkState(previousDomain == null || domain.equals(previousDomain), "Different task domains were set"); }); } - private void stageCannotScheduleMoreTasks(StageId stageId, int numberOfTasks) - { - stageNumberOfTasks.put(stageId, numberOfTasks); - } - private Map getDynamicFilterSummaries() { - return dynamicFilterSummaries; + return dynamicFilterCollectionContexts.entrySet().stream() + .filter(entry -> entry.getValue().getCollectedDomainFuture().isDone()) + .collect(toImmutableMap(Map.Entry::getKey, entry -> getFutureValue(entry.getValue().getCollectedDomainFuture()))); } private Map> getLazyDynamicFilters() @@ -871,14 +916,11 @@ private Set getReplicatedDynamicFilters() return replicatedDynamicFilters; } - private Optional getDynamicFilterCollectionDuration(DynamicFilterId filterId) + private Optional getDynamicFilterCollectionDuration(DynamicFilterId dynamicFilterId) { - Long filterCollectionTime = dynamicFilterCollectionTime.get(filterId); - if (filterCollectionTime == null) { - return Optional.empty(); - } - - return Optional.of(succinctNanos(filterCollectionTime - queryAttemptStartTime)); + DynamicFilterCollectionContext collectionContext = dynamicFilterCollectionContexts.get(dynamicFilterId); + verify(collectionContext != null, "collectionContext is missing for %s", dynamicFilterId); + return collectionContext.getCollectionDuration(); } private int getAttemptId() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java index 34b6a39e2560..62e5f269acdd 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java @@ -34,7 +34,6 @@ public void testDefaults() .setEnableDynamicFiltering(true) .setEnableCoordinatorDynamicFiltersDistribution(true) .setEnableLargeDynamicFilters(false) - .setServiceThreadCount(2) .setSmallBroadcastMaxDistinctValuesPerDriver(200) .setSmallBroadcastMaxSizePerDriver(DataSize.of(20, KILOBYTE)) .setSmallBroadcastRangeRowLimitPerDriver(400) @@ -60,7 +59,6 @@ public void testExplicitPropertyMappings() .put("enable-dynamic-filtering", "false") .put("enable-coordinator-dynamic-filters-distribution", "false") .put("enable-large-dynamic-filters", "true") - .put("dynamic-filtering.service-thread-count", "4") .put("dynamic-filtering.small-broadcast.max-distinct-values-per-driver", "256") .put("dynamic-filtering.small-broadcast.max-size-per-driver", "64kB") .put("dynamic-filtering.small-broadcast.range-row-limit-per-driver", "10000") @@ -83,7 +81,6 @@ public void testExplicitPropertyMappings() .setEnableDynamicFiltering(false) .setEnableCoordinatorDynamicFiltersDistribution(false) .setEnableLargeDynamicFilters(true) - .setServiceThreadCount(4) .setSmallBroadcastMaxDistinctValuesPerDriver(256) .setSmallBroadcastMaxSizePerDriver(DataSize.of(64, KILOBYTE)) .setSmallBroadcastRangeRowLimitPerDriver(10000) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index 16947e2619be..218fa97204ca 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -21,7 +21,6 @@ import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; import io.trino.cost.StatsAndCosts; -import io.trino.execution.DynamicFilterConfig; import io.trino.execution.MockRemoteTaskFactory; import io.trino.execution.MockRemoteTaskFactory.MockRemoteTask; import io.trino.execution.NodeTaskMap; @@ -346,7 +345,7 @@ public void testNoNodes() new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(20, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 2, - new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), + new DynamicFilterService(metadata, functionManager, typeOperators), new TableExecuteContextManager(), () -> false); scheduler.schedule(); @@ -487,7 +486,7 @@ public void testNewTaskScheduledWhenChildStageBufferIsUnderutilized() new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(500, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 500, - new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), + new DynamicFilterService(metadata, functionManager, typeOperators), new TableExecuteContextManager(), () -> false); @@ -531,7 +530,7 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(400, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 400, - new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), + new DynamicFilterService(metadata, functionManager, typeOperators), new TableExecuteContextManager(), () -> true); @@ -560,7 +559,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); StageExecution stage = createStageExecution(plan, nodeTaskMap); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap)); - DynamicFilterService dynamicFilterService = new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()); + DynamicFilterService dynamicFilterService = new DynamicFilterService(metadata, functionManager, typeOperators); dynamicFilterService.registerQuery( QUERY_ID, TEST_SESSION, @@ -644,7 +643,7 @@ private StageScheduler getSourcePartitionedScheduler( new ConnectorAwareSplitSource(CONNECTOR_ID, splitSource), placementPolicy, splitBatchSize, - new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), + new DynamicFilterService(metadata, functionManager, typeOperators), new TableExecuteContextManager(), () -> false); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java index e2b25fda30c4..6cb62cb5bf93 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java @@ -41,7 +41,6 @@ import java.util.Set; import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static io.trino.execution.scheduler.StageExecution.State.ABORTED; import static io.trino.execution.scheduler.StageExecution.State.FINISHED; import static io.trino.execution.scheduler.StageExecution.State.FLUSHING; @@ -63,8 +62,7 @@ public class TestPhasedExecutionSchedule private final DynamicFilterService dynamicFilterService = new DynamicFilterService( createTestMetadataManager(), createTestingFunctionManager(), - new TypeOperators(), - newDirectExecutorService()); + new TypeOperators()); @Test public void testPartitionedJoin() diff --git a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java index f57048931695..d03d56d57dda 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java +++ b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java @@ -55,7 +55,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; -import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; +import static io.airlift.slice.Slices.utf8Slice; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; @@ -78,6 +78,7 @@ import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; import static io.trino.util.DynamicFiltersTestUtil.getSimplifiedDomainString; +import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertNull; @@ -742,8 +743,8 @@ filterId1, singleValue(INTEGER, 3L), ImmutableMap.of( filterId1, multipleValues(INTEGER, ImmutableList.of(1L, 3L)), filterId2, multipleValues(INTEGER, ImmutableList.of(2L, 4L)))); - // both filters should be received in single callback - assertEquals(callbackCount.get(), 1); + + assertEquals(callbackCount.get(), 2); // register another consumer after both filters have been collected Map secondConsumerCollectedFilters = new HashMap<>(); @@ -761,10 +762,9 @@ filterId1, multipleValues(INTEGER, ImmutableList.of(1L, 3L)), ImmutableMap.of( filterId1, multipleValues(INTEGER, ImmutableList.of(1L, 3L)), filterId2, multipleValues(INTEGER, ImmutableList.of(2L, 4L)))); - // both filters should be received by second consumer in single callback - assertEquals(secondCallbackCount.get(), 1); + assertEquals(secondCallbackCount.get(), 2); // first consumer should not receive callback again since it already got the completed filter - assertEquals(callbackCount.get(), 1); + assertEquals(callbackCount.get(), 2); } @Test @@ -843,13 +843,48 @@ public void testMultipleAttempts() getSimplifiedDomainString(4L, 6L, 3, INTEGER)))); } + @Test + public void testCollectMoreThanOnceForTheSameTask() + { + DynamicFilterService dynamicFilterService = createDynamicFilterService(); + QueryId query = new QueryId("query"); + StageId stage = new StageId(query, 0); + DynamicFilterId filter = new DynamicFilterId("filter"); + + dynamicFilterService.registerQuery( + query, + session, + ImmutableSet.of(filter), + ImmutableSet.of(filter), + ImmutableSet.of()); + + dynamicFilterService.stageCannotScheduleMoreTasks(stage, 0, 2); + + Domain domain1 = Domain.singleValue(VARCHAR, utf8Slice("value1")); + Domain domain2 = Domain.singleValue(VARCHAR, utf8Slice("value2")); + Domain domain3 = Domain.singleValue(VARCHAR, utf8Slice("value3")); + + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage, 0, 0), + ImmutableMap.of(filter, domain1)); + assertThat(dynamicFilterService.getSummary(query, filter)).isNotPresent(); + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage, 0, 0), + ImmutableMap.of(filter, domain2)); + assertThat(dynamicFilterService.getSummary(query, filter)).isNotPresent(); + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage, 1, 0), + ImmutableMap.of(filter, domain3)); + assertThat(dynamicFilterService.getSummary(query, filter)).isPresent(); + assertEquals(dynamicFilterService.getSummary(query, filter).get(), domain1.union(domain3)); + } + private static DynamicFilterService createDynamicFilterService() { return new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - PLANNER_CONTEXT.getTypeOperators(), - newDirectExecutorService()); + PLANNER_CONTEXT.getTypeOperators()); } private static PlanFragment createPlan(DynamicFilterId dynamicFilterId, PartitioningHandle stagePartitioning, ExchangeNode.Type exchangeType) diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index 1307cdce9665..adfb560f0d4c 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -31,7 +31,6 @@ import io.trino.block.BlockJsonSerde; import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; -import io.trino.execution.DynamicFilterConfig; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryManagerConfig; @@ -107,7 +106,6 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static com.google.inject.Scopes.SINGLETON; import static io.airlift.json.JsonBinder.jsonBinder; import static io.airlift.json.JsonCodecBinder.jsonCodecBinder; @@ -215,8 +213,7 @@ public void testDynamicFilters() DynamicFilterService dynamicFilterService = new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - new TypeOperators(), - newDirectExecutorService()); + new TypeOperators()); HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, dynamicFilterService); RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory, ImmutableSet.of()); @@ -274,7 +271,6 @@ public void testDynamicFilters() assertGreaterThanOrEqual(testingTaskResource.getStatusFetchCounter(), 4L); httpRemoteTaskFactory.stop(); - dynamicFilterService.stop(); } @Test(timeOut = 30_000) @@ -296,8 +292,7 @@ public void testOutboundDynamicFilters() DynamicFilterService dynamicFilterService = new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - new TypeOperators(), - newDirectExecutorService()); + new TypeOperators()); dynamicFilterService.registerQuery( queryId, TEST_SESSION, @@ -369,7 +364,6 @@ public void testOutboundDynamicFilters() ImmutableMap.of(filterId2, Domain.singleValue(BIGINT, 2L))); httpRemoteTaskFactory.stop(); - dynamicFilterService.stop(); } private void runTest(FailureScenario failureScenario) @@ -434,8 +428,7 @@ private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskReso return createHttpRemoteTaskFactory(testingTaskResource, new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - new TypeOperators(), - new DynamicFilterConfig())); + new TypeOperators())); } private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskResource testingTaskResource, DynamicFilterService dynamicFilterService) From 96c41cc0f43e336741f7c804d8853359de890d0e Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Thu, 23 Jun 2022 23:57:34 -0400 Subject: [PATCH 2/7] Limit maximum size of a single dynamic filter --- .../trino/execution/DynamicFilterConfig.java | 31 ++++++ .../io/trino/server/DynamicFilterService.java | 60 +++++++++-- .../execution/TestDynamicFilterConfig.java | 10 +- .../TestSourcePartitionedScheduler.java | 11 +- .../policy/TestPhasedExecutionSchedule.java | 4 +- .../server/TestDynamicFilterService.java | 101 +++++++++++++++++- .../server/remotetask/TestHttpRemoteTask.java | 10 +- 7 files changed, 208 insertions(+), 19 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java b/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java index b1680382f2f2..15fec226b288 100644 --- a/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/DynamicFilterConfig.java @@ -21,6 +21,7 @@ import io.airlift.units.MaxDataSize; import javax.validation.constraints.Min; +import javax.validation.constraints.NotNull; import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; @@ -48,6 +49,7 @@ public class DynamicFilterConfig private DataSize smallPartitionedMaxSizePerDriver = DataSize.of(10, KILOBYTE); private int smallPartitionedRangeRowLimitPerDriver = 100; private DataSize smallPartitionedMaxSizePerOperator = DataSize.of(100, KILOBYTE); + private DataSize smallMaxSizePerFilter = DataSize.of(1, MEGABYTE); private int largeBroadcastMaxDistinctValuesPerDriver = 5_000; private DataSize largeBroadcastMaxSizePerDriver = DataSize.of(500, KILOBYTE); @@ -57,6 +59,7 @@ public class DynamicFilterConfig private DataSize largePartitionedMaxSizePerDriver = DataSize.of(50, KILOBYTE); private int largePartitionedRangeRowLimitPerDriver = 1_000; private DataSize largePartitionedMaxSizePerOperator = DataSize.of(500, KILOBYTE); + private DataSize largeMaxSizePerFilter = DataSize.of(5, MEGABYTE); public boolean isEnableDynamicFiltering() { @@ -200,6 +203,20 @@ public DynamicFilterConfig setSmallPartitionedMaxSizePerOperator(DataSize smallP return this; } + @NotNull + @MaxDataSize("10MB") + public DataSize getSmallMaxSizePerFilter() + { + return smallMaxSizePerFilter; + } + + @Config("dynamic-filtering.small.max-size-per-filter") + public DynamicFilterConfig setSmallMaxSizePerFilter(DataSize smallMaxSizePerFilter) + { + this.smallMaxSizePerFilter = smallMaxSizePerFilter; + return this; + } + @Min(0) public int getLargeBroadcastMaxDistinctValuesPerDriver() { @@ -303,4 +320,18 @@ public DynamicFilterConfig setLargePartitionedMaxSizePerOperator(DataSize largeP this.largePartitionedMaxSizePerOperator = largePartitionedMaxSizePerOperator; return this; } + + @NotNull + @MaxDataSize("10MB") + public DataSize getLargeMaxSizePerFilter() + { + return largeMaxSizePerFilter; + } + + @Config("dynamic-filtering.large.max-size-per-filter") + public DynamicFilterConfig setLargeMaxSizePerFilter(DataSize largeMaxSizePerFilter) + { + this.largeMaxSizePerFilter = largeMaxSizePerFilter; + return this; + } } diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index 94e8982f1e49..243af6424319 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -23,8 +23,10 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Inject; +import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; +import io.trino.execution.DynamicFilterConfig; import io.trino.execution.SqlQueryExecution; import io.trino.execution.StageId; import io.trino.execution.TaskId; @@ -84,6 +86,7 @@ import static io.airlift.concurrent.MoreFutures.toCompletableFuture; import static io.airlift.concurrent.MoreFutures.unmodifiableFuture; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; +import static io.trino.SystemSessionProperties.isEnableLargeDynamicFilters; import static io.trino.spi.connector.DynamicFilter.EMPTY; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.DynamicFilters.extractSourceSymbols; @@ -99,14 +102,16 @@ public class DynamicFilterService private final Metadata metadata; private final FunctionManager functionManager; private final TypeOperators typeOperators; + private final DynamicFilterConfig dynamicFilterConfig; private final Map dynamicFilterContexts = new ConcurrentHashMap<>(); @Inject - public DynamicFilterService(Metadata metadata, FunctionManager functionManager, TypeOperators typeOperators) + public DynamicFilterService(Metadata metadata, FunctionManager functionManager, TypeOperators typeOperators, DynamicFilterConfig dynamicFilterConfig) { this.metadata = requireNonNull(metadata, "metadata is null"); this.functionManager = requireNonNull(functionManager, "functionManager is null"); this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); + this.dynamicFilterConfig = requireNonNull(dynamicFilterConfig, "dynamicFilterConfig is null"); } public void registerQuery(SqlQueryExecution sqlQueryExecution, SubPlan fragmentedPlan) @@ -143,9 +148,18 @@ public void registerQuery( dynamicFilters, lazyDynamicFilters, replicatedDynamicFilters, + getDynamicFilterSizeLimit(session), 0)); } + private DataSize getDynamicFilterSizeLimit(Session session) + { + if (isEnableLargeDynamicFilters(session)) { + return dynamicFilterConfig.getLargeMaxSizePerFilter(); + } + return dynamicFilterConfig.getSmallMaxSizePerFilter(); + } + public void registerQueryRetry(QueryId queryId, int attemptId) { DynamicFilterContext context = dynamicFilterContexts.get(queryId); @@ -647,6 +661,7 @@ public String toString() private static class DynamicFilterCollectionContext { private final boolean replicated; + private final long domainSizeLimitInBytes; private final Set collectedTasks = newConcurrentHashSet(); private final Queue summaryDomains = new ConcurrentLinkedQueue<>(); @@ -661,9 +676,10 @@ private static class DynamicFilterCollectionContext private volatile boolean collected; private final SettableFuture collectedDomainsFuture = SettableFuture.create(); - private DynamicFilterCollectionContext(boolean replicated) + private DynamicFilterCollectionContext(boolean replicated, long domainSizeLimitInBytes) { this.replicated = replicated; + this.domainSizeLimitInBytes = domainSizeLimitInBytes; } public void collect(TaskId taskId, Domain domain) @@ -682,6 +698,12 @@ public void collect(TaskId taskId, Domain domain) private void collectReplicated(Domain domain) { + if (domain.getRetainedSizeInBytes() > domainSizeLimitInBytes) { + domain = domain.simplify(1); + } + if (domain.getRetainedSizeInBytes() > domainSizeLimitInBytes) { + domain = Domain.all(domain.getType()); + } Domain result; synchronized (this) { if (collected) { @@ -700,6 +722,7 @@ private void collectPartitioned(TaskId taskId, Domain domain) if (!collectedTasks.add(taskId.getPartitionId())) { return; } + summaryDomains.add(domain); unionSummaryDomains(); @@ -715,17 +738,36 @@ private void collectPartitioned(TaskId taskId, Domain domain) unionSummaryDomains(); } - boolean collectionFinished = domain.isAll() || allPartitionsCollected; + boolean sizeLimitExceeded = false; + Domain allDomain = null; + Domain summary = summaryDomains.poll(); + // summary can be null as another concurrent summary compaction may be running + if (summary != null) { + if (summary.getRetainedSizeInBytes() > domainSizeLimitInBytes) { + summary = summary.simplify(1); + } + if (summary.getRetainedSizeInBytes() > domainSizeLimitInBytes) { + sizeLimitExceeded = true; + allDomain = Domain.all(summary.getType()); + } + else { + summaryDomains.add(summary); + } + } + + boolean collectionFinished = sizeLimitExceeded || domain.isAll() || allPartitionsCollected; if (!collectionFinished) { return; } collected = true; - if (domain.isAll()) { + if (sizeLimitExceeded) { + result = allDomain; + } + else if (domain.isAll()) { result = domain; } else { - // run union one more time - unionSummaryDomains(); + verify(allPartitionsCollected, "allPartitionsCollected is expected to be true"); int summaryDomainsCount = summaryDomains.size(); verify(summaryDomainsCount == 1, "summaryDomainsCount is expected to be equal to 1, got: %s", summaryDomainsCount); result = summaryDomains.poll(); @@ -802,6 +844,7 @@ private static class DynamicFilterContext private final Session session; private final Set dynamicFilters; private final Set replicatedDynamicFilters; + private final DataSize dynamicFilterSizeLimit; private final Map> lazyDynamicFilters; private final Map dynamicFilterCollectionContexts; @@ -815,6 +858,7 @@ private DynamicFilterContext( Set dynamicFilters, Set lazyDynamicFilters, Set replicatedDynamicFilters, + DataSize dynamicFilterSizeLimit, int attemptId) { this.session = requireNonNull(session, "session is null"); @@ -823,9 +867,10 @@ private DynamicFilterContext( this.lazyDynamicFilters = lazyDynamicFilters.stream() .collect(toImmutableMap(identity(), filter -> SettableFuture.create())); this.replicatedDynamicFilters = requireNonNull(replicatedDynamicFilters, "replicatedDynamicFilters is null"); + this.dynamicFilterSizeLimit = requireNonNull(dynamicFilterSizeLimit, "dynamicFilterSizeLimit is null"); ImmutableMap.Builder collectionContexts = ImmutableMap.builder(); for (DynamicFilterId dynamicFilterId : dynamicFilters) { - DynamicFilterCollectionContext collectionContext = new DynamicFilterCollectionContext(replicatedDynamicFilters.contains(dynamicFilterId)); + DynamicFilterCollectionContext collectionContext = new DynamicFilterCollectionContext(replicatedDynamicFilters.contains(dynamicFilterId), dynamicFilterSizeLimit.toBytes()); collectionContexts.put(dynamicFilterId, collectionContext); SettableFuture lazyDynamicFilterFuture = this.lazyDynamicFilters.get(dynamicFilterId); if (lazyDynamicFilterFuture != null) { @@ -843,6 +888,7 @@ DynamicFilterContext createContextForQueryRetry(int attemptId) dynamicFilters, lazyDynamicFilters.keySet(), replicatedDynamicFilters, + dynamicFilterSizeLimit, attemptId); } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java index 62e5f269acdd..33c927229952 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestDynamicFilterConfig.java @@ -42,6 +42,7 @@ public void testDefaults() .setSmallPartitionedMaxSizePerDriver(DataSize.of(10, KILOBYTE)) .setSmallPartitionedRangeRowLimitPerDriver(100) .setSmallPartitionedMaxSizePerOperator(DataSize.of(100, KILOBYTE)) + .setSmallMaxSizePerFilter(DataSize.of(1, MEGABYTE)) .setLargeBroadcastMaxDistinctValuesPerDriver(5000) .setLargeBroadcastMaxSizePerDriver(DataSize.of(500, KILOBYTE)) .setLargeBroadcastRangeRowLimitPerDriver(10_000) @@ -49,7 +50,8 @@ public void testDefaults() .setLargePartitionedMaxDistinctValuesPerDriver(500) .setLargePartitionedMaxSizePerDriver(DataSize.of(50, KILOBYTE)) .setLargePartitionedRangeRowLimitPerDriver(1_000) - .setLargePartitionedMaxSizePerOperator(DataSize.of(500, KILOBYTE))); + .setLargePartitionedMaxSizePerOperator(DataSize.of(500, KILOBYTE)) + .setLargeMaxSizePerFilter(DataSize.of(5, MEGABYTE))); } @Test @@ -67,6 +69,7 @@ public void testExplicitPropertyMappings() .put("dynamic-filtering.small-partitioned.max-size-per-driver", "64kB") .put("dynamic-filtering.small-partitioned.range-row-limit-per-driver", "10000") .put("dynamic-filtering.small-partitioned.max-size-per-operator", "641kB") + .put("dynamic-filtering.small.max-size-per-filter", "341kB") .put("dynamic-filtering.large-broadcast.max-distinct-values-per-driver", "256") .put("dynamic-filtering.large-broadcast.max-size-per-driver", "64kB") .put("dynamic-filtering.large-broadcast.range-row-limit-per-driver", "100000") @@ -75,6 +78,7 @@ public void testExplicitPropertyMappings() .put("dynamic-filtering.large-partitioned.max-size-per-driver", "64kB") .put("dynamic-filtering.large-partitioned.range-row-limit-per-driver", "100000") .put("dynamic-filtering.large-partitioned.max-size-per-operator", "643kB") + .put("dynamic-filtering.large.max-size-per-filter", "3411kB") .buildOrThrow(); DynamicFilterConfig expected = new DynamicFilterConfig() @@ -89,6 +93,7 @@ public void testExplicitPropertyMappings() .setSmallPartitionedMaxSizePerDriver(DataSize.of(64, KILOBYTE)) .setSmallPartitionedRangeRowLimitPerDriver(10000) .setSmallPartitionedMaxSizePerOperator(DataSize.of(641, KILOBYTE)) + .setSmallMaxSizePerFilter(DataSize.of(341, KILOBYTE)) .setLargeBroadcastMaxDistinctValuesPerDriver(256) .setLargeBroadcastMaxSizePerDriver(DataSize.of(64, KILOBYTE)) .setLargeBroadcastRangeRowLimitPerDriver(100000) @@ -96,7 +101,8 @@ public void testExplicitPropertyMappings() .setLargePartitionedMaxDistinctValuesPerDriver(256) .setLargePartitionedMaxSizePerDriver(DataSize.of(64, KILOBYTE)) .setLargePartitionedRangeRowLimitPerDriver(100000) - .setLargePartitionedMaxSizePerOperator(DataSize.of(643, KILOBYTE)); + .setLargePartitionedMaxSizePerOperator(DataSize.of(643, KILOBYTE)) + .setLargeMaxSizePerFilter(DataSize.of(3411, KILOBYTE)); assertFullMapping(properties, expected); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index 218fa97204ca..16947e2619be 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -21,6 +21,7 @@ import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; import io.trino.cost.StatsAndCosts; +import io.trino.execution.DynamicFilterConfig; import io.trino.execution.MockRemoteTaskFactory; import io.trino.execution.MockRemoteTaskFactory.MockRemoteTask; import io.trino.execution.NodeTaskMap; @@ -345,7 +346,7 @@ public void testNoNodes() new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(20, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 2, - new DynamicFilterService(metadata, functionManager, typeOperators), + new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), new TableExecuteContextManager(), () -> false); scheduler.schedule(); @@ -486,7 +487,7 @@ public void testNewTaskScheduledWhenChildStageBufferIsUnderutilized() new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(500, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 500, - new DynamicFilterService(metadata, functionManager, typeOperators), + new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), new TableExecuteContextManager(), () -> false); @@ -530,7 +531,7 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() new ConnectorAwareSplitSource(CONNECTOR_ID, createFixedSplitSource(400, TestingSplit::createRemoteSplit)), new DynamicSplitPlacementPolicy(nodeScheduler.createNodeSelector(session, Optional.of(CONNECTOR_ID)), stage::getAllTasks), 400, - new DynamicFilterService(metadata, functionManager, typeOperators), + new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), new TableExecuteContextManager(), () -> true); @@ -559,7 +560,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); StageExecution stage = createStageExecution(plan, nodeTaskMap); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap)); - DynamicFilterService dynamicFilterService = new DynamicFilterService(metadata, functionManager, typeOperators); + DynamicFilterService dynamicFilterService = new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()); dynamicFilterService.registerQuery( QUERY_ID, TEST_SESSION, @@ -643,7 +644,7 @@ private StageScheduler getSourcePartitionedScheduler( new ConnectorAwareSplitSource(CONNECTOR_ID, splitSource), placementPolicy, splitBatchSize, - new DynamicFilterService(metadata, functionManager, typeOperators), + new DynamicFilterService(metadata, functionManager, typeOperators, new DynamicFilterConfig()), new TableExecuteContextManager(), () -> false); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java index 6cb62cb5bf93..fce15e6bb340 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; import com.google.common.util.concurrent.ListenableFuture; +import io.trino.execution.DynamicFilterConfig; import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.RemoteTask; import io.trino.execution.StageId; @@ -62,7 +63,8 @@ public class TestPhasedExecutionSchedule private final DynamicFilterService dynamicFilterService = new DynamicFilterService( createTestMetadataManager(), createTestingFunctionManager(), - new TypeOperators()); + new TypeOperators(), + new DynamicFilterConfig()); @Test public void testPartitionedJoin() diff --git a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java index d03d56d57dda..4c31a2bd03e3 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java +++ b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java @@ -16,8 +16,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.airlift.units.DataSize; import io.trino.Session; import io.trino.cost.StatsAndCosts; +import io.trino.execution.DynamicFilterConfig; import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.operator.RetryPolicy; @@ -27,6 +29,7 @@ import io.trino.spi.connector.TestingColumnHandle; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.predicate.ValueSet; import io.trino.sql.DynamicFilters; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningHandle; @@ -54,8 +57,12 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.IntStream; +import java.util.stream.LongStream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; @@ -64,6 +71,7 @@ import static io.trino.spi.predicate.Domain.multipleValues; import static io.trino.spi.predicate.Domain.none; import static io.trino.spi.predicate.Domain.singleValue; +import static io.trino.spi.predicate.Range.range; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -78,6 +86,7 @@ import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; import static io.trino.util.DynamicFiltersTestUtil.getSimplifiedDomainString; +import static java.util.stream.Collectors.joining; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; @@ -843,6 +852,95 @@ public void testMultipleAttempts() getSimplifiedDomainString(4L, 6L, 3, INTEGER)))); } + @Test + public void testSizeLimit() + { + DataSize sizeLimit = DataSize.of(1, KILOBYTE); + DynamicFilterConfig config = new DynamicFilterConfig(); + config.setSmallMaxSizePerFilter(sizeLimit); + DynamicFilterService dynamicFilterService = new DynamicFilterService( + PLANNER_CONTEXT.getMetadata(), + PLANNER_CONTEXT.getFunctionManager(), + PLANNER_CONTEXT.getTypeOperators(), + config); + + QueryId queryId = new QueryId("query"); + StageId stage1 = new StageId(queryId, 0); + StageId stage2 = new StageId(queryId, 1); + StageId stage3 = new StageId(queryId, 3); + StageId stage4 = new StageId(queryId, 3); + DynamicFilterId compactFilter = new DynamicFilterId("compact"); + DynamicFilterId largeFilter = new DynamicFilterId("large"); + DynamicFilterId replicatedFilter1 = new DynamicFilterId("replicated1"); + DynamicFilterId replicatedFilter2 = new DynamicFilterId("replicated2"); + + dynamicFilterService.registerQuery( + queryId, + session, + ImmutableSet.of(compactFilter, largeFilter, replicatedFilter1, replicatedFilter2), + ImmutableSet.of(compactFilter, largeFilter, replicatedFilter1, replicatedFilter2), + ImmutableSet.of(replicatedFilter1, replicatedFilter2)); + + Domain domain1 = Domain.multipleValues(VARCHAR, LongStream.range(0, 5) + .mapToObj(i -> utf8Slice("value" + i)) + .collect(toImmutableList())); + Domain domain2 = Domain.multipleValues(VARCHAR, LongStream.range(6, 31) + .mapToObj(i -> utf8Slice("value" + i)) + .collect(toImmutableList())); + Domain domain3 = Domain.singleValue(VARCHAR, utf8Slice(IntStream.range(0, 800) + .mapToObj(i -> "x") + .collect(joining()))); + assertThat(domain1.getRetainedSizeInBytes()).isLessThan(sizeLimit.toBytes()); + assertThat(domain1.union(domain2).getRetainedSizeInBytes()).isGreaterThanOrEqualTo(sizeLimit.toBytes()); + assertThat(domain1.union(domain2).union(domain3).simplify(1).getRetainedSizeInBytes()) + .isGreaterThanOrEqualTo(sizeLimit.toBytes()); + + // test filter compaction + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage1, 0, 0), + ImmutableMap.of(compactFilter, domain1)); + assertThat(dynamicFilterService.getSummary(queryId, compactFilter)).isNotPresent(); + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage1, 1, 0), + ImmutableMap.of(compactFilter, domain2)); + assertThat(dynamicFilterService.getSummary(queryId, compactFilter)).isNotPresent(); + dynamicFilterService.stageCannotScheduleMoreTasks(stage1, 0, 2); + assertThat(dynamicFilterService.getSummary(queryId, compactFilter)).isPresent(); + Domain compactFilterSummary = dynamicFilterService.getSummary(queryId, compactFilter).get(); + assertEquals(compactFilterSummary.getValues(), ValueSet.ofRanges(range(VARCHAR, utf8Slice("value0"), true, utf8Slice("value9"), true))); + + // test size limit exceeded after compaction + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage2, 0, 0), + ImmutableMap.of(largeFilter, domain1)); + assertThat(dynamicFilterService.getSummary(queryId, largeFilter)).isNotPresent(); + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage2, 1, 0), + ImmutableMap.of(largeFilter, domain2)); + assertThat(dynamicFilterService.getSummary(queryId, largeFilter)).isNotPresent(); + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage2, 2, 0), + ImmutableMap.of(largeFilter, domain3)); + assertThat(dynamicFilterService.getSummary(queryId, largeFilter)).isPresent(); + assertEquals(dynamicFilterService.getSummary(queryId, largeFilter).get(), Domain.all(VARCHAR)); + + // test compaction for replicated filter + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage3, 0, 0), + ImmutableMap.of(replicatedFilter1, domain1.union(domain2))); + assertThat(dynamicFilterService.getSummary(queryId, replicatedFilter1)).isPresent(); + assertEquals( + dynamicFilterService.getSummary(queryId, replicatedFilter1).get().getValues(), + ValueSet.ofRanges(range(VARCHAR, utf8Slice("value0"), true, utf8Slice("value9"), true))); + + // test size limit exceeded for replicated filter + dynamicFilterService.addTaskDynamicFilters( + new TaskId(stage4, 0, 0), + ImmutableMap.of(replicatedFilter2, domain1.union(domain2).union(domain3))); + assertThat(dynamicFilterService.getSummary(queryId, replicatedFilter2)).isPresent(); + assertEquals(dynamicFilterService.getSummary(queryId, replicatedFilter2).get(), Domain.all(VARCHAR)); + } + @Test public void testCollectMoreThanOnceForTheSameTask() { @@ -884,7 +982,8 @@ private static DynamicFilterService createDynamicFilterService() return new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - PLANNER_CONTEXT.getTypeOperators()); + PLANNER_CONTEXT.getTypeOperators(), + new DynamicFilterConfig()); } private static PlanFragment createPlan(DynamicFilterId dynamicFilterId, PartitioningHandle stagePartitioning, ExchangeNode.Type exchangeType) diff --git a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java index adfb560f0d4c..7b65d49169e2 100644 --- a/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java +++ b/core/trino-main/src/test/java/io/trino/server/remotetask/TestHttpRemoteTask.java @@ -31,6 +31,7 @@ import io.trino.block.BlockJsonSerde; import io.trino.client.NodeVersion; import io.trino.connector.CatalogName; +import io.trino.execution.DynamicFilterConfig; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryManagerConfig; @@ -213,7 +214,8 @@ public void testDynamicFilters() DynamicFilterService dynamicFilterService = new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - new TypeOperators()); + new TypeOperators(), + new DynamicFilterConfig()); HttpRemoteTaskFactory httpRemoteTaskFactory = createHttpRemoteTaskFactory(testingTaskResource, dynamicFilterService); RemoteTask remoteTask = createRemoteTask(httpRemoteTaskFactory, ImmutableSet.of()); @@ -292,7 +294,8 @@ public void testOutboundDynamicFilters() DynamicFilterService dynamicFilterService = new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - new TypeOperators()); + new TypeOperators(), + new DynamicFilterConfig()); dynamicFilterService.registerQuery( queryId, TEST_SESSION, @@ -428,7 +431,8 @@ private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskReso return createHttpRemoteTaskFactory(testingTaskResource, new DynamicFilterService( PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), - new TypeOperators())); + new TypeOperators(), + new DynamicFilterConfig())); } private static HttpRemoteTaskFactory createHttpRemoteTaskFactory(TestingTaskResource testingTaskResource, DynamicFilterService dynamicFilterService) From fdc9b1f1e2ee3cbf330afb49262a7ce99dcce4f7 Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Fri, 24 Jun 2022 00:37:27 -0400 Subject: [PATCH 3/7] Use roaring bitmap to store collected tasks A set of collected tasks is stored for each dynamic filter. When the number of dynamic filters collected is high and there are thousands of tasks a naive way of storing collected tasks (via a Set of Integer) could create a significant memory overhead. --- core/trino-main/pom.xml | 5 +++++ .../java/io/trino/server/DynamicFilterService.java | 10 +++++++--- pom.xml | 6 ++++++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/core/trino-main/pom.xml b/core/trino-main/pom.xml index dd69613bff5b..3b2f57ace22f 100644 --- a/core/trino-main/pom.xml +++ b/core/trino-main/pom.xml @@ -344,6 +344,11 @@ pcollections + + org.roaringbitmap + RoaringBitmap + + org.weakref jmxutils diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index 243af6424319..f620a18fa297 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -51,6 +51,7 @@ import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.SemiJoinNode; +import org.roaringbitmap.RoaringBitmap; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; @@ -662,7 +663,8 @@ private static class DynamicFilterCollectionContext { private final boolean replicated; private final long domainSizeLimitInBytes; - private final Set collectedTasks = newConcurrentHashSet(); + @GuardedBy("collectedTasks") + private final RoaringBitmap collectedTasks = new RoaringBitmap(); private final Queue summaryDomains = new ConcurrentLinkedQueue<>(); @GuardedBy("this") @@ -719,8 +721,10 @@ private void collectReplicated(Domain domain) private void collectPartitioned(TaskId taskId, Domain domain) { - if (!collectedTasks.add(taskId.getPartitionId())) { - return; + synchronized (collectedTasks) { + if (!collectedTasks.checkedAdd(taskId.getPartitionId())) { + return; + } } summaryDomains.add(domain); diff --git a/pom.xml b/pom.xml index 9c9ef6694156..f3a6060773b7 100644 --- a/pom.xml +++ b/pom.xml @@ -1658,6 +1658,12 @@ 42.3.4 + + org.roaringbitmap + RoaringBitmap + 0.9.25 + + org.sonatype.aether aether-api From e1bd0d21b90f389b504b91b8d63bfec9e300ca3f Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Mon, 27 Jun 2022 12:39:49 -0400 Subject: [PATCH 4/7] Optimize union of domains in DynamicFilterService Union domains only when the size limit is exceeded --- .../io/trino/server/DynamicFilterService.java | 57 +++++++++++++------ 1 file changed, 40 insertions(+), 17 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index f620a18fa297..0412a053a0a7 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -56,6 +56,7 @@ import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; +import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; @@ -66,6 +67,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -89,6 +91,7 @@ import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.trino.SystemSessionProperties.isEnableLargeDynamicFilters; import static io.trino.spi.connector.DynamicFilter.EMPTY; +import static io.trino.spi.predicate.Domain.union; import static io.trino.sql.DynamicFilters.extractDynamicFilters; import static io.trino.sql.DynamicFilters.extractSourceSymbols; import static io.trino.sql.planner.DomainCoercer.applySaturatedCasts; @@ -666,6 +669,7 @@ private static class DynamicFilterCollectionContext @GuardedBy("collectedTasks") private final RoaringBitmap collectedTasks = new RoaringBitmap(); private final Queue summaryDomains = new ConcurrentLinkedQueue<>(); + private final AtomicLong summaryDomainsRetainedSizeInBytes = new AtomicLong(); @GuardedBy("this") private volatile Integer expectedTaskCount; @@ -727,8 +731,9 @@ private void collectPartitioned(TaskId taskId, Domain domain) } } + summaryDomainsRetainedSizeInBytes.addAndGet(domain.getRetainedSizeInBytes()); summaryDomains.add(domain); - unionSummaryDomains(); + unionSummaryDomainsIfNecessary(false); Domain result; synchronized (this) { @@ -739,7 +744,7 @@ private void collectPartitioned(TaskId taskId, Domain domain) boolean allPartitionsCollected = expectedTaskCount != null && expectedTaskCount == collectedTaskCount; if (allPartitionsCollected) { // run final compaction as previous concurrent compactions may have left more than a single domain - unionSummaryDomains(); + unionSummaryDomainsIfNecessary(true); } boolean sizeLimitExceeded = false; @@ -747,14 +752,17 @@ private void collectPartitioned(TaskId taskId, Domain domain) Domain summary = summaryDomains.poll(); // summary can be null as another concurrent summary compaction may be running if (summary != null) { + long originalSize = summary.getRetainedSizeInBytes(); if (summary.getRetainedSizeInBytes() > domainSizeLimitInBytes) { summary = summary.simplify(1); } if (summary.getRetainedSizeInBytes() > domainSizeLimitInBytes) { sizeLimitExceeded = true; allDomain = Domain.all(summary.getType()); + summaryDomainsRetainedSizeInBytes.addAndGet(-originalSize); } else { + summaryDomainsRetainedSizeInBytes.addAndGet(summary.getRetainedSizeInBytes() - originalSize); summaryDomains.add(summary); } } @@ -775,30 +783,42 @@ else if (domain.isAll()) { int summaryDomainsCount = summaryDomains.size(); verify(summaryDomainsCount == 1, "summaryDomainsCount is expected to be equal to 1, got: %s", summaryDomainsCount); result = summaryDomains.poll(); + verify(result != null); + long currentSize = summaryDomainsRetainedSizeInBytes.addAndGet(-result.getRetainedSizeInBytes()); + verify(currentSize == 0, "currentSize is expected to be zero: %s", currentSize); } } - verify(result != null); collectionDuration.set(Duration.succinctNanos(System.nanoTime() - start)); collectedDomainsFuture.set(result); } - private void unionSummaryDomains() + private void unionSummaryDomainsIfNecessary(boolean force) { + if (summaryDomainsRetainedSizeInBytes.get() < domainSizeLimitInBytes && !force) { + return; + } + + List domains = new ArrayList<>(); + long domainsRetainedSizeInBytes = 0; while (true) { - // This method is called every time a new domain is added to the summaryDomains queue. - // In a normal situation (when there's no race) there should be no more than 2 domains in the queue. - Domain first = summaryDomains.poll(); - if (first == null) { - return; - } - Domain second = summaryDomains.poll(); - if (second == null) { - summaryDomains.add(first); - return; + Domain domain = summaryDomains.poll(); + if (domain == null) { + break; } - summaryDomains.add(first.union(second)); + domains.add(domain); + domainsRetainedSizeInBytes += domain.getRetainedSizeInBytes(); } + + if (domains.isEmpty()) { + return; + } + + Domain union = union(domains); + summaryDomainsRetainedSizeInBytes.addAndGet(union.getRetainedSizeInBytes() - domainsRetainedSizeInBytes); + long currentSize = summaryDomainsRetainedSizeInBytes.get(); + verify(currentSize >= 0, "currentSize is expected to be greater than or equal to zero: %s", currentSize); + summaryDomains.add(union); } public void setExpectedTaskCount(int count) @@ -822,12 +842,15 @@ public void setExpectedTaskCount(int count) return; } // run union one more time - unionSummaryDomains(); + unionSummaryDomainsIfNecessary(true); verify(summaryDomains.size() == 1); result = summaryDomains.poll(); + verify(result != null); + long currentSize = summaryDomainsRetainedSizeInBytes.addAndGet(-result.getRetainedSizeInBytes()); + verify(currentSize == 0, "currentSize is expected to be zero: %s", currentSize); } - verify(result != null); + collectionDuration.set(Duration.succinctNanos(System.nanoTime() - start)); collectedDomainsFuture.set(result); } From 07fce36dc4a66d85743a93a93070c98ca189f3d2 Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Mon, 27 Jun 2022 19:16:38 -0400 Subject: [PATCH 5/7] Clear collected domains if the collection is finished In DynamicFilterCollectionContext the new domain is optimistically added to the queue. If the collection is finished in the meantime the queue has to be properly cleaned. --- .../io/trino/server/DynamicFilterService.java | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index 0412a053a0a7..6060253963ae 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -738,6 +738,7 @@ private void collectPartitioned(TaskId taskId, Domain domain) Domain result; synchronized (this) { if (collected) { + clearSummaryDomains(); return; } collectedTaskCount++; @@ -776,6 +777,7 @@ private void collectPartitioned(TaskId taskId, Domain domain) result = allDomain; } else if (domain.isAll()) { + clearSummaryDomains(); result = domain; } else { @@ -821,6 +823,21 @@ private void unionSummaryDomainsIfNecessary(boolean force) summaryDomains.add(union); } + private void clearSummaryDomains() + { + long domainsRetainedSizeInBytes = 0; + while (true) { + Domain domain = summaryDomains.poll(); + if (domain == null) { + break; + } + domainsRetainedSizeInBytes += domain.getRetainedSizeInBytes(); + } + summaryDomainsRetainedSizeInBytes.addAndGet(-domainsRetainedSizeInBytes); + long currentSize = summaryDomainsRetainedSizeInBytes.get(); + verify(currentSize >= 0, "currentSize is expected to be greater than or equal to zero: %s", currentSize); + } + public void setExpectedTaskCount(int count) { if (collected || expectedTaskCount != null) { From 852448e9d3e47304f75b198b9a220eadcffccbad Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Mon, 27 Jun 2022 18:58:13 -0400 Subject: [PATCH 6/7] Extract getRetainedSizeInBytes method in LocalDynamicFilterConsumer --- .../io/trino/sql/planner/LocalDynamicFilterConsumer.java | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java index fdeafefd9a60..6a66dad23812 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java @@ -108,10 +108,10 @@ public void addPartition(TupleDomain domain) TupleDomain summary = summaryDomains.poll(); // summary can be null as another concurrent summary compaction may be running if (summary != null) { - if (summary.getRetainedSizeInBytes(DynamicFilterId::getRetainedSizeInBytes) > domainSizeLimitInBytes) { + if (getRetainedSizeInBytes(summary) > domainSizeLimitInBytes) { summary = summary.simplify(1); } - if (summary.getRetainedSizeInBytes(DynamicFilterId::getRetainedSizeInBytes) > domainSizeLimitInBytes) { + if (getRetainedSizeInBytes(summary) > domainSizeLimitInBytes) { sizeLimitExceeded = true; } summaryDomains.add(summary); @@ -248,4 +248,9 @@ public synchronized String toString() .add("summaryDomains", summaryDomains) .toString(); } + + private static long getRetainedSizeInBytes(TupleDomain summary) + { + return summary.getRetainedSizeInBytes(DynamicFilterId::getRetainedSizeInBytes); + } } From ad5b2c08c935bb6a3c2f4f1d231577272be7747c Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Mon, 27 Jun 2022 19:12:44 -0400 Subject: [PATCH 7/7] Optimize union of domains in LocalDynamicFilterConsumer Union domains only when the size limit is exceeded --- .../planner/LocalDynamicFilterConsumer.java | 76 ++++++++++++++----- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java index 6a66dad23812..5c8e1247bbd7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalDynamicFilterConsumer.java @@ -25,12 +25,14 @@ import javax.annotation.concurrent.GuardedBy; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import static com.google.common.base.MoreObjects.toStringHelper; @@ -61,6 +63,7 @@ public class LocalDynamicFilterConsumer private volatile boolean collected; private final Queue> summaryDomains = new ConcurrentLinkedQueue<>(); + private final AtomicLong summaryDomainsRetainedSizeInBytes = new AtomicLong(); public LocalDynamicFilterConsumer(Map buildChannels, Map filterBuildTypes, List>> collectors, DataSize domainSizeLimit) { @@ -80,20 +83,22 @@ public void addPartition(TupleDomain domain) return; } + long domainRetainedSizeInBytes = getRetainedSizeInBytes(domain); + summaryDomainsRetainedSizeInBytes.addAndGet(domainRetainedSizeInBytes); summaryDomains.add(domain); // Operators collecting dynamic filters tend to finish all at the same time // when filters are collected right before the HashBuilderOperator. // To avoid multiple task executor threads being blocked on waiting // for each other when collecting the filters run the heavy union operation // outside the lock. - unionSummaryDomains(); + unionSummaryDomainsIfNecessary(false); TupleDomain result; synchronized (this) { verify(expectedPartitionCount == null || collectedPartitionCount < expectedPartitionCount); if (collected) { - summaryDomains.clear(); + clearSummaryDomains(); return; } collectedPartitionCount++; @@ -101,20 +106,25 @@ public void addPartition(TupleDomain domain) boolean allPartitionsCollected = expectedPartitionCount != null && collectedPartitionCount == expectedPartitionCount; if (allPartitionsCollected) { // run final compaction as previous concurrent compactions may have left more than a single domain - unionSummaryDomains(); + unionSummaryDomainsIfNecessary(true); } boolean sizeLimitExceeded = false; TupleDomain summary = summaryDomains.poll(); // summary can be null as another concurrent summary compaction may be running if (summary != null) { - if (getRetainedSizeInBytes(summary) > domainSizeLimitInBytes) { + long originalSize = getRetainedSizeInBytes(summary); + if (originalSize > domainSizeLimitInBytes) { summary = summary.simplify(1); } if (getRetainedSizeInBytes(summary) > domainSizeLimitInBytes) { + summaryDomainsRetainedSizeInBytes.addAndGet(-originalSize); sizeLimitExceeded = true; } - summaryDomains.add(summary); + else { + summaryDomainsRetainedSizeInBytes.addAndGet(getRetainedSizeInBytes(summary) - originalSize); + summaryDomains.add(summary); + } } if (!allPartitionsCollected && !sizeLimitExceeded && !domain.isAll()) { @@ -122,7 +132,7 @@ public void addPartition(TupleDomain domain) } if (sizeLimitExceeded || domain.isAll()) { - summaryDomains.clear(); + clearSummaryDomains(); result = TupleDomain.all(); } else { @@ -130,6 +140,8 @@ public void addPartition(TupleDomain domain) verify(summaryDomains.size() == 1); result = summaryDomains.poll(); verify(result != null); + long currentSize = summaryDomainsRetainedSizeInBytes.addAndGet(-getRetainedSizeInBytes(result)); + verify(currentSize == 0, "currentSize is expected to be zero: %s", currentSize); } collected = true; } @@ -155,10 +167,12 @@ public void setPartitionCount(int partitionCount) } else { // run final compaction as previous concurrent compactions may have left more than a single domain - unionSummaryDomains(); + unionSummaryDomainsIfNecessary(true); verify(summaryDomains.size() == 1); result = summaryDomains.poll(); verify(result != null); + long currentSize = summaryDomainsRetainedSizeInBytes.addAndGet(-getRetainedSizeInBytes(result)); + verify(currentSize == 0, "currentSize is expected to be zero: %s", currentSize); } collected = true; } @@ -166,22 +180,47 @@ public void setPartitionCount(int partitionCount) collectors.forEach(collector -> collector.accept(convertTupleDomain(result))); } - private void unionSummaryDomains() + private void unionSummaryDomainsIfNecessary(boolean force) { + if (summaryDomainsRetainedSizeInBytes.get() < domainSizeLimitInBytes && !force) { + return; + } + + List> domains = new ArrayList<>(); + long domainsRetainedSizeInBytes = 0; while (true) { - // This method is called every time a new domain is added to the summaryDomains queue. - // In a normal situation (when there's no race) there should be no more than 2 domains in the queue. - TupleDomain first = summaryDomains.poll(); - if (first == null) { - return; + TupleDomain domain = summaryDomains.poll(); + if (domain == null) { + break; } - TupleDomain second = summaryDomains.poll(); - if (second == null) { - summaryDomains.add(first); - return; + domains.add(domain); + domainsRetainedSizeInBytes += getRetainedSizeInBytes(domain); + } + + if (domains.isEmpty()) { + return; + } + + TupleDomain union = columnWiseUnion(domains); + summaryDomainsRetainedSizeInBytes.addAndGet(getRetainedSizeInBytes(union) - domainsRetainedSizeInBytes); + long currentSize = summaryDomainsRetainedSizeInBytes.get(); + verify(currentSize >= 0, "currentSize is expected to be greater than or equal to zero: %s", currentSize); + summaryDomains.add(union); + } + + private void clearSummaryDomains() + { + long domainsRetainedSizeInBytes = 0; + while (true) { + TupleDomain domain = summaryDomains.poll(); + if (domain == null) { + break; } - summaryDomains.add(columnWiseUnion(first, second)); + domainsRetainedSizeInBytes += getRetainedSizeInBytes(domain); } + summaryDomainsRetainedSizeInBytes.addAndGet(-domainsRetainedSizeInBytes); + long currentSize = summaryDomainsRetainedSizeInBytes.get(); + verify(currentSize >= 0, "currentSize is expected to be greater than or equal to zero: %s", currentSize); } private Map convertTupleDomain(TupleDomain result) @@ -246,6 +285,7 @@ public synchronized String toString() .add("collectedPartitionCount", collectedPartitionCount) .add("collected", collected) .add("summaryDomains", summaryDomains) + .add("summaryDomainsRetainedSizeInBytes", summaryDomainsRetainedSizeInBytes) .toString(); }