From 526d5d097a62e731bcd3e302ad57609b9305c411 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Tue, 1 Feb 2022 16:13:42 +0100 Subject: [PATCH 1/2] Add missing @Language annotation --- .../plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java index b02965a40ef9..88d752e93b86 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java @@ -38,7 +38,7 @@ protected QueryRunner createQueryRunner() @Override protected void createLineitemTable(String tableName, List columns, List partitionColumns) { - String sql = format( + @Language("SQL") String sql = format( "CREATE TABLE %s WITH (partitioning=array[%s]) AS SELECT %s FROM tpch.tiny.lineitem", tableName, partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), From bce823e216c207de95d656cd32e474997f96fdb9 Mon Sep 17 00:00:00 2001 From: Karol Sobczak Date: Wed, 2 Feb 2022 12:10:33 +0100 Subject: [PATCH 2/2] Schedule dynamic filtering collecting task immediately In case of plan J1 / \ J2 S3 / \ S1 S2 It might happen that dynamic filtering evaluation order is: S3 => S2 => S1 With phased scheduler source stage consisting of (J1, J2, S1) won't be scheduled until stages running S3 and S2 have finished split enumaration. However, it might happen that S2 is waiting for dynamic filters produced for S3. In that case, S2 will never complete because DFs for S3 are collected in stage (J1, J2, S1) which won't be scheduled until all S2 split are enumerated. This commit makes scheduling of DF collecting task immediately which will prevent queries from deadlock. --- .../FixedSourcePartitionedScheduler.java | 6 ++ .../scheduler/SourcePartitionedScheduler.java | 29 ++++-- .../execution/scheduler/SourceScheduler.java | 2 + .../scheduler/SqlQueryScheduler.java | 1 + .../execution/scheduler/StageScheduler.java | 11 +++ .../policy/PhasedExecutionPolicy.java | 15 ++- .../policy/PhasedExecutionSchedule.java | 36 ++++--- .../io/trino/server/DynamicFilterService.java | 16 ++- .../TestSourcePartitionedScheduler.java | 11 ++- .../policy/TestPhasedExecutionSchedule.java | 19 ++-- .../TestHiveDynamicPartitionPruningTest.java | 23 +++++ ...estIcebergDynamicPartitionPruningTest.java | 26 +++++ .../BaseDynamicPartitionPruningTest.java | 99 ++++++++++++++++++- 13 files changed, 263 insertions(+), 31 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index b8ed401e3c29..a74941fd1c84 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -322,6 +322,12 @@ public AsGroupedSourceScheduler(SourceScheduler sourceScheduler) pendingCompleted = new ArrayList<>(); } + @Override + public void start() + { + sourceScheduler.start(); + } + @Override public ScheduleResult schedule() { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java index cee23ccc4636..42ab0682953a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java @@ -176,6 +176,12 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( return new StageScheduler() { + @Override + public void start() + { + sourcePartitionedScheduler.start(); + } + @Override public ScheduleResult schedule() { @@ -251,6 +257,18 @@ public synchronized void noMoreLifespans() whenFinishedOrNewLifespanAdded = SettableFuture.create(); } + @Override + public synchronized void start() + { + // Avoid deadlocks by immediately scheduling a task for collecting dynamic filters because: + // * there can be task in other stage blocked waiting for the dynamic filters, or + // * connector split source for this stage might be blocked waiting the dynamic filters. + if (dynamicFilterService.isCollectingTaskNeeded(stageExecution.getStageId().getQueryId(), stageExecution.getFragment())) { + stageExecution.beginScheduling(); + createTaskOnRandomNode(); + } + } + @Override public synchronized ScheduleResult schedule() { @@ -406,13 +424,6 @@ else if (pendingSplits.isEmpty()) { return new ScheduleResult(false, overallNewTasks.build(), overallSplitAssignmentCount); } - if (anyBlockedOnNextSplitBatch - && scheduledTasks.isEmpty() - && dynamicFilterService.isCollectingTaskNeeded(stageExecution.getStageId().getQueryId(), stageExecution.getFragment())) { - // schedule a task for collecting dynamic filters in case probe split generator is waiting for them - createTaskOnRandomNode().ifPresent(overallNewTasks::add); - } - boolean anySourceTaskBlocked = this.anySourceTaskBlocked.getAsBoolean(); if (anySourceTaskBlocked) { // Dynamic filters might not be collected due to build side source tasks being blocked on full buffer. @@ -541,13 +552,13 @@ private Set assignSplits(Multimap splitAssignme return newTasks.build(); } - private Optional createTaskOnRandomNode() + private void createTaskOnRandomNode() { checkState(scheduledTasks.isEmpty(), "Stage task is already scheduled on node"); List allNodes = splitPlacementPolicy.allNodes(); checkState(allNodes.size() > 0, "No nodes available"); InternalNode node = allNodes.get(ThreadLocalRandom.current().nextInt(0, allNodes.size())); - return scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()); + scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()); } private Set finalizeTaskCreationIfNecessary() diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java index e78817c17f8b..f675eeb59a99 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java @@ -22,6 +22,8 @@ public interface SourceScheduler { + void start(); + ScheduleResult schedule(); void close(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index c2a1a774ab53..444d8fe63c06 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -1543,6 +1543,7 @@ public void schedule() checkState(started.compareAndSet(false, true), "already started"); try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + stageSchedulers.values().forEach(StageScheduler::start); while (!executionSchedule.isFinished()) { List> blockedStages = new ArrayList<>(); StagesScheduleResult stagesScheduleResult = executionSchedule.getStagesToSchedule(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java index aeba9632946d..3f1d9537a040 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java @@ -18,6 +18,17 @@ public interface StageScheduler extends Closeable { + /** + * Called by the query scheduler when the scheduling process begins. + * This method is called before the ExecutionSchedule takes a decision + * to schedule a stage but after the query scheduling has been fully initialized. + * Within this method the scheduler may decide to schedule tasks that + * are necessary for query execution to make progress. + * For example the scheduler may decide to schedule a task without + * assigning any splits to unblock dynamic filter collection. + */ + default void start() {} + /** * Schedules as much work as possible without blocking. * The schedule results is a hint to the query scheduler if and diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java index 3e49e510814f..5f0ce506b5af 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java @@ -14,15 +14,28 @@ package io.trino.execution.scheduler.policy; import io.trino.execution.scheduler.StageExecution; +import io.trino.server.DynamicFilterService; + +import javax.inject.Inject; import java.util.Collection; +import static java.util.Objects.requireNonNull; + public class PhasedExecutionPolicy implements ExecutionPolicy { + private final DynamicFilterService dynamicFilterService; + + @Inject + public PhasedExecutionPolicy(DynamicFilterService dynamicFilterService) + { + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + } + @Override public ExecutionSchedule createExecutionSchedule(Collection stages) { - return PhasedExecutionSchedule.forStages(stages); + return PhasedExecutionSchedule.forStages(stages, dynamicFilterService); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java index 6ada9758b5b0..2ef61b10dc1e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java @@ -20,6 +20,8 @@ import com.google.common.util.concurrent.SettableFuture; import io.trino.execution.scheduler.StageExecution; import io.trino.execution.scheduler.StageExecution.State; +import io.trino.server.DynamicFilterService; +import io.trino.spi.QueryId; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ExchangeNode; @@ -88,6 +90,7 @@ public class PhasedExecutionSchedule private final List sortedFragments = new ArrayList<>(); private final Map stagesByFragmentId; private final Set activeStages = new LinkedHashSet<>(); + private final DynamicFilterService dynamicFilterService; /** * Set by {@link PhasedExecutionSchedule#init(Collection)} method. @@ -97,28 +100,26 @@ public class PhasedExecutionSchedule @GuardedBy("this") private SettableFuture rescheduleFuture = SettableFuture.create(); - public static PhasedExecutionSchedule forStages(Collection stages) + public static PhasedExecutionSchedule forStages(Collection stages, DynamicFilterService dynamicFilterService) { - PhasedExecutionSchedule schedule = new PhasedExecutionSchedule(stages); + PhasedExecutionSchedule schedule = new PhasedExecutionSchedule(stages, dynamicFilterService); schedule.init(stages); return schedule; } - private PhasedExecutionSchedule(Collection stages) + private PhasedExecutionSchedule(Collection stages, DynamicFilterService dynamicFilterService) { fragmentDependency = new DefaultDirectedGraph<>(new FragmentsEdgeFactory()); fragmentTopology = new DefaultDirectedGraph<>(new FragmentsEdgeFactory()); stagesByFragmentId = stages.stream() .collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); } private void init(Collection stages) { ImmutableSet.Builder fragmentsToExecute = ImmutableSet.builder(); - fragmentsToExecute.addAll(extractDependenciesAndReturnNonLazyFragments( - stages.stream() - .map(StageExecution::getFragment) - .collect(toImmutableList()))); + fragmentsToExecute.addAll(extractDependenciesAndReturnNonLazyFragments(stages)); // start stages without any dependencies fragmentDependency.vertexSet().stream() .filter(fragmentId -> fragmentDependency.inDegreeOf(fragmentId) == 0) @@ -267,13 +268,24 @@ private boolean isStageCompleted(StageExecution stage) return state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone(); } - private Set extractDependenciesAndReturnNonLazyFragments(Collection fragments) + private Set extractDependenciesAndReturnNonLazyFragments(Collection stages) { + if (stages.isEmpty()) { + return ImmutableSet.of(); + } + + QueryId queryId = stages.stream() + .map(stage -> stage.getStageId().getQueryId()) + .findAny().orElseThrow(); + List fragments = stages.stream() + .map(StageExecution::getFragment) + .collect(toImmutableList()); + // Build a graph where the plan fragments are vertexes and the edges represent // a before -> after relationship. Destination fragment should be started only // when source fragment is completed. For example, a join hash build has an edge // to the join probe. - Visitor visitor = new Visitor(fragments); + Visitor visitor = new Visitor(queryId, fragments); visitor.processAllFragments(); // Make sure there are no strongly connected components as it would mean circular dependency between stages @@ -286,12 +298,14 @@ private Set extractDependenciesAndReturnNonLazyFragments(Collect private class Visitor extends PlanVisitor { + private final QueryId queryId; private final Map fragments; private final ImmutableSet.Builder nonLazyFragments = ImmutableSet.builder(); private final Map fragmentSubGraphs = new HashMap<>(); - public Visitor(Collection fragments) + public Visitor(QueryId queryId, Collection fragments) { + this.queryId = queryId; this.fragments = requireNonNull(fragments, "fragments is null").stream() .collect(toImmutableMap(PlanFragment::getId, identity())); } @@ -410,7 +424,7 @@ private FragmentSubGraph processJoin(boolean replicated, PlanNode probe, PlanNod addDependencyEdges(buildSubGraph.getUpstreamFragments(), probeSubGraph.getLazyUpstreamFragments()); boolean currentFragmentLazy = probeSubGraph.isCurrentFragmentLazy() && buildSubGraph.isCurrentFragmentLazy(); - if (replicated && currentFragmentLazy) { + if (replicated && currentFragmentLazy && !dynamicFilterService.isStageSchedulingNeededToCollectDynamicFilters(queryId, fragments.get(currentFragmentId))) { // Do not start join stage (which can also be a source stage with table scans) // for replicated join until build source stage enters FLUSHING state. // Broadcast join limit for CBO is set in such a way that build source data should diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index c56c7184b3e5..e96cd496727b 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -232,7 +232,21 @@ public boolean isCollectingTaskNeeded(QueryId queryId, PlanFragment plan) return false; } - return !getSourceStageInnerLazyDynamicFilters(plan).isEmpty(); + // dynamic filters are collected by additional task only for non-fixed source stage + return plan.getPartitioning().equals(SOURCE_DISTRIBUTION) && !getLazyDynamicFilters(plan).isEmpty(); + } + + public boolean isStageSchedulingNeededToCollectDynamicFilters(QueryId queryId, PlanFragment plan) + { + DynamicFilterContext context = dynamicFilterContexts.get(queryId); + if (context == null) { + // query has been removed or not registered (e.g dynamic filtering is disabled) + return false; + } + + // stage scheduling is not needed to collect dynamic filters for non-fixed source stage, because + // for such stage collecting task is created + return !plan.getPartitioning().equals(SOURCE_DISTRIBUTION) && !getLazyDynamicFilters(plan).isEmpty(); } /** diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index fa5182398b90..bdfc78801d52 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -88,6 +88,8 @@ import static io.trino.execution.scheduler.PipelinedStageExecution.createPipelinedStageExecution; import static io.trino.execution.scheduler.ScheduleResult.BlockedReason.SPLIT_QUEUES_FULL; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; +import static io.trino.execution.scheduler.StageExecution.State.PLANNED; +import static io.trino.execution.scheduler.StageExecution.State.SCHEDULING; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; @@ -584,6 +586,12 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() ImmutableMap.of(symbol, new TestingColumnHandle("probeColumnA")), symbolAllocator.getTypes()); + // make sure dynamic filtering collecting task was created immediately + assertEquals(stage.getState(), PLANNED); + scheduler.start(); + assertEquals(stage.getAllTasks().size(), 1); + assertEquals(stage.getState(), SCHEDULING); + // make sure dynamic filter is initially blocked assertFalse(dynamicFilter.isBlocked().isDone()); @@ -591,8 +599,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() ScheduleResult scheduleResult = scheduler.schedule(); assertTrue(dynamicFilter.isBlocked().isDone()); - // make sure that an extra task for collecting dynamic filters was created - assertEquals(scheduleResult.getNewTasks().size(), 1); + // no new probe splits should be scheduled assertEquals(scheduleResult.getSplitsScheduled(), 0); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java index 3cb966f3574c..a68d88372ced 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java @@ -28,6 +28,9 @@ import io.trino.execution.scheduler.policy.PhasedExecutionSchedule.FragmentsEdge; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; +import io.trino.server.DynamicFilterService; +import io.trino.spi.QueryId; +import io.trino.spi.type.TypeOperators; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; @@ -40,6 +43,7 @@ import java.util.function.Consumer; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static io.trino.execution.scheduler.StageExecution.State.ABORTED; import static io.trino.execution.scheduler.StageExecution.State.FINISHED; import static io.trino.execution.scheduler.StageExecution.State.FLUSHING; @@ -48,6 +52,7 @@ import static io.trino.execution.scheduler.policy.PlanUtils.createBroadcastJoinPlanFragment; import static io.trino.execution.scheduler.policy.PlanUtils.createJoinPlanFragment; import static io.trino.execution.scheduler.policy.PlanUtils.createTableScanPlanFragment; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static io.trino.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; @@ -56,6 +61,8 @@ public class TestPhasedExecutionSchedule { + private final DynamicFilterService dynamicFilterService = new DynamicFilterService(createTestMetadataManager(), new TypeOperators(), newDirectExecutorService()); + @Test public void testPartitionedJoin() { @@ -67,7 +74,7 @@ public void testPartitionedJoin() TestingStageExecution probeStage = new TestingStageExecution(probeFragment); TestingStageExecution joinStage = new TestingStageExecution(joinFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(buildStage, probeStage, joinStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(buildStage, probeStage, joinStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), probeFragment.getId(), joinFragment.getId()); // single dependency between build and probe stages @@ -105,7 +112,7 @@ public void testBroadcastSourceJoin() TestingStageExecution buildStage = new TestingStageExecution(buildFragment); TestingStageExecution joinSourceStage = new TestingStageExecution(joinSourceFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(joinSourceStage, buildStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(joinSourceStage, buildStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), joinSourceFragment.getId()); // single dependency between build and join stages @@ -134,7 +141,7 @@ public void testAggregation() TestingStageExecution buildStage = new TestingStageExecution(buildFragment); TestingStageExecution joinStage = new TestingStageExecution(joinFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId(), joinFragment.getId()); // aggregation and source stage should start immediately, join stage should wait for build stage to complete @@ -156,7 +163,7 @@ public void testDependentStageAbortedBeforeStarted() TestingStageExecution buildStage = new TestingStageExecution(buildFragment); TestingStageExecution joinStage = new TestingStageExecution(joinFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId(), joinFragment.getId()); // aggregation and source stage should start immediately, join stage should wait for build stage to complete @@ -191,7 +198,7 @@ public void testStageWithBroadcastAndPartitionedJoin() TestingStageExecution joinStage = new TestingStageExecution(joinFragment); PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of( - broadcastBuildStage, partitionedBuildStage, probeStage, joinStage)); + broadcastBuildStage, partitionedBuildStage, probeStage, joinStage), dynamicFilterService); // join stage should start immediately because partitioned join forces that DirectedGraph dependencies = schedule.getFragmentDependency(); @@ -272,7 +279,7 @@ public void addStateChangeListener(StateChangeListener stateChangeListene @Override public StageId getStageId() { - throw new UnsupportedOperationException(); + return new StageId(new QueryId("id"), 0); } @Override diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java index ddaf94425e92..4767c987fd5e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java @@ -47,4 +47,27 @@ protected void createLineitemTable(String tableName, List columns, List< String.join(",", columns)); getQueryRunner().execute(sql); } + + @Override + protected void createPartitionedTable(String tableName, List columns, List partitionColumns) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s (%s) WITH (partitioned_by=array[%s])", + tableName, + String.join(",", columns), + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(","))); + getQueryRunner().execute(sql); + } + + @Override + protected void createPartitionedAndBucketedTable(String tableName, List columns, List partitionColumns, List bucketColumns) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s (%s) WITH (partitioned_by=array[%s], bucketed_by=array[%s], bucket_count=10)", + tableName, + String.join(",", columns), + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), + bucketColumns.stream().map(column -> "'" + column + "'").collect(joining(","))); + getQueryRunner().execute(sql); + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java index 88d752e93b86..c0d1fa839534 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java @@ -14,8 +14,11 @@ package io.trino.plugin.iceberg; import com.google.common.collect.ImmutableMap; +import io.trino.FeaturesConfig.JoinDistributionType; import io.trino.testing.BaseDynamicPartitionPruningTest; import io.trino.testing.QueryRunner; +import org.intellij.lang.annotations.Language; +import org.testng.SkipException; import java.util.List; @@ -35,6 +38,12 @@ protected QueryRunner createQueryRunner() REQUIRED_TABLES); } + @Override + public void testJoinDynamicFilteringMultiJoinOnBucketedTables(JoinDistributionType joinDistributionType) + { + throw new SkipException("Iceberg does not support bucketing"); + } + @Override protected void createLineitemTable(String tableName, List columns, List partitionColumns) { @@ -45,4 +54,21 @@ protected void createLineitemTable(String tableName, List columns, List< String.join(",", columns)); getQueryRunner().execute(sql); } + + @Override + protected void createPartitionedTable(String tableName, List columns, List partitionColumns) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s (%s) WITH (partitioning=array[%s])", + tableName, + String.join(",", columns), + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(","))); + getQueryRunner().execute(sql); + } + + @Override + protected void createPartitionedAndBucketedTable(String tableName, List columns, List partitionColumns, List bucketColumns) + { + throw new UnsupportedOperationException(); + } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java index eecd18b97a05..76f6ea8b836b 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java @@ -16,20 +16,25 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.trino.FeaturesConfig.JoinDistributionType; import io.trino.Session; +import io.trino.execution.QueryStats; import io.trino.operator.OperatorStats; import io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import io.trino.server.DynamicFilterService.DynamicFiltersStats; +import io.trino.spi.QueryId; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.ValueSet; import io.trino.tpch.TpchTable; import org.intellij.lang.annotations.Language; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Stream; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.FeaturesConfig.JoinDistributionType.PARTITIONED; @@ -41,11 +46,13 @@ import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.predicate.Range.range; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.tpch.TpchTable.LINE_ITEM; import static io.trino.tpch.TpchTable.ORDERS; import static io.trino.tpch.TpchTable.SUPPLIER; import static io.trino.util.DynamicFiltersTestUtil.getSimplifiedDomainString; +import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -73,6 +80,10 @@ public void initTables() protected abstract void createLineitemTable(String tableName, List columns, List partitionColumns); + protected abstract void createPartitionedTable(String tableName, List columns, List partitionColumns); + + protected abstract void createPartitionedAndBucketedTable(String tableName, List columns, List partitionColumns, List bucketColumns); + @Override protected Session getSession() { @@ -416,9 +427,95 @@ public void testRightJoinWithNonSelectiveBuildSide() .isEqualTo(getSimplifiedDomainString(1L, 100L, 100, BIGINT)); } + @Test(timeOut = 30_000, dataProvider = "joinDistributionTypes") + public void testJoinDynamicFilteringMultiJoinOnPartitionedTables(JoinDistributionType joinDistributionType) + { + assertUpdate("DROP TABLE IF EXISTS t0_part"); + assertUpdate("DROP TABLE IF EXISTS t1_part"); + assertUpdate("DROP TABLE IF EXISTS t2_part"); + createPartitionedTable("t0_part", ImmutableList.of("v0 real", "k0 integer"), ImmutableList.of("k0")); + createPartitionedTable("t1_part", ImmutableList.of("v1 real", "i1 integer"), ImmutableList.of()); + createPartitionedTable("t2_part", ImmutableList.of("v2 real", "i2 integer", "k2 integer"), ImmutableList.of("k2")); + assertUpdate("INSERT INTO t0_part VALUES (1.0, 1), (1.0, 2)", 2); + assertUpdate("INSERT INTO t1_part VALUES (2.0, 10), (2.0, 20)", 2); + assertUpdate("INSERT INTO t2_part VALUES (3.0, 1, 1), (3.0, 2, 2)", 2); + testJoinDynamicFilteringMultiJoin(joinDistributionType, "t0_part", "t1_part", "t2_part"); + } + + @Test(timeOut = 30_000, dataProvider = "joinDistributionTypes") + public void testJoinDynamicFilteringMultiJoinOnBucketedTables(JoinDistributionType joinDistributionType) + { + assertUpdate("DROP TABLE IF EXISTS t0_bucketed"); + assertUpdate("DROP TABLE IF EXISTS t1_bucketed"); + assertUpdate("DROP TABLE IF EXISTS t2_bucketed"); + createPartitionedAndBucketedTable("t0_bucketed", ImmutableList.of("v0 real", "k0 integer"), ImmutableList.of("k0"), ImmutableList.of("v0")); + createPartitionedAndBucketedTable("t1_bucketed", ImmutableList.of("v1 real", "i1 integer"), ImmutableList.of(), ImmutableList.of("v1")); + createPartitionedAndBucketedTable("t2_bucketed", ImmutableList.of("v2 real", "i2 integer", "k2 integer"), ImmutableList.of("k2"), ImmutableList.of("v2")); + assertUpdate("INSERT INTO t0_bucketed VALUES (1.0, 1), (1.0, 2)", 2); + assertUpdate("INSERT INTO t1_bucketed VALUES (2.0, 10), (2.0, 20)", 2); + assertUpdate("INSERT INTO t2_bucketed VALUES (3.0, 1, 1), (3.0, 2, 2)", 2); + testJoinDynamicFilteringMultiJoin(joinDistributionType, "t0_bucketed", "t1_bucketed", "t2_bucketed"); + } + + private void testJoinDynamicFilteringMultiJoin(JoinDistributionType joinDistributionType, String t0, String t1, String t2) + { + // queries should not deadlock + + // t0 table scan depends on DFs from t1 and t2 + assertDynamicFilters( + noJoinReordering(joinDistributionType), + format("SELECT v0, v1, v2 FROM (%s JOIN %s ON k0 = i2) JOIN %s ON k0 = i1", t0, t2, t1), + 0); + + // DF evaluation order is: t1 => t2 => t0 + assertDynamicFilters( + noJoinReordering(joinDistributionType), + format("SELECT v0, v1, v2 FROM (%s JOIN %s ON k0 = i2) JOIN %s ON k2 = i1", t0, t2, t1), + 0); + + // t2 table scan depends on t1 DFs, but t0 <-> t2 join is blocked on t2 data + // "(k0 * -1) + 2 = i2)" prevents DF to be used on t0 + assertDynamicFilters( + noJoinReordering(joinDistributionType), + format("SELECT v0, v1, v2 FROM (%s JOIN %s ON (k0 * -1) + 2 = i2) JOIN %s ON k2 = i1", t0, t2, t1), + 0); + } + + private void assertDynamicFilters(Session session, @Language("SQL") String query, int expectedRowCount) + { + long filteredInputPositions = getQueryInputPositions(session, query, expectedRowCount); + long unfilteredInputPositions = getQueryInputPositions(withDynamicFilteringDisabled(session), query, 0); + + assertThat(filteredInputPositions) + .as("filtered input positions") + .isLessThan(unfilteredInputPositions); + } + + private long getQueryInputPositions(Session session, @Language("SQL") String sql, int expectedRowCount) + { + DistributedQueryRunner runner = (DistributedQueryRunner) getQueryRunner(); + ResultWithQueryId result = runner.executeWithQueryId(session, sql); + assertThat(result.getResult().getRowCount()).isEqualTo(expectedRowCount); + QueryId queryId = result.getQueryId(); + QueryStats stats = runner.getCoordinator().getQueryManager().getFullQueryInfo(queryId).getQueryStats(); + return stats.getPhysicalInputPositions(); + } + + @DataProvider + public Object[][] joinDistributionTypes() + { + return Stream.of(JoinDistributionType.values()) + .collect(toDataProvider()); + } + private Session withDynamicFilteringDisabled() { - return Session.builder(getSession()) + return withDynamicFilteringDisabled(getSession()); + } + + private Session withDynamicFilteringDisabled(Session session) + { + return Session.builder(session) .setSystemProperty("enable_dynamic_filtering", "false") .build(); }