diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index b8ed401e3c29..a74941fd1c84 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -322,6 +322,12 @@ public AsGroupedSourceScheduler(SourceScheduler sourceScheduler) pendingCompleted = new ArrayList<>(); } + @Override + public void start() + { + sourceScheduler.start(); + } + @Override public ScheduleResult schedule() { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java index cee23ccc4636..42ab0682953a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java @@ -176,6 +176,12 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( return new StageScheduler() { + @Override + public void start() + { + sourcePartitionedScheduler.start(); + } + @Override public ScheduleResult schedule() { @@ -251,6 +257,18 @@ public synchronized void noMoreLifespans() whenFinishedOrNewLifespanAdded = SettableFuture.create(); } + @Override + public synchronized void start() + { + // Avoid deadlocks by immediately scheduling a task for collecting dynamic filters because: + // * there can be task in other stage blocked waiting for the dynamic filters, or + // * connector split source for this stage might be blocked waiting the dynamic filters. + if (dynamicFilterService.isCollectingTaskNeeded(stageExecution.getStageId().getQueryId(), stageExecution.getFragment())) { + stageExecution.beginScheduling(); + createTaskOnRandomNode(); + } + } + @Override public synchronized ScheduleResult schedule() { @@ -406,13 +424,6 @@ else if (pendingSplits.isEmpty()) { return new ScheduleResult(false, overallNewTasks.build(), overallSplitAssignmentCount); } - if (anyBlockedOnNextSplitBatch - && scheduledTasks.isEmpty() - && dynamicFilterService.isCollectingTaskNeeded(stageExecution.getStageId().getQueryId(), stageExecution.getFragment())) { - // schedule a task for collecting dynamic filters in case probe split generator is waiting for them - createTaskOnRandomNode().ifPresent(overallNewTasks::add); - } - boolean anySourceTaskBlocked = this.anySourceTaskBlocked.getAsBoolean(); if (anySourceTaskBlocked) { // Dynamic filters might not be collected due to build side source tasks being blocked on full buffer. @@ -541,13 +552,13 @@ private Set assignSplits(Multimap splitAssignme return newTasks.build(); } - private Optional createTaskOnRandomNode() + private void createTaskOnRandomNode() { checkState(scheduledTasks.isEmpty(), "Stage task is already scheduled on node"); List allNodes = splitPlacementPolicy.allNodes(); checkState(allNodes.size() > 0, "No nodes available"); InternalNode node = allNodes.get(ThreadLocalRandom.current().nextInt(0, allNodes.size())); - return scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()); + scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()); } private Set finalizeTaskCreationIfNecessary() diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java index e78817c17f8b..f675eeb59a99 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourceScheduler.java @@ -22,6 +22,8 @@ public interface SourceScheduler { + void start(); + ScheduleResult schedule(); void close(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index c2a1a774ab53..444d8fe63c06 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -1543,6 +1543,7 @@ public void schedule() checkState(started.compareAndSet(false, true), "already started"); try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + stageSchedulers.values().forEach(StageScheduler::start); while (!executionSchedule.isFinished()) { List> blockedStages = new ArrayList<>(); StagesScheduleResult stagesScheduleResult = executionSchedule.getStagesToSchedule(); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java index aeba9632946d..3f1d9537a040 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageScheduler.java @@ -18,6 +18,17 @@ public interface StageScheduler extends Closeable { + /** + * Called by the query scheduler when the scheduling process begins. + * This method is called before the ExecutionSchedule takes a decision + * to schedule a stage but after the query scheduling has been fully initialized. + * Within this method the scheduler may decide to schedule tasks that + * are necessary for query execution to make progress. + * For example the scheduler may decide to schedule a task without + * assigning any splits to unblock dynamic filter collection. + */ + default void start() {} + /** * Schedules as much work as possible without blocking. * The schedule results is a hint to the query scheduler if and diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java index 3e49e510814f..5f0ce506b5af 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionPolicy.java @@ -14,15 +14,28 @@ package io.trino.execution.scheduler.policy; import io.trino.execution.scheduler.StageExecution; +import io.trino.server.DynamicFilterService; + +import javax.inject.Inject; import java.util.Collection; +import static java.util.Objects.requireNonNull; + public class PhasedExecutionPolicy implements ExecutionPolicy { + private final DynamicFilterService dynamicFilterService; + + @Inject + public PhasedExecutionPolicy(DynamicFilterService dynamicFilterService) + { + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + } + @Override public ExecutionSchedule createExecutionSchedule(Collection stages) { - return PhasedExecutionSchedule.forStages(stages); + return PhasedExecutionSchedule.forStages(stages, dynamicFilterService); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java index 6ada9758b5b0..2ef61b10dc1e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/policy/PhasedExecutionSchedule.java @@ -20,6 +20,8 @@ import com.google.common.util.concurrent.SettableFuture; import io.trino.execution.scheduler.StageExecution; import io.trino.execution.scheduler.StageExecution.State; +import io.trino.server.DynamicFilterService; +import io.trino.spi.QueryId; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ExchangeNode; @@ -88,6 +90,7 @@ public class PhasedExecutionSchedule private final List sortedFragments = new ArrayList<>(); private final Map stagesByFragmentId; private final Set activeStages = new LinkedHashSet<>(); + private final DynamicFilterService dynamicFilterService; /** * Set by {@link PhasedExecutionSchedule#init(Collection)} method. @@ -97,28 +100,26 @@ public class PhasedExecutionSchedule @GuardedBy("this") private SettableFuture rescheduleFuture = SettableFuture.create(); - public static PhasedExecutionSchedule forStages(Collection stages) + public static PhasedExecutionSchedule forStages(Collection stages, DynamicFilterService dynamicFilterService) { - PhasedExecutionSchedule schedule = new PhasedExecutionSchedule(stages); + PhasedExecutionSchedule schedule = new PhasedExecutionSchedule(stages, dynamicFilterService); schedule.init(stages); return schedule; } - private PhasedExecutionSchedule(Collection stages) + private PhasedExecutionSchedule(Collection stages, DynamicFilterService dynamicFilterService) { fragmentDependency = new DefaultDirectedGraph<>(new FragmentsEdgeFactory()); fragmentTopology = new DefaultDirectedGraph<>(new FragmentsEdgeFactory()); stagesByFragmentId = stages.stream() .collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); } private void init(Collection stages) { ImmutableSet.Builder fragmentsToExecute = ImmutableSet.builder(); - fragmentsToExecute.addAll(extractDependenciesAndReturnNonLazyFragments( - stages.stream() - .map(StageExecution::getFragment) - .collect(toImmutableList()))); + fragmentsToExecute.addAll(extractDependenciesAndReturnNonLazyFragments(stages)); // start stages without any dependencies fragmentDependency.vertexSet().stream() .filter(fragmentId -> fragmentDependency.inDegreeOf(fragmentId) == 0) @@ -267,13 +268,24 @@ private boolean isStageCompleted(StageExecution stage) return state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone(); } - private Set extractDependenciesAndReturnNonLazyFragments(Collection fragments) + private Set extractDependenciesAndReturnNonLazyFragments(Collection stages) { + if (stages.isEmpty()) { + return ImmutableSet.of(); + } + + QueryId queryId = stages.stream() + .map(stage -> stage.getStageId().getQueryId()) + .findAny().orElseThrow(); + List fragments = stages.stream() + .map(StageExecution::getFragment) + .collect(toImmutableList()); + // Build a graph where the plan fragments are vertexes and the edges represent // a before -> after relationship. Destination fragment should be started only // when source fragment is completed. For example, a join hash build has an edge // to the join probe. - Visitor visitor = new Visitor(fragments); + Visitor visitor = new Visitor(queryId, fragments); visitor.processAllFragments(); // Make sure there are no strongly connected components as it would mean circular dependency between stages @@ -286,12 +298,14 @@ private Set extractDependenciesAndReturnNonLazyFragments(Collect private class Visitor extends PlanVisitor { + private final QueryId queryId; private final Map fragments; private final ImmutableSet.Builder nonLazyFragments = ImmutableSet.builder(); private final Map fragmentSubGraphs = new HashMap<>(); - public Visitor(Collection fragments) + public Visitor(QueryId queryId, Collection fragments) { + this.queryId = queryId; this.fragments = requireNonNull(fragments, "fragments is null").stream() .collect(toImmutableMap(PlanFragment::getId, identity())); } @@ -410,7 +424,7 @@ private FragmentSubGraph processJoin(boolean replicated, PlanNode probe, PlanNod addDependencyEdges(buildSubGraph.getUpstreamFragments(), probeSubGraph.getLazyUpstreamFragments()); boolean currentFragmentLazy = probeSubGraph.isCurrentFragmentLazy() && buildSubGraph.isCurrentFragmentLazy(); - if (replicated && currentFragmentLazy) { + if (replicated && currentFragmentLazy && !dynamicFilterService.isStageSchedulingNeededToCollectDynamicFilters(queryId, fragments.get(currentFragmentId))) { // Do not start join stage (which can also be a source stage with table scans) // for replicated join until build source stage enters FLUSHING state. // Broadcast join limit for CBO is set in such a way that build source data should diff --git a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java index c56c7184b3e5..e96cd496727b 100644 --- a/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java +++ b/core/trino-main/src/main/java/io/trino/server/DynamicFilterService.java @@ -232,7 +232,21 @@ public boolean isCollectingTaskNeeded(QueryId queryId, PlanFragment plan) return false; } - return !getSourceStageInnerLazyDynamicFilters(plan).isEmpty(); + // dynamic filters are collected by additional task only for non-fixed source stage + return plan.getPartitioning().equals(SOURCE_DISTRIBUTION) && !getLazyDynamicFilters(plan).isEmpty(); + } + + public boolean isStageSchedulingNeededToCollectDynamicFilters(QueryId queryId, PlanFragment plan) + { + DynamicFilterContext context = dynamicFilterContexts.get(queryId); + if (context == null) { + // query has been removed or not registered (e.g dynamic filtering is disabled) + return false; + } + + // stage scheduling is not needed to collect dynamic filters for non-fixed source stage, because + // for such stage collecting task is created + return !plan.getPartitioning().equals(SOURCE_DISTRIBUTION) && !getLazyDynamicFilters(plan).isEmpty(); } /** diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index fa5182398b90..bdfc78801d52 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -88,6 +88,8 @@ import static io.trino.execution.scheduler.PipelinedStageExecution.createPipelinedStageExecution; import static io.trino.execution.scheduler.ScheduleResult.BlockedReason.SPLIT_QUEUES_FULL; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; +import static io.trino.execution.scheduler.StageExecution.State.PLANNED; +import static io.trino.execution.scheduler.StageExecution.State.SCHEDULING; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; @@ -584,6 +586,12 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() ImmutableMap.of(symbol, new TestingColumnHandle("probeColumnA")), symbolAllocator.getTypes()); + // make sure dynamic filtering collecting task was created immediately + assertEquals(stage.getState(), PLANNED); + scheduler.start(); + assertEquals(stage.getAllTasks().size(), 1); + assertEquals(stage.getState(), SCHEDULING); + // make sure dynamic filter is initially blocked assertFalse(dynamicFilter.isBlocked().isDone()); @@ -591,8 +599,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() ScheduleResult scheduleResult = scheduler.schedule(); assertTrue(dynamicFilter.isBlocked().isDone()); - // make sure that an extra task for collecting dynamic filters was created - assertEquals(scheduleResult.getNewTasks().size(), 1); + // no new probe splits should be scheduled assertEquals(scheduleResult.getSplitsScheduled(), 0); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java index 3cb966f3574c..a68d88372ced 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/TestPhasedExecutionSchedule.java @@ -28,6 +28,9 @@ import io.trino.execution.scheduler.policy.PhasedExecutionSchedule.FragmentsEdge; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; +import io.trino.server.DynamicFilterService; +import io.trino.spi.QueryId; +import io.trino.spi.type.TypeOperators; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; @@ -40,6 +43,7 @@ import java.util.function.Consumer; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService; import static io.trino.execution.scheduler.StageExecution.State.ABORTED; import static io.trino.execution.scheduler.StageExecution.State.FINISHED; import static io.trino.execution.scheduler.StageExecution.State.FLUSHING; @@ -48,6 +52,7 @@ import static io.trino.execution.scheduler.policy.PlanUtils.createBroadcastJoinPlanFragment; import static io.trino.execution.scheduler.policy.PlanUtils.createJoinPlanFragment; import static io.trino.execution.scheduler.policy.PlanUtils.createTableScanPlanFragment; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; import static io.trino.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; @@ -56,6 +61,8 @@ public class TestPhasedExecutionSchedule { + private final DynamicFilterService dynamicFilterService = new DynamicFilterService(createTestMetadataManager(), new TypeOperators(), newDirectExecutorService()); + @Test public void testPartitionedJoin() { @@ -67,7 +74,7 @@ public void testPartitionedJoin() TestingStageExecution probeStage = new TestingStageExecution(probeFragment); TestingStageExecution joinStage = new TestingStageExecution(joinFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(buildStage, probeStage, joinStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(buildStage, probeStage, joinStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), probeFragment.getId(), joinFragment.getId()); // single dependency between build and probe stages @@ -105,7 +112,7 @@ public void testBroadcastSourceJoin() TestingStageExecution buildStage = new TestingStageExecution(buildFragment); TestingStageExecution joinSourceStage = new TestingStageExecution(joinSourceFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(joinSourceStage, buildStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(joinSourceStage, buildStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), joinSourceFragment.getId()); // single dependency between build and join stages @@ -134,7 +141,7 @@ public void testAggregation() TestingStageExecution buildStage = new TestingStageExecution(buildFragment); TestingStageExecution joinStage = new TestingStageExecution(joinFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId(), joinFragment.getId()); // aggregation and source stage should start immediately, join stage should wait for build stage to complete @@ -156,7 +163,7 @@ public void testDependentStageAbortedBeforeStarted() TestingStageExecution buildStage = new TestingStageExecution(buildFragment); TestingStageExecution joinStage = new TestingStageExecution(joinFragment); - PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage)); + PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of(sourceStage, aggregationStage, buildStage, joinStage), dynamicFilterService); assertThat(schedule.getSortedFragments()).containsExactly(buildFragment.getId(), sourceFragment.getId(), aggregationFragment.getId(), joinFragment.getId()); // aggregation and source stage should start immediately, join stage should wait for build stage to complete @@ -191,7 +198,7 @@ public void testStageWithBroadcastAndPartitionedJoin() TestingStageExecution joinStage = new TestingStageExecution(joinFragment); PhasedExecutionSchedule schedule = PhasedExecutionSchedule.forStages(ImmutableSet.of( - broadcastBuildStage, partitionedBuildStage, probeStage, joinStage)); + broadcastBuildStage, partitionedBuildStage, probeStage, joinStage), dynamicFilterService); // join stage should start immediately because partitioned join forces that DirectedGraph dependencies = schedule.getFragmentDependency(); @@ -272,7 +279,7 @@ public void addStateChangeListener(StateChangeListener stateChangeListene @Override public StageId getStageId() { - throw new UnsupportedOperationException(); + return new StageId(new QueryId("id"), 0); } @Override diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java index ddaf94425e92..4767c987fd5e 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveDynamicPartitionPruningTest.java @@ -47,4 +47,27 @@ protected void createLineitemTable(String tableName, List columns, List< String.join(",", columns)); getQueryRunner().execute(sql); } + + @Override + protected void createPartitionedTable(String tableName, List columns, List partitionColumns) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s (%s) WITH (partitioned_by=array[%s])", + tableName, + String.join(",", columns), + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(","))); + getQueryRunner().execute(sql); + } + + @Override + protected void createPartitionedAndBucketedTable(String tableName, List columns, List partitionColumns, List bucketColumns) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s (%s) WITH (partitioned_by=array[%s], bucketed_by=array[%s], bucket_count=10)", + tableName, + String.join(",", columns), + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), + bucketColumns.stream().map(column -> "'" + column + "'").collect(joining(","))); + getQueryRunner().execute(sql); + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java index b02965a40ef9..c0d1fa839534 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergDynamicPartitionPruningTest.java @@ -14,8 +14,11 @@ package io.trino.plugin.iceberg; import com.google.common.collect.ImmutableMap; +import io.trino.FeaturesConfig.JoinDistributionType; import io.trino.testing.BaseDynamicPartitionPruningTest; import io.trino.testing.QueryRunner; +import org.intellij.lang.annotations.Language; +import org.testng.SkipException; import java.util.List; @@ -35,14 +38,37 @@ protected QueryRunner createQueryRunner() REQUIRED_TABLES); } + @Override + public void testJoinDynamicFilteringMultiJoinOnBucketedTables(JoinDistributionType joinDistributionType) + { + throw new SkipException("Iceberg does not support bucketing"); + } + @Override protected void createLineitemTable(String tableName, List columns, List partitionColumns) { - String sql = format( + @Language("SQL") String sql = format( "CREATE TABLE %s WITH (partitioning=array[%s]) AS SELECT %s FROM tpch.tiny.lineitem", tableName, partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(",")), String.join(",", columns)); getQueryRunner().execute(sql); } + + @Override + protected void createPartitionedTable(String tableName, List columns, List partitionColumns) + { + @Language("SQL") String sql = format( + "CREATE TABLE %s (%s) WITH (partitioning=array[%s])", + tableName, + String.join(",", columns), + partitionColumns.stream().map(column -> "'" + column + "'").collect(joining(","))); + getQueryRunner().execute(sql); + } + + @Override + protected void createPartitionedAndBucketedTable(String tableName, List columns, List partitionColumns, List bucketColumns) + { + throw new UnsupportedOperationException(); + } } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java b/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java index eecd18b97a05..76f6ea8b836b 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/BaseDynamicPartitionPruningTest.java @@ -16,20 +16,25 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import io.trino.FeaturesConfig.JoinDistributionType; import io.trino.Session; +import io.trino.execution.QueryStats; import io.trino.operator.OperatorStats; import io.trino.server.DynamicFilterService.DynamicFilterDomainStats; import io.trino.server.DynamicFilterService.DynamicFiltersStats; +import io.trino.spi.QueryId; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.ValueSet; import io.trino.tpch.TpchTable; import org.intellij.lang.annotations.Language; import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Stream; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.FeaturesConfig.JoinDistributionType.PARTITIONED; @@ -41,11 +46,13 @@ import static io.trino.spi.predicate.Domain.singleValue; import static io.trino.spi.predicate.Range.range; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.testing.DataProviders.toDataProvider; import static io.trino.testing.QueryAssertions.assertEqualsIgnoreOrder; import static io.trino.tpch.TpchTable.LINE_ITEM; import static io.trino.tpch.TpchTable.ORDERS; import static io.trino.tpch.TpchTable.SUPPLIER; import static io.trino.util.DynamicFiltersTestUtil.getSimplifiedDomainString; +import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertTrue; @@ -73,6 +80,10 @@ public void initTables() protected abstract void createLineitemTable(String tableName, List columns, List partitionColumns); + protected abstract void createPartitionedTable(String tableName, List columns, List partitionColumns); + + protected abstract void createPartitionedAndBucketedTable(String tableName, List columns, List partitionColumns, List bucketColumns); + @Override protected Session getSession() { @@ -416,9 +427,95 @@ public void testRightJoinWithNonSelectiveBuildSide() .isEqualTo(getSimplifiedDomainString(1L, 100L, 100, BIGINT)); } + @Test(timeOut = 30_000, dataProvider = "joinDistributionTypes") + public void testJoinDynamicFilteringMultiJoinOnPartitionedTables(JoinDistributionType joinDistributionType) + { + assertUpdate("DROP TABLE IF EXISTS t0_part"); + assertUpdate("DROP TABLE IF EXISTS t1_part"); + assertUpdate("DROP TABLE IF EXISTS t2_part"); + createPartitionedTable("t0_part", ImmutableList.of("v0 real", "k0 integer"), ImmutableList.of("k0")); + createPartitionedTable("t1_part", ImmutableList.of("v1 real", "i1 integer"), ImmutableList.of()); + createPartitionedTable("t2_part", ImmutableList.of("v2 real", "i2 integer", "k2 integer"), ImmutableList.of("k2")); + assertUpdate("INSERT INTO t0_part VALUES (1.0, 1), (1.0, 2)", 2); + assertUpdate("INSERT INTO t1_part VALUES (2.0, 10), (2.0, 20)", 2); + assertUpdate("INSERT INTO t2_part VALUES (3.0, 1, 1), (3.0, 2, 2)", 2); + testJoinDynamicFilteringMultiJoin(joinDistributionType, "t0_part", "t1_part", "t2_part"); + } + + @Test(timeOut = 30_000, dataProvider = "joinDistributionTypes") + public void testJoinDynamicFilteringMultiJoinOnBucketedTables(JoinDistributionType joinDistributionType) + { + assertUpdate("DROP TABLE IF EXISTS t0_bucketed"); + assertUpdate("DROP TABLE IF EXISTS t1_bucketed"); + assertUpdate("DROP TABLE IF EXISTS t2_bucketed"); + createPartitionedAndBucketedTable("t0_bucketed", ImmutableList.of("v0 real", "k0 integer"), ImmutableList.of("k0"), ImmutableList.of("v0")); + createPartitionedAndBucketedTable("t1_bucketed", ImmutableList.of("v1 real", "i1 integer"), ImmutableList.of(), ImmutableList.of("v1")); + createPartitionedAndBucketedTable("t2_bucketed", ImmutableList.of("v2 real", "i2 integer", "k2 integer"), ImmutableList.of("k2"), ImmutableList.of("v2")); + assertUpdate("INSERT INTO t0_bucketed VALUES (1.0, 1), (1.0, 2)", 2); + assertUpdate("INSERT INTO t1_bucketed VALUES (2.0, 10), (2.0, 20)", 2); + assertUpdate("INSERT INTO t2_bucketed VALUES (3.0, 1, 1), (3.0, 2, 2)", 2); + testJoinDynamicFilteringMultiJoin(joinDistributionType, "t0_bucketed", "t1_bucketed", "t2_bucketed"); + } + + private void testJoinDynamicFilteringMultiJoin(JoinDistributionType joinDistributionType, String t0, String t1, String t2) + { + // queries should not deadlock + + // t0 table scan depends on DFs from t1 and t2 + assertDynamicFilters( + noJoinReordering(joinDistributionType), + format("SELECT v0, v1, v2 FROM (%s JOIN %s ON k0 = i2) JOIN %s ON k0 = i1", t0, t2, t1), + 0); + + // DF evaluation order is: t1 => t2 => t0 + assertDynamicFilters( + noJoinReordering(joinDistributionType), + format("SELECT v0, v1, v2 FROM (%s JOIN %s ON k0 = i2) JOIN %s ON k2 = i1", t0, t2, t1), + 0); + + // t2 table scan depends on t1 DFs, but t0 <-> t2 join is blocked on t2 data + // "(k0 * -1) + 2 = i2)" prevents DF to be used on t0 + assertDynamicFilters( + noJoinReordering(joinDistributionType), + format("SELECT v0, v1, v2 FROM (%s JOIN %s ON (k0 * -1) + 2 = i2) JOIN %s ON k2 = i1", t0, t2, t1), + 0); + } + + private void assertDynamicFilters(Session session, @Language("SQL") String query, int expectedRowCount) + { + long filteredInputPositions = getQueryInputPositions(session, query, expectedRowCount); + long unfilteredInputPositions = getQueryInputPositions(withDynamicFilteringDisabled(session), query, 0); + + assertThat(filteredInputPositions) + .as("filtered input positions") + .isLessThan(unfilteredInputPositions); + } + + private long getQueryInputPositions(Session session, @Language("SQL") String sql, int expectedRowCount) + { + DistributedQueryRunner runner = (DistributedQueryRunner) getQueryRunner(); + ResultWithQueryId result = runner.executeWithQueryId(session, sql); + assertThat(result.getResult().getRowCount()).isEqualTo(expectedRowCount); + QueryId queryId = result.getQueryId(); + QueryStats stats = runner.getCoordinator().getQueryManager().getFullQueryInfo(queryId).getQueryStats(); + return stats.getPhysicalInputPositions(); + } + + @DataProvider + public Object[][] joinDistributionTypes() + { + return Stream.of(JoinDistributionType.values()) + .collect(toDataProvider()); + } + private Session withDynamicFilteringDisabled() { - return Session.builder(getSession()) + return withDynamicFilteringDisabled(getSession()); + } + + private Session withDynamicFilteringDisabled(Session session) + { + return Session.builder(session) .setSystemProperty("enable_dynamic_filtering", "false") .build(); }