From da5f103f3be5a4a2bd3b6f6e9da92c5fbe8066a4 Mon Sep 17 00:00:00 2001 From: Andrii Rosa Date: Mon, 30 Jan 2023 12:11:46 -0500 Subject: [PATCH] Remove deprecated FTE scheduler --- .../io/trino/SystemSessionProperties.java | 37 - .../trino/execution/QueryManagerConfig.java | 42 - .../io/trino/execution/SqlQueryExecution.java | 77 +- .../FaultTolerantQueryScheduler.java | 519 -------- .../FaultTolerantStageScheduler.java | 799 ----------- .../scheduler/StageTaskSourceFactory.java | 1026 --------------- .../trino/execution/scheduler/TaskSource.java | 30 - .../scheduler/TaskSourceFactory.java | 36 - .../io/trino/server/CoordinatorModule.java | 3 - .../execution/TestQueryManagerConfig.java | 9 - .../TestFaultTolerantStageScheduler.java | 1168 ----------------- .../scheduler/TestStageTaskSourceFactory.java | 949 -------------- .../execution/scheduler/TestingExchange.java | 196 --- .../scheduler/TestingSplitSource.java | 116 -- .../TestingTaskLifecycleListener.java | 58 - .../scheduler/TestingTaskSourceFactory.java | 178 --- 16 files changed, 19 insertions(+), 5224 deletions(-) delete mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java delete mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java delete mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java delete mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java delete mode 100644 core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java delete mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java delete mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java delete mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java delete mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java delete mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java delete mode 100644 core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index f82a609b1d90..ed5560c9038c 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -152,7 +152,6 @@ public final class SystemSessionProperties public static final String MAX_PARTIAL_TOP_N_MEMORY = "max_partial_top_n_memory"; public static final String RETRY_POLICY = "retry_policy"; public static final String QUERY_RETRY_ATTEMPTS = "query_retry_attempts"; - public static final String TASK_RETRY_ATTEMPTS_OVERALL = "task_retry_attempts_overall"; public static final String TASK_RETRY_ATTEMPTS_PER_TASK = "task_retry_attempts_per_task"; public static final String MAX_TASKS_WAITING_FOR_EXECUTION_PER_QUERY = "max_tasks_waiting_for_execution_per_query"; public static final String MAX_TASKS_WAITING_FOR_NODE_PER_STAGE = "max_tasks_waiting_for_node_per_stage"; @@ -161,7 +160,6 @@ public final class SystemSessionProperties public static final String RETRY_DELAY_SCALE_FACTOR = "retry_delay_scale_factor"; public static final String HIDE_INACCESSIBLE_COLUMNS = "hide_inaccessible_columns"; public static final String FAULT_TOLERANT_EXECUTION_TARGET_TASK_INPUT_SIZE = "fault_tolerant_execution_target_task_input_size"; - public static final String FAULT_TOLERANT_EXECUTION_MIN_TASK_SPLIT_COUNT = "fault_tolerant_execution_min_task_split_count"; public static final String FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT = "fault_tolerant_execution_target_task_split_count"; public static final String FAULT_TOLERANT_EXECUTION_MAX_TASK_SPLIT_COUNT = "fault_tolerant_execution_max_task_split_count"; public static final String FAULT_TOLERANT_EXECUTION_COORDINATOR_TASK_MEMORY = "fault_tolerant_execution_coordinator_task_memory"; @@ -175,7 +173,6 @@ public final class SystemSessionProperties public static final String JOIN_PARTITIONED_BUILD_MIN_ROW_COUNT = "join_partitioned_build_min_row_count"; public static final String USE_EXACT_PARTITIONING = "use_exact_partitioning"; public static final String FORCE_SPILLING_JOIN = "force_spilling_join"; - public static final String FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED = "fault_tolerant_execution_event_driven_scheduler_enabled"; public static final String FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED = "fault_tolerant_execution_force_preferred_write_partitioning_enabled"; public static final String PAGE_PARTITIONING_BUFFER_POOL_SIZE = "page_partitioning_buffer_pool_size"; @@ -739,11 +736,6 @@ public SystemSessionProperties( "Maximum number of query retry attempts", queryManagerConfig.getQueryRetryAttempts(), false), - integerProperty( - TASK_RETRY_ATTEMPTS_OVERALL, - "Maximum number of task retry attempts overall", - queryManagerConfig.getTaskRetryAttemptsOverall(), - false), integerProperty( TASK_RETRY_ATTEMPTS_PER_TASK, "Maximum number of task retry attempts per single task", @@ -799,11 +791,6 @@ public SystemSessionProperties( "Target size in bytes of all task inputs for a single fault tolerant task", queryManagerConfig.getFaultTolerantExecutionTargetTaskInputSize(), false), - integerProperty( - FAULT_TOLERANT_EXECUTION_MIN_TASK_SPLIT_COUNT, - "Minimal number of splits for a single fault tolerant task (count based)", - queryManagerConfig.getFaultTolerantExecutionMinTaskSplitCount(), - false), integerProperty( FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT, "Target number of splits for a single fault tolerant task (split weight aware)", @@ -871,11 +858,6 @@ public SystemSessionProperties( "Force the usage of spliing join operator in favor of the non-spilling one, even if spill is not enabled", featuresConfig.isForceSpillingJoin(), false), - booleanProperty( - FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED, - "Enable event driven scheduler for fault tolerant execution", - queryManagerConfig.isFaultTolerantExecutionEventDrivenSchedulerEnabled(), - true), booleanProperty( FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED, "Force preferred write partitioning for fault tolerant execution", @@ -1448,11 +1430,6 @@ public static int getQueryRetryAttempts(Session session) return session.getSystemProperty(QUERY_RETRY_ATTEMPTS, Integer.class); } - public static int getTaskRetryAttemptsOverall(Session session) - { - return session.getSystemProperty(TASK_RETRY_ATTEMPTS_OVERALL, Integer.class); - } - public static int getTaskRetryAttemptsPerTask(Session session) { return session.getSystemProperty(TASK_RETRY_ATTEMPTS_PER_TASK, Integer.class); @@ -1493,11 +1470,6 @@ public static DataSize getFaultTolerantExecutionTargetTaskInputSize(Session sess return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TARGET_TASK_INPUT_SIZE, DataSize.class); } - public static int getFaultTolerantExecutionMinTaskSplitCount(Session session) - { - return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_MIN_TASK_SPLIT_COUNT, Integer.class); - } - public static int getFaultTolerantExecutionTargetTaskSplitCount(Session session) { return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_TARGET_TASK_SPLIT_COUNT, Integer.class); @@ -1563,17 +1535,8 @@ public static boolean isForceSpillingOperator(Session session) return session.getSystemProperty(FORCE_SPILLING_JOIN, Boolean.class); } - public static boolean isFaultTolerantExecutionEventDriverSchedulerEnabled(Session session) - { - return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_EVENT_DRIVEN_SCHEDULER_ENABLED, Boolean.class); - } - public static boolean isFaultTolerantExecutionForcePreferredWritePartitioningEnabled(Session session) { - if (!isFaultTolerantExecutionEventDriverSchedulerEnabled(session)) { - // supported only in event driven scheduler - return false; - } return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_FORCE_PREFERRED_WRITE_PARTITIONING_ENABLED, Boolean.class); } diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index e391706a4c45..a607bd1170fc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -84,7 +84,6 @@ public class QueryManagerConfig private RetryPolicy retryPolicy = RetryPolicy.NONE; private int queryRetryAttempts = 4; private int taskRetryAttemptsPerTask = 4; - private int taskRetryAttemptsOverall = Integer.MAX_VALUE; private Duration retryInitialDelay = new Duration(10, SECONDS); private Duration retryMaxDelay = new Duration(1, MINUTES); private double retryDelayScaleFactor = 2.0; @@ -94,12 +93,10 @@ public class QueryManagerConfig private DataSize faultTolerantExecutionTargetTaskInputSize = DataSize.of(4, GIGABYTE); - private int faultTolerantExecutionMinTaskSplitCount = 16; private int faultTolerantExecutionTargetTaskSplitCount = 64; private int faultTolerantExecutionMaxTaskSplitCount = 256; private DataSize faultTolerantExecutionTaskDescriptorStorageMaxMemory = DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15)); private int faultTolerantExecutionPartitionCount = 50; - private boolean faultTolerantExecutionEventDrivenSchedulerEnabled = true; private boolean faultTolerantExecutionForcePreferredWritePartitioningEnabled = true; @Min(1) @@ -455,19 +452,6 @@ public QueryManagerConfig setQueryRetryAttempts(int queryRetryAttempts) return this; } - @Min(0) - public int getTaskRetryAttemptsOverall() - { - return taskRetryAttemptsOverall; - } - - @Config("task-retry-attempts-overall") - public QueryManagerConfig setTaskRetryAttemptsOverall(int taskRetryAttemptsOverall) - { - this.taskRetryAttemptsOverall = taskRetryAttemptsOverall; - return this; - } - @Min(0) @Max(MAX_TASK_RETRY_ATTEMPTS) public int getTaskRetryAttemptsPerTask() @@ -567,20 +551,6 @@ public QueryManagerConfig setFaultTolerantExecutionTargetTaskInputSize(DataSize return this; } - @Min(1) - public int getFaultTolerantExecutionMinTaskSplitCount() - { - return faultTolerantExecutionMinTaskSplitCount; - } - - @Config("fault-tolerant-execution-min-task-split-count") - @ConfigDescription("Minimal number of splits for a single fault tolerant task (count based)") - public QueryManagerConfig setFaultTolerantExecutionMinTaskSplitCount(int faultTolerantExecutionMinTaskSplitCount) - { - this.faultTolerantExecutionMinTaskSplitCount = faultTolerantExecutionMinTaskSplitCount; - return this; - } - @Min(1) public int getFaultTolerantExecutionTargetTaskSplitCount() { @@ -637,18 +607,6 @@ public QueryManagerConfig setFaultTolerantExecutionPartitionCount(int faultToler return this; } - public boolean isFaultTolerantExecutionEventDrivenSchedulerEnabled() - { - return faultTolerantExecutionEventDrivenSchedulerEnabled; - } - - @Config("experimental.fault-tolerant-execution-event-driven-scheduler-enabled") - public QueryManagerConfig setFaultTolerantExecutionEventDrivenSchedulerEnabled(boolean faultTolerantExecutionEventDrivenSchedulerEnabled) - { - this.faultTolerantExecutionEventDrivenSchedulerEnabled = faultTolerantExecutionEventDrivenSchedulerEnabled; - return this; - } - public boolean isFaultTolerantExecutionForcePreferredWritePartitioningEnabled() { return faultTolerantExecutionForcePreferredWritePartitioningEnabled; diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index d54643f1a5e0..f81fa6c79397 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -26,7 +26,6 @@ import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.scheduler.EventDrivenFaultTolerantQueryScheduler; import io.trino.execution.scheduler.EventDrivenTaskSourceFactory; -import io.trino.execution.scheduler.FaultTolerantQueryScheduler; import io.trino.execution.scheduler.NodeAllocatorService; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.PartitionMemoryEstimatorFactory; @@ -35,7 +34,6 @@ import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.execution.scheduler.TaskDescriptorStorage; import io.trino.execution.scheduler.TaskExecutionStats; -import io.trino.execution.scheduler.TaskSourceFactory; import io.trino.execution.scheduler.policy.ExecutionPolicy; import io.trino.execution.warnings.WarningCollector; import io.trino.failuredetector.FailureDetector; @@ -86,12 +84,8 @@ import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.units.DataSize.succinctBytes; -import static io.trino.SystemSessionProperties.getMaxTasksWaitingForNodePerStage; import static io.trino.SystemSessionProperties.getRetryPolicy; -import static io.trino.SystemSessionProperties.getTaskRetryAttemptsOverall; -import static io.trino.SystemSessionProperties.getTaskRetryAttemptsPerTask; import static io.trino.SystemSessionProperties.isEnableDynamicFiltering; -import static io.trino.SystemSessionProperties.isFaultTolerantExecutionEventDriverSchedulerEnabled; import static io.trino.execution.ParameterExtractor.bindParameters; import static io.trino.execution.QueryState.FAILED; import static io.trino.execution.QueryState.PLANNING; @@ -135,7 +129,6 @@ public class SqlQueryExecution private final TypeAnalyzer typeAnalyzer; private final SqlTaskManager coordinatorTaskManager; private final ExchangeManagerRegistry exchangeManagerRegistry; - private final TaskSourceFactory taskSourceFactory; private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory; private final TaskDescriptorStorage taskDescriptorStorage; @@ -169,7 +162,6 @@ private SqlQueryExecution( TypeAnalyzer typeAnalyzer, SqlTaskManager coordinatorTaskManager, ExchangeManagerRegistry exchangeManagerRegistry, - TaskSourceFactory taskSourceFactory, EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory, TaskDescriptorStorage taskDescriptorStorage) { @@ -217,7 +209,6 @@ private SqlQueryExecution( this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); - this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "taskSourceFactory is null"); this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); } @@ -527,51 +518,25 @@ private void planDistribution(PlanRoot plan) coordinatorTaskManager); break; case TASK: - if (isFaultTolerantExecutionEventDriverSchedulerEnabled(stateMachine.getSession())) { - scheduler = new EventDrivenFaultTolerantQueryScheduler( - stateMachine, - plannerContext.getMetadata(), - remoteTaskFactory, - taskDescriptorStorage, - eventDrivenTaskSourceFactory, - plan.isSummarizeTaskInfos(), - nodeTaskMap, - queryExecutor, - schedulerExecutor, - schedulerStats, - partitionMemoryEstimatorFactory, - nodePartitioningManager, - exchangeManagerRegistry.getExchangeManager(), - nodeAllocatorService, - failureDetector, - dynamicFilterService, - taskExecutionStats, - plan.getRoot()); - } - else { - scheduler = new FaultTolerantQueryScheduler( - stateMachine, - queryExecutor, - schedulerStats, - failureDetector, - taskSourceFactory, - taskDescriptorStorage, - exchangeManagerRegistry.getExchangeManager(), - nodePartitioningManager, - getTaskRetryAttemptsOverall(getSession()), - getTaskRetryAttemptsPerTask(getSession()), - getMaxTasksWaitingForNodePerStage(getSession()), - schedulerExecutor, - nodeAllocatorService, - partitionMemoryEstimatorFactory, - taskExecutionStats, - dynamicFilterService, - plannerContext.getMetadata(), - remoteTaskFactory, - nodeTaskMap, - plan.getRoot(), - plan.isSummarizeTaskInfos()); - } + scheduler = new EventDrivenFaultTolerantQueryScheduler( + stateMachine, + plannerContext.getMetadata(), + remoteTaskFactory, + taskDescriptorStorage, + eventDrivenTaskSourceFactory, + plan.isSummarizeTaskInfos(), + nodeTaskMap, + queryExecutor, + schedulerExecutor, + schedulerStats, + partitionMemoryEstimatorFactory, + nodePartitioningManager, + exchangeManagerRegistry.getExchangeManager(), + nodeAllocatorService, + failureDetector, + dynamicFilterService, + taskExecutionStats, + plan.getRoot()); break; default: throw new IllegalArgumentException("Unexpected retry policy: " + retryPolicy); @@ -777,7 +742,6 @@ public static class SqlQueryExecutionFactory private final TypeAnalyzer typeAnalyzer; private final SqlTaskManager coordinatorTaskManager; private final ExchangeManagerRegistry exchangeManagerRegistry; - private final TaskSourceFactory taskSourceFactory; private final EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory; private final TaskDescriptorStorage taskDescriptorStorage; @@ -808,7 +772,6 @@ public static class SqlQueryExecutionFactory TypeAnalyzer typeAnalyzer, SqlTaskManager coordinatorTaskManager, ExchangeManagerRegistry exchangeManagerRegistry, - TaskSourceFactory taskSourceFactory, EventDrivenTaskSourceFactory eventDrivenTaskSourceFactory, TaskDescriptorStorage taskDescriptorStorage) { @@ -837,7 +800,6 @@ public static class SqlQueryExecutionFactory this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); this.coordinatorTaskManager = requireNonNull(coordinatorTaskManager, "coordinatorTaskManager is null"); this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); - this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); this.eventDrivenTaskSourceFactory = requireNonNull(eventDrivenTaskSourceFactory, "eventDrivenTaskSourceFactory is null"); this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); } @@ -883,7 +845,6 @@ public QueryExecution createQueryExecution( typeAnalyzer, coordinatorTaskManager, exchangeManagerRegistry, - taskSourceFactory, eventDrivenTaskSourceFactory, taskDescriptorStorage); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java deleted file mode 100644 index 8f07c9b2f9bd..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantQueryScheduler.java +++ /dev/null @@ -1,519 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.concurrent.SetThreadName; -import io.airlift.log.Logger; -import io.airlift.stats.TimeStat; -import io.airlift.units.Duration; -import io.trino.Session; -import io.trino.exchange.SpoolingExchangeInput; -import io.trino.execution.BasicStageStats; -import io.trino.execution.NodeTaskMap; -import io.trino.execution.QueryState; -import io.trino.execution.QueryStateMachine; -import io.trino.execution.RemoteTaskFactory; -import io.trino.execution.SqlStage; -import io.trino.execution.StageId; -import io.trino.execution.StageInfo; -import io.trino.execution.TaskId; -import io.trino.failuredetector.FailureDetector; -import io.trino.metadata.Metadata; -import io.trino.operator.RetryPolicy; -import io.trino.server.DynamicFilterService; -import io.trino.spi.exchange.Exchange; -import io.trino.spi.exchange.ExchangeContext; -import io.trino.spi.exchange.ExchangeId; -import io.trino.spi.exchange.ExchangeManager; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.spi.exchange.ExchangeSourceOutputSelector; -import io.trino.sql.planner.NodePartitioningManager; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.SubPlan; -import io.trino.sql.planner.plan.PlanFragmentId; - -import javax.annotation.concurrent.GuardedBy; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.CancellationException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.atomic.AtomicInteger; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Ticker.systemTicker; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.Lists.reverse; -import static io.airlift.concurrent.MoreFutures.addExceptionCallback; -import static io.airlift.concurrent.MoreFutures.addSuccessCallback; -import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; -import static io.airlift.concurrent.MoreFutures.whenAnyComplete; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionPartitionCount; -import static io.trino.SystemSessionProperties.getRetryPolicy; -import static io.trino.execution.QueryState.FINISHING; -import static io.trino.execution.scheduler.Exchanges.getAllSourceHandles; -import static io.trino.operator.RetryPolicy.TASK; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; - -/** - * Deprecated in favor of {@link EventDrivenFaultTolerantQueryScheduler} - */ -@Deprecated -public class FaultTolerantQueryScheduler - implements QueryScheduler -{ - private static final Logger log = Logger.get(FaultTolerantQueryScheduler.class); - - private final QueryStateMachine queryStateMachine; - private final ExecutorService queryExecutor; - private final SplitSchedulerStats schedulerStats; - private final FailureDetector failureDetector; - private final TaskSourceFactory taskSourceFactory; - private final TaskDescriptorStorage taskDescriptorStorage; - private final ExchangeManager exchangeManager; - private final NodePartitioningManager nodePartitioningManager; - private final int taskRetryAttemptsOverall; - private final int taskRetryAttemptsPerTask; - private final int maxTasksWaitingForNodePerStage; - private final ScheduledExecutorService scheduledExecutorService; - private final NodeAllocatorService nodeAllocatorService; - private final PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory; - private final TaskExecutionStats taskExecutionStats; - private final DynamicFilterService dynamicFilterService; - - private final StageManager stageManager; - - @GuardedBy("this") - private boolean started; - @GuardedBy("this") - private Scheduler scheduler; - - public FaultTolerantQueryScheduler( - QueryStateMachine queryStateMachine, - ExecutorService queryExecutor, - SplitSchedulerStats schedulerStats, - FailureDetector failureDetector, - TaskSourceFactory taskSourceFactory, - TaskDescriptorStorage taskDescriptorStorage, - ExchangeManager exchangeManager, - NodePartitioningManager nodePartitioningManager, - int taskRetryAttemptsOverall, - int taskRetryAttemptsPerTask, - int maxTasksWaitingForNodePerStage, - ScheduledExecutorService scheduledExecutorService, - NodeAllocatorService nodeAllocatorService, - PartitionMemoryEstimatorFactory partitionMemoryEstimatorFactory, - TaskExecutionStats taskExecutionStats, - DynamicFilterService dynamicFilterService, - Metadata metadata, - RemoteTaskFactory taskFactory, - NodeTaskMap nodeTaskMap, - SubPlan planTree, - boolean summarizeTaskInfo) - { - this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); - RetryPolicy retryPolicy = getRetryPolicy(queryStateMachine.getSession()); - verify(retryPolicy == TASK, "unexpected retry policy: %s", retryPolicy); - this.queryExecutor = requireNonNull(queryExecutor, "queryExecutor is null"); - this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); - this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); - this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); - this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); - this.exchangeManager = requireNonNull(exchangeManager, "exchangeManager is null"); - this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); - this.taskRetryAttemptsOverall = taskRetryAttemptsOverall; - this.taskRetryAttemptsPerTask = taskRetryAttemptsPerTask; - this.maxTasksWaitingForNodePerStage = maxTasksWaitingForNodePerStage; - this.scheduledExecutorService = requireNonNull(scheduledExecutorService, "scheduledExecutorService is null"); - this.nodeAllocatorService = requireNonNull(nodeAllocatorService, "nodeAllocatorService is null"); - this.partitionMemoryEstimatorFactory = requireNonNull(partitionMemoryEstimatorFactory, "partitionMemoryEstimatorFactory is null"); - this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null"); - this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - - stageManager = StageManager.create( - queryStateMachine, - metadata, - taskFactory, - nodeTaskMap, - queryExecutor, - schedulerStats, - planTree, - summarizeTaskInfo); - } - - @Override - public synchronized void start() - { - if (started) { - return; - } - started = true; - - if (queryStateMachine.isDone()) { - return; - } - - // when query is done or any time a stage completes, attempt to transition query to "final query info ready" - queryStateMachine.addStateChangeListener(state -> { - if (!state.isDone()) { - return; - } - Scheduler scheduler; - synchronized (this) { - scheduler = this.scheduler; - this.scheduler = null; - } - if (state == QueryState.FINISHED) { - if (scheduler != null) { - scheduler.cancel(); - } - stageManager.finish(); - } - else if (state == QueryState.FAILED) { - if (scheduler != null) { - scheduler.abort(); - } - stageManager.abort(); - } - queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo())); - }); - - scheduler = createScheduler(); - queryExecutor.submit(scheduler::schedule); - } - - private Scheduler createScheduler() - { - taskDescriptorStorage.initialize(queryStateMachine.getQueryId()); - queryStateMachine.addStateChangeListener(state -> { - if (state.isDone()) { - taskDescriptorStorage.destroy(queryStateMachine.getQueryId()); - } - }); - - Session session = queryStateMachine.getSession(); - FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory = new FaultTolerantPartitioningSchemeFactory( - nodePartitioningManager, - session, - getFaultTolerantExecutionPartitionCount(session)); - - Map schedulers = new HashMap<>(); - Map exchanges = new HashMap<>(); - NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(session); - - try { - // root to children order - List stagesInTopologicalOrder = stageManager.getStagesInTopologicalOrder(); - // children to root order - List stagesInReverseTopologicalOrder = reverse(stagesInTopologicalOrder); - - checkArgument(taskRetryAttemptsOverall >= 0, "taskRetryAttemptsOverall must be greater than or equal to 0: %s", taskRetryAttemptsOverall); - AtomicInteger remainingTaskRetryAttemptsOverall = new AtomicInteger(taskRetryAttemptsOverall); - for (SqlStage stage : stagesInReverseTopologicalOrder) { - PlanFragment fragment = stage.getFragment(); - - boolean outputStage = stageManager.getOutputStage().getStageId().equals(stage.getStageId()); - ExchangeContext exchangeContext = new ExchangeContext(session.getQueryId(), new ExchangeId("external-exchange-" + stage.getStageId().getId())); - FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioningScheme().getPartitioning().getHandle()); - Exchange exchange = exchangeManager.createExchange( - exchangeContext, - sinkPartitioningScheme.getPartitionCount(), - // order of output records for coordinator consumed stages must be preserved as the stage - // may produce sorted dataset (for example an output of a global OrderByOperator) - outputStage); - exchanges.put(fragment.getId(), exchange); - - ImmutableMap.Builder sourceExchanges = ImmutableMap.builder(); - ImmutableMap.Builder sourceSchedulers = ImmutableMap.builder(); - for (SqlStage childStage : stageManager.getChildren(fragment.getId())) { - PlanFragmentId childFragmentId = childStage.getFragment().getId(); - Exchange sourceExchange = exchanges.get(childFragmentId); - verify(sourceExchange != null, "exchange not found for fragment: %s", childFragmentId); - sourceExchanges.put(childFragmentId, sourceExchange); - FaultTolerantStageScheduler sourceScheduler = schedulers.get(childFragmentId); - verify(sourceScheduler != null, "scheduler not found for fragment: %s", childFragmentId); - sourceSchedulers.put(childFragmentId, sourceScheduler); - } - - FaultTolerantPartitioningScheme sourcePartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioning()); - FaultTolerantStageScheduler scheduler = new FaultTolerantStageScheduler( - session, - stage, - failureDetector, - taskSourceFactory, - nodeAllocator, - taskDescriptorStorage, - partitionMemoryEstimatorFactory.createPartitionMemoryEstimator(), - taskExecutionStats, - (future, delay) -> scheduledExecutorService.schedule(() -> future.set(null), delay.toMillis(), MILLISECONDS), - systemTicker(), - exchange, - sinkPartitioningScheme, - sourceSchedulers.buildOrThrow(), - sourceExchanges.buildOrThrow(), - sourcePartitioningScheme, - remainingTaskRetryAttemptsOverall, - taskRetryAttemptsPerTask, - maxTasksWaitingForNodePerStage, - dynamicFilterService); - - schedulers.put(fragment.getId(), scheduler); - - if (outputStage) { - ListenableFuture> sourceHandles = getAllSourceHandles(exchange.getSourceHandles()); - addSuccessCallback(sourceHandles, handles -> { - try { - ExchangeSourceOutputSelector.Builder selector = ExchangeSourceOutputSelector.builder(ImmutableSet.of(exchange.getId())); - Map successfulAttempts = scheduler.getSuccessfulAttempts(); - successfulAttempts.forEach((taskPartitionId, attemptId) -> - selector.include(exchange.getId(), taskPartitionId, attemptId)); - selector.setPartitionCount(exchange.getId(), successfulAttempts.size()); - selector.setFinal(); - SpoolingExchangeInput input = new SpoolingExchangeInput(handles, Optional.of(selector.build())); - queryStateMachine.updateInputsForQueryResults(ImmutableList.of(input), true); - } - catch (Throwable t) { - queryStateMachine.transitionToFailed(t); - } - }); - addExceptionCallback(sourceHandles, queryStateMachine::transitionToFailed); - } - } - - return new Scheduler( - queryStateMachine, - ImmutableList.copyOf(schedulers.values()), - stageManager, - schedulerStats, - nodeAllocator); - } - catch (Throwable t) { - for (FaultTolerantStageScheduler scheduler : schedulers.values()) { - try { - scheduler.abort(); - } - catch (Throwable closeFailure) { - if (t != closeFailure) { - t.addSuppressed(closeFailure); - } - } - } - - try { - nodeAllocator.close(); - } - catch (Throwable closeFailure) { - if (t != closeFailure) { - t.addSuppressed(closeFailure); - } - } - - for (Exchange exchange : exchanges.values()) { - try { - exchange.close(); - } - catch (Throwable closeFailure) { - if (t != closeFailure) { - t.addSuppressed(closeFailure); - } - } - } - throw t; - } - } - - @Override - public void cancelStage(StageId stageId) - { - throw new UnsupportedOperationException("partial cancel is not supported in fault tolerant mode"); - } - - @Override - public void failTask(TaskId taskId, Throwable failureCause) - { - stageManager.failTaskRemotely(taskId, failureCause); - } - - @Override - public BasicStageStats getBasicStageStats() - { - return stageManager.getBasicStageStats(); - } - - @Override - public StageInfo getStageInfo() - { - return stageManager.getStageInfo(); - } - - @Override - public long getUserMemoryReservation() - { - return stageManager.getUserMemoryReservation(); - } - - @Override - public long getTotalMemoryReservation() - { - return stageManager.getTotalMemoryReservation(); - } - - @Override - public Duration getTotalCpuTime() - { - return stageManager.getTotalCpuTime(); - } - - private static class Scheduler - { - private final QueryStateMachine queryStateMachine; - private final List schedulers; - private final StageManager stageManager; - private final SplitSchedulerStats schedulerStats; - private final NodeAllocator nodeAllocator; - - private Scheduler( - QueryStateMachine queryStateMachine, - List schedulers, - StageManager stageManager, - SplitSchedulerStats schedulerStats, - NodeAllocator nodeAllocator) - { - this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); - this.stageManager = requireNonNull(stageManager, "stageManager is null"); - this.schedulers = ImmutableList.copyOf(requireNonNull(schedulers, "schedulers is null")); - this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); - this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); - } - - public void schedule() - { - if (schedulers.isEmpty()) { - queryStateMachine.transitionToFinishing(); - return; - } - - queryStateMachine.transitionToRunning(); - - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - List> blockedStages = new ArrayList<>(); - while (!isFinishingOrDone(queryStateMachine)) { - blockedStages.clear(); - boolean atLeastOneStageIsNotBlocked = false; - boolean allFinished = true; - for (FaultTolerantStageScheduler scheduler : schedulers) { - if (scheduler.isFinished()) { - stageManager.get(scheduler.getStageId()).finish(); - continue; - } - allFinished = false; - ListenableFuture blocked = scheduler.isBlocked(); - if (!blocked.isDone()) { - blockedStages.add(blocked); - continue; - } - try { - scheduler.schedule(); - } - catch (Throwable t) { - fail(t, Optional.of(scheduler.getStageId())); - return; - } - blocked = scheduler.isBlocked(); - if (!blocked.isDone()) { - blockedStages.add(blocked); - } - else { - atLeastOneStageIsNotBlocked = true; - } - } - if (allFinished) { - queryStateMachine.transitionToFinishing(); - return; - } - // wait for a state change and then schedule again - if (!atLeastOneStageIsNotBlocked) { - verify(!blockedStages.isEmpty(), "blockedStages is not expected to be empty here"); - try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { - try { - tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); - } - catch (CancellationException e) { - log.debug( - "Scheduling has been cancelled for query %s. Query state: %s", - queryStateMachine.getQueryId(), - queryStateMachine.getQueryState()); - } - } - } - } - } - catch (Throwable t) { - fail(t, Optional.empty()); - } - } - - public void cancel() - { - schedulers.forEach(FaultTolerantStageScheduler::cancel); - closeNodeAllocator(); - } - - public void abort() - { - schedulers.forEach(FaultTolerantStageScheduler::abort); - closeNodeAllocator(); - } - - private void fail(Throwable t, Optional failedStageId) - { - abort(); - stageManager.getStagesInTopologicalOrder().forEach(stage -> { - if (failedStageId.isPresent() && failedStageId.get().equals(stage.getStageId())) { - stage.fail(t); - } - else { - stage.abort(); - } - }); - queryStateMachine.transitionToFailed(t); - } - - private void closeNodeAllocator() - { - try { - nodeAllocator.close(); - } - catch (Throwable t) { - log.warn(t, "Error closing node allocator for query: %s", queryStateMachine.getQueryId()); - } - } - } - - private static boolean isFinishingOrDone(QueryStateMachine queryStateMachine) - { - QueryState queryState = queryStateMachine.getQueryState(); - return queryState == FINISHING || queryState.isDone(); - } -} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java deleted file mode 100644 index 4e134801e38e..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantStageScheduler.java +++ /dev/null @@ -1,799 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.base.Stopwatch; -import com.google.common.base.Ticker; -import com.google.common.base.VerifyException; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Multimap; -import com.google.common.collect.Ordering; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.airlift.concurrent.MoreFutures; -import io.airlift.log.Logger; -import io.airlift.units.DataSize; -import io.trino.Session; -import io.trino.exchange.SpoolingExchangeInput; -import io.trino.execution.ExecutionFailureInfo; -import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStage; -import io.trino.execution.StageId; -import io.trino.execution.TaskId; -import io.trino.execution.TaskState; -import io.trino.execution.TaskStatus; -import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.buffer.SpoolingOutputBuffers; -import io.trino.execution.scheduler.PartitionMemoryEstimator.MemoryRequirements; -import io.trino.failuredetector.FailureDetector; -import io.trino.metadata.InternalNode; -import io.trino.metadata.Split; -import io.trino.server.DynamicFilterService; -import io.trino.spi.ErrorCode; -import io.trino.spi.TrinoException; -import io.trino.spi.exchange.Exchange; -import io.trino.spi.exchange.ExchangeId; -import io.trino.spi.exchange.ExchangeSinkHandle; -import io.trino.spi.exchange.ExchangeSinkInstanceHandle; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.spi.exchange.ExchangeSourceOutputSelector; -import io.trino.split.RemoteSplit; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.RemoteSourceNode; - -import javax.annotation.concurrent.GuardedBy; - -import java.time.Duration; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Queue; -import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Throwables.propagateIfPossible; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableListMultimap.flatteningToImmutableListMultimap; -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.ImmutableSet.toImmutableSet; -import static com.google.common.util.concurrent.Futures.allAsList; -import static com.google.common.util.concurrent.Futures.immediateVoidFuture; -import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.concurrent.MoreFutures.asVoid; -import static io.airlift.concurrent.MoreFutures.getFutureValue; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultCoordinatorTaskMemory; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory; -import static io.trino.SystemSessionProperties.getRetryDelayScaleFactor; -import static io.trino.SystemSessionProperties.getRetryInitialDelay; -import static io.trino.SystemSessionProperties.getRetryMaxDelay; -import static io.trino.execution.scheduler.ErrorCodes.isOutOfMemoryError; -import static io.trino.execution.scheduler.Exchanges.getAllSourceHandles; -import static io.trino.failuredetector.FailureDetector.State.GONE; -import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; -import static io.trino.spi.ErrorType.EXTERNAL; -import static io.trino.spi.ErrorType.INTERNAL_ERROR; -import static io.trino.spi.ErrorType.USER_ERROR; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; -import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; -import static io.trino.util.Failures.toFailure; -import static java.lang.String.format; -import static java.util.Objects.requireNonNull; - -/** - * Deprecated in favor of {@link EventDrivenFaultTolerantQueryScheduler} - */ -@Deprecated -public class FaultTolerantStageScheduler -{ - private static final Logger log = Logger.get(FaultTolerantStageScheduler.class); - - private final Session session; - private final SqlStage stage; - private final FailureDetector failureDetector; - private final TaskSourceFactory taskSourceFactory; - private final NodeAllocator nodeAllocator; - private final TaskDescriptorStorage taskDescriptorStorage; - private final PartitionMemoryEstimator partitionMemoryEstimator; - private final TaskExecutionStats taskExecutionStats; - private final int maxRetryAttemptsPerTask; - private final int maxTasksWaitingForNodePerStage; - - private final Exchange sinkExchange; - private final FaultTolerantPartitioningScheme sinkPartitioningScheme; - - private final Map sourceSchedulers; - private final Map sourceExchanges; - private final FaultTolerantPartitioningScheme sourcePartitioningScheme; - - private final DelayedFutureCompletor futureCompletor; - - @GuardedBy("this") - private ListenableFuture blocked = immediateVoidFuture(); - - @GuardedBy("this") - private ListenableFuture tasksPopulatedFuture = immediateVoidFuture(); - - @GuardedBy("this") - private SettableFuture taskFinishedFuture; - - private final Duration minRetryDelay; - private final Duration maxRetryDelay; - private final double retryDelayScaleFactor; - - @GuardedBy("this") - private Optional delaySchedulingDuration = Optional.empty(); - @GuardedBy("this") - private final Stopwatch delayStopwatch; - @GuardedBy("this") - private SettableFuture delaySchedulingFuture; - - @GuardedBy("this") - private TaskSource taskSource; - @GuardedBy("this") - private final Map partitionToExchangeSinkHandleMap = new HashMap<>(); - @GuardedBy("this") - private final Multimap partitionToRemoteTaskMap = ArrayListMultimap.create(); - @GuardedBy("this") - private final Map runningTasks = new HashMap<>(); - @GuardedBy("this") - private final Map runningNodes = new HashMap<>(); - @GuardedBy("this") - private final Set allPartitions = new HashSet<>(); - @GuardedBy("this") - private boolean noMorePartitions; - @GuardedBy("this") - private final Queue queuedPartitions = new ArrayDeque<>(); - @GuardedBy("this") - private final Queue pendingPartitions = new ArrayDeque<>(); - @GuardedBy("this") - private final Map finishedPartitions = new HashMap<>(); - @GuardedBy("this") - private final AtomicInteger remainingRetryAttemptsOverall; - @GuardedBy("this") - private final Map remainingAttemptsPerTask = new HashMap<>(); - @GuardedBy("this") - private final Map partitionMemoryRequirements = new HashMap<>(); - @GuardedBy("this") - private Multimap outputSelectorSplits; - - private final DynamicFilterService dynamicFilterService; - - @GuardedBy("this") - private Throwable failure; - @GuardedBy("this") - private boolean closed; - - public FaultTolerantStageScheduler( - Session session, - SqlStage stage, - FailureDetector failureDetector, - TaskSourceFactory taskSourceFactory, - NodeAllocator nodeAllocator, - TaskDescriptorStorage taskDescriptorStorage, - PartitionMemoryEstimator partitionMemoryEstimator, - TaskExecutionStats taskExecutionStats, - DelayedFutureCompletor futureCompletor, - Ticker ticker, - Exchange sinkExchange, - FaultTolerantPartitioningScheme sinkPartitioningScheme, - Map sourceSchedulers, - Map sourceExchanges, - FaultTolerantPartitioningScheme sourcePartitioningScheme, - AtomicInteger remainingRetryAttemptsOverall, - int taskRetryAttemptsPerTask, - int maxTasksWaitingForNodePerStage, - DynamicFilterService dynamicFilterService) - { - this.session = requireNonNull(session, "session is null"); - this.stage = requireNonNull(stage, "stage is null"); - this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); - this.taskSourceFactory = requireNonNull(taskSourceFactory, "taskSourceFactory is null"); - this.nodeAllocator = requireNonNull(nodeAllocator, "nodeAllocator is null"); - this.taskDescriptorStorage = requireNonNull(taskDescriptorStorage, "taskDescriptorStorage is null"); - this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); - this.taskExecutionStats = requireNonNull(taskExecutionStats, "taskExecutionStats is null"); - this.futureCompletor = requireNonNull(futureCompletor, "futureCompletor is null"); - this.sinkExchange = requireNonNull(sinkExchange, "sinkExchange is null"); - this.sinkPartitioningScheme = requireNonNull(sinkPartitioningScheme, "sinkPartitioningScheme is null"); - Set sourceFragments = stage.getFragment().getRemoteSourceNodes().stream() - .flatMap(remoteSource -> remoteSource.getSourceFragmentIds().stream()) - .collect(toImmutableSet()); - requireNonNull(sourceSchedulers, "sourceSchedulers is null"); - checkArgument(sourceSchedulers.keySet().containsAll(sourceFragments), "sourceSchedulers map is incomplete"); - this.sourceSchedulers = ImmutableMap.copyOf(sourceSchedulers); - requireNonNull(sourceExchanges, "sourceExchanges is null"); - checkArgument(sourceExchanges.keySet().containsAll(sourceFragments), "sourceExchanges map is incomplete"); - this.sourceExchanges = ImmutableMap.copyOf(sourceExchanges); - this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null"); - this.remainingRetryAttemptsOverall = requireNonNull(remainingRetryAttemptsOverall, "remainingRetryAttemptsOverall is null"); - this.maxRetryAttemptsPerTask = taskRetryAttemptsPerTask; - this.maxTasksWaitingForNodePerStage = maxTasksWaitingForNodePerStage; - this.minRetryDelay = Duration.ofMillis(getRetryInitialDelay(session).toMillis()); - this.maxRetryDelay = Duration.ofMillis(getRetryMaxDelay(session).toMillis()); - this.retryDelayScaleFactor = getRetryDelayScaleFactor(session); - this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - this.delayStopwatch = Stopwatch.createUnstarted(ticker); - } - - public StageId getStageId() - { - return stage.getStageId(); - } - - public synchronized ListenableFuture isBlocked() - { - return nonCancellationPropagating(blocked); - } - - public synchronized void schedule() - throws Exception - { - if (failure != null) { - propagateIfPossible(failure, Exception.class); - throw new RuntimeException(failure); - } - - if (closed) { - return; - } - - if (isFinished()) { - return; - } - - if (!blocked.isDone()) { - return; - } - - if (delaySchedulingFuture != null && !delaySchedulingFuture.isDone()) { - // let's wait a bit more - blocked = delaySchedulingFuture; - return; - } - - if (taskSource == null) { - Map>> sourceHandles = sourceExchanges.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, entry -> getAllSourceHandles(entry.getValue().getSourceHandles()))); - - List>> blockedFutures = sourceHandles.values().stream() - .filter(future -> !future.isDone()) - .collect(toImmutableList()); - - if (!blockedFutures.isEmpty()) { - blocked = asVoid(allAsList(blockedFutures)); - return; - } - - Multimap exchangeSources = sourceHandles.entrySet().stream() - .collect(flatteningToImmutableListMultimap(Map.Entry::getKey, entry -> getFutureValue(entry.getValue()).stream())); - - taskSource = taskSourceFactory.create( - session, - stage.getFragment(), - exchangeSources, - stage::recordGetSplitTime, - sourcePartitioningScheme); - } - - while (!pendingPartitions.isEmpty() || !queuedPartitions.isEmpty() || !taskSource.isFinished()) { - while (queuedPartitions.isEmpty() && pendingPartitions.size() < maxTasksWaitingForNodePerStage && !taskSource.isFinished()) { - tasksPopulatedFuture = Futures.transform( - taskSource.getMoreTasks(), - tasks -> { - synchronized (this) { - for (TaskDescriptor task : tasks) { - queuedPartitions.add(task.getPartitionId()); - allPartitions.add(task.getPartitionId()); - taskDescriptorStorage.put(stage.getStageId(), task); - ExchangeSinkHandle exchangeSinkHandle = sinkExchange.addSink(task.getPartitionId()); - partitionToExchangeSinkHandleMap.put(task.getPartitionId(), exchangeSinkHandle); - } - if (taskSource.isFinished()) { - dynamicFilterService.stageCannotScheduleMoreTasks(stage.getStageId(), 0, allPartitions.size()); - sinkExchange.noMoreSinks(); - noMorePartitions = true; - } - if (noMorePartitions && finishedPartitions.keySet().containsAll(allPartitions)) { - sinkExchange.allRequiredSinksFinished(); - } - return null; - } - }, - directExecutor()); - if (!tasksPopulatedFuture.isDone()) { - blocked = tasksPopulatedFuture; - return; - } - } - - Iterator pendingPartitionsIterator = pendingPartitions.iterator(); - boolean startedTask = false; - while (pendingPartitionsIterator.hasNext()) { - PendingPartition pendingPartition = pendingPartitionsIterator.next(); - if (pendingPartition.getNodeLease().getNode().isDone()) { - MemoryRequirements memoryRequirements = partitionMemoryRequirements.get(pendingPartition.getPartition()); - verify(memoryRequirements != null, "no entry for %s.%s in partitionMemoryRequirements", stage.getStageId(), pendingPartition.getPartition()); - startTask(pendingPartition.getPartition(), pendingPartition.getNodeLease(), memoryRequirements); - startedTask = true; - pendingPartitionsIterator.remove(); - } - } - - if (!startedTask && (queuedPartitions.isEmpty() || pendingPartitions.size() >= maxTasksWaitingForNodePerStage)) { - break; - } - - while (pendingPartitions.size() < maxTasksWaitingForNodePerStage && !queuedPartitions.isEmpty()) { - int partition = queuedPartitions.poll(); - Optional taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition); - if (taskDescriptorOptional.isEmpty()) { - // query has been terminated - return; - } - TaskDescriptor taskDescriptor = taskDescriptorOptional.get(); - DataSize defaultTaskMemory = stage.getFragment().getPartitioning().equals(COORDINATOR_DISTRIBUTION) ? - getFaultTolerantExecutionDefaultCoordinatorTaskMemory(session) : - getFaultTolerantExecutionDefaultTaskMemory(session); - MemoryRequirements memoryRequirements = partitionMemoryRequirements.computeIfAbsent(partition, ignored -> partitionMemoryEstimator.getInitialMemoryRequirements(session, defaultTaskMemory)); - log.debug("Computed initial memory requirements for task from stage %s; requirements=%s; estimator=%s", stage.getStageId(), memoryRequirements, partitionMemoryEstimator); - NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements(); - NodeAllocator.NodeLease nodeLease = nodeAllocator.acquire(nodeRequirements, memoryRequirements.getRequiredMemory()); - - pendingPartitions.add(new PendingPartition(partition, nodeLease)); - } - } - - List> futures = new ArrayList<>(); - if (taskFinishedFuture != null && !taskFinishedFuture.isDone()) { - futures.add(taskFinishedFuture); - } - for (PendingPartition pendingPartition : pendingPartitions) { - futures.add(pendingPartition.getNodeLease().getNode()); - } - if (!futures.isEmpty()) { - blocked = asVoid(MoreFutures.whenAnyComplete(futures)); - } - } - - @GuardedBy("this") - private void startTask(int partition, NodeAllocator.NodeLease nodeLease, MemoryRequirements memoryRequirements) - { - Optional taskDescriptorOptional = taskDescriptorStorage.get(stage.getStageId(), partition); - if (taskDescriptorOptional.isEmpty()) { - // query has been terminated - return; - } - TaskDescriptor taskDescriptor = taskDescriptorOptional.get(); - - InternalNode node = getFutureValue(nodeLease.getNode()); - - int attemptId = getNextAttemptIdForPartition(partition); - - ExchangeSinkHandle sinkHandle = partitionToExchangeSinkHandleMap.get(partition); - ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = sinkExchange.instantiateSink(sinkHandle, attemptId); - OutputBuffers outputBuffers = SpoolingOutputBuffers.createInitial(exchangeSinkInstanceHandle, sinkPartitioningScheme.getPartitionCount()); - - Set allSourcePlanNodeIds = ImmutableSet.builder() - .addAll(stage.getFragment().getPartitionedSources()) - .addAll(stage.getFragment() - .getRemoteSourceNodes().stream() - .map(RemoteSourceNode::getId) - .iterator()) - .build(); - - createOutputSelectorSplitsIfNecessary(); - - RemoteTask task = stage.createTask( - node, - partition, - attemptId, - sinkPartitioningScheme.getBucketToPartitionMap(), - outputBuffers, - ImmutableListMultimap.builder() - .putAll(outputSelectorSplits) - .putAll(taskDescriptor.getSplits()) - .build(), - allSourcePlanNodeIds, - Optional.of(memoryRequirements.getRequiredMemory())).orElseThrow(() -> new VerifyException("stage execution is expected to be active")); - - nodeLease.attachTaskId(task.getTaskId()); - partitionToRemoteTaskMap.put(partition, task); - runningTasks.put(task.getTaskId(), task); - runningNodes.put(task.getTaskId(), nodeLease); - - if (taskFinishedFuture == null) { - taskFinishedFuture = SettableFuture.create(); - } - - task.addStateChangeListener(taskStatus -> updateTaskStatus(taskStatus, sinkHandle)); - task.addFinalTaskInfoListener(taskExecutionStats::update); - task.start(); - } - - @GuardedBy("this") - private void createOutputSelectorSplitsIfNecessary() - { - if (outputSelectorSplits != null) { - return; - } - - ImmutableListMultimap.Builder selectors = ImmutableListMultimap.builder(); - for (RemoteSourceNode remoteSource : stage.getFragment().getRemoteSourceNodes()) { - List sourceFragmentIds = remoteSource.getSourceFragmentIds(); - Set sourceExchangeIds = sourceExchanges.entrySet().stream() - .filter(entry -> sourceFragmentIds.contains(entry.getKey())) - .map(entry -> entry.getValue().getId()) - .collect(toImmutableSet()); - ExchangeSourceOutputSelector.Builder selector = ExchangeSourceOutputSelector.builder(sourceExchangeIds); - for (PlanFragmentId sourceFragment : sourceFragmentIds) { - FaultTolerantStageScheduler sourceScheduler = sourceSchedulers.get(sourceFragment); - Exchange sourceExchange = sourceExchanges.get(sourceFragment); - Map successfulAttempts = sourceScheduler.getSuccessfulAttempts(); - successfulAttempts.forEach((taskPartitionId, attemptId) -> - selector.include(sourceExchange.getId(), taskPartitionId, attemptId)); - selector.setPartitionCount(sourceExchange.getId(), successfulAttempts.size()); - } - selector.setFinal(); - selectors.put(remoteSource.getId(), new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(), Optional.of(selector.build()))))); - } - outputSelectorSplits = selectors.build(); - } - - public synchronized boolean isFinished() - { - return failure == null && - taskSource != null && - taskSource.isFinished() && - tasksPopulatedFuture.isDone() && - queuedPartitions.isEmpty() && - allPartitions.stream().allMatch(finishedPartitions::containsKey); - } - - public synchronized Map getSuccessfulAttempts() - { - return ImmutableMap.copyOf(finishedPartitions); - } - - public void cancel() - { - close(false); - } - - public void abort() - { - close(true); - } - - private void fail(Throwable t) - { - synchronized (this) { - if (failure == null) { - failure = t; - } - } - close(true); - } - - private void close(boolean abort) - { - boolean closed; - synchronized (this) { - closed = this.closed; - this.closed = true; - } - if (!closed) { - cancelRunningTasks(abort); - cancelBlockedFuture(); - releasePendingNodes(); - closeTaskSource(); - closeSinkExchange(); - } - } - - private void cancelRunningTasks(boolean abort) - { - List tasks; - synchronized (this) { - tasks = ImmutableList.copyOf(runningTasks.values()); - } - if (abort) { - tasks.forEach(RemoteTask::abort); - } - else { - tasks.forEach(RemoteTask::cancel); - } - } - - private void cancelBlockedFuture() - { - verify(!Thread.holdsLock(this)); - ListenableFuture future; - synchronized (this) { - future = blocked; - } - if (future != null && !future.isDone()) { - future.cancel(true); - } - } - - private void releasePendingNodes() - { - verify(!Thread.holdsLock(this)); - List leases = new ArrayList<>(); - synchronized (this) { - for (PendingPartition pendingPartition : pendingPartitions) { - leases.add(pendingPartition.getNodeLease()); - } - pendingPartitions.clear(); - } - for (NodeAllocator.NodeLease lease : leases) { - lease.release(); - } - } - - private void closeTaskSource() - { - TaskSource taskSource; - synchronized (this) { - taskSource = this.taskSource; - } - if (taskSource != null) { - try { - taskSource.close(); - } - catch (RuntimeException e) { - log.warn(e, "Error closing task source for stage: %s", stage.getStageId()); - } - } - } - - private void closeSinkExchange() - { - try { - sinkExchange.close(); - } - catch (RuntimeException e) { - log.warn(e, "Error closing sink exchange for stage: %s", stage.getStageId()); - } - } - - private int getNextAttemptIdForPartition(int partition) - { - int latestAttemptId = partitionToRemoteTaskMap.get(partition).stream() - .mapToInt(task -> task.getTaskId().getAttemptId()) - .max() - .orElse(-1); - return latestAttemptId + 1; - } - - private void updateTaskStatus(TaskStatus taskStatus, ExchangeSinkHandle exchangeSinkHandle) - { - TaskState state = taskStatus.getState(); - if (!state.isDone()) { - return; - } - - try { - RuntimeException failure = null; - SettableFuture previousTaskFinishedFuture; - SettableFuture previousDelaySchedulingFuture = null; - synchronized (this) { - TaskId taskId = taskStatus.getTaskId(); - - runningTasks.remove(taskId); - previousTaskFinishedFuture = taskFinishedFuture; - if (!runningTasks.isEmpty()) { - taskFinishedFuture = SettableFuture.create(); - } - else { - taskFinishedFuture = null; - } - - NodeAllocator.NodeLease nodeLease = requireNonNull(runningNodes.remove(taskId), () -> "node not found for task id: " + taskId); - nodeLease.release(); - - int partitionId = taskId.getPartitionId(); - - if (!finishedPartitions.containsKey(partitionId) && !closed) { - MemoryRequirements memoryLimits = partitionMemoryRequirements.get(partitionId); - verify(memoryLimits != null); - switch (state) { - case FINISHED: - finishedPartitions.put(partitionId, taskId.getAttemptId()); - sinkExchange.sinkFinished(exchangeSinkHandle, taskId.getAttemptId()); - if (noMorePartitions && finishedPartitions.keySet().containsAll(allPartitions)) { - sinkExchange.allRequiredSinksFinished(); - } - partitionToRemoteTaskMap.get(partitionId).forEach(RemoteTask::abort); - partitionMemoryEstimator.registerPartitionFinished(session, memoryLimits, taskStatus.getPeakMemoryReservation(), true, Optional.empty()); - - if (delayStopwatch.isRunning() && delayStopwatch.elapsed().compareTo(delaySchedulingDuration.get()) > 0) { - // we are past delay period and task completed successfully; reset delay - previousDelaySchedulingFuture = delaySchedulingFuture; - delayStopwatch.reset(); - delaySchedulingDuration = Optional.empty(); - delaySchedulingFuture = null; - } - - // Remove taskDescriptor for finished partition to conserve memory - // We may revisit the approach when we support volatile exchanges, for which - // it may be needed to restart already finished task to recreate output it produced. - taskDescriptorStorage.remove(stage.getStageId(), partitionId); - - break; - case CANCELED: - log.debug("Task cancelled: %s", taskId); - // no need for partitionMemoryEstimator.registerPartitionFinished; task cancelled mid-way - break; - case ABORTED: - log.debug("Task aborted: %s", taskId); - // no need for partitionMemoryEstimator.registerPartitionFinished; task aborted mid-way - break; - case FAILED: - ExecutionFailureInfo failureInfo = taskStatus.getFailures().stream() - .findFirst() - .map(this::rewriteTransportFailure) - .orElseGet(() -> toFailure(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"))); - log.warn(failureInfo.toException(), "Task failed: %s", taskId); - ErrorCode errorCode = failureInfo.getErrorCode(); - partitionMemoryEstimator.registerPartitionFinished(session, memoryLimits, taskStatus.getPeakMemoryReservation(), false, Optional.ofNullable(errorCode)); - - boolean coordinatorStage = stage.getFragment().getPartitioning().equals(COORDINATOR_DISTRIBUTION); - // coordinator tasks cannot be retried - int taskRemainingAttempts = remainingAttemptsPerTask.getOrDefault(partitionId, coordinatorStage ? 0 : maxRetryAttemptsPerTask); - if (remainingRetryAttemptsOverall.get() > 0 - && taskRemainingAttempts > 0 - && (errorCode == null || errorCode.getType() != USER_ERROR)) { - remainingRetryAttemptsOverall.decrementAndGet(); - remainingAttemptsPerTask.put(partitionId, taskRemainingAttempts - 1); - - // update memory limits for next attempt - MemoryRequirements newMemoryLimits = partitionMemoryEstimator.getNextRetryMemoryRequirements(session, memoryLimits, taskStatus.getPeakMemoryReservation(), errorCode); - log.debug("Computed next memory requirements for task from stage %s; previous=%s; new=%s; peak=%s; estimator=%s", stage.getStageId(), memoryLimits, newMemoryLimits, taskStatus.getPeakMemoryReservation(), partitionMemoryEstimator); - - if (errorCode != null && isOutOfMemoryError(errorCode) && newMemoryLimits.getRequiredMemory().toBytes() * 0.99 <= taskStatus.getPeakMemoryReservation().toBytes()) { - String message = format( - "Cannot allocate enough memory for task %s. Reported peak memory reservation: %s. Maximum possible reservation: %s.", - taskId, - taskStatus.getPeakMemoryReservation(), - newMemoryLimits.getRequiredMemory()); - failure = new TrinoException(() -> errorCode, message, failureInfo.toException()); - break; - } - - partitionMemoryRequirements.put(partitionId, newMemoryLimits); - - // reschedule - queuedPartitions.add(partitionId); - log.debug("Retrying partition %s for stage %s", partitionId, stage.getStageId()); - - if (errorCode != null && shouldDelayScheduling(errorCode)) { - if (delayStopwatch.isRunning()) { - // we are currently delaying tasks scheduling - checkState(delaySchedulingDuration.isPresent()); - - if (delayStopwatch.elapsed().compareTo(delaySchedulingDuration.get()) > 0) { - // we are past previous delay period and still getting failures; let's make it longer - delayStopwatch.reset().start(); - delaySchedulingDuration = delaySchedulingDuration.map(duration -> - Ordering.natural().min( - Duration.ofMillis((long) (duration.toMillis() * retryDelayScaleFactor)), - maxRetryDelay)); - - // create new future - previousDelaySchedulingFuture = delaySchedulingFuture; - SettableFuture newDelaySchedulingFuture = SettableFuture.create(); - delaySchedulingFuture = newDelaySchedulingFuture; - futureCompletor.completeFuture(newDelaySchedulingFuture, delaySchedulingDuration.get()); - } - } - else { - // initialize delaying of tasks scheduling - delayStopwatch.start(); - delaySchedulingDuration = Optional.of(minRetryDelay); - delaySchedulingFuture = SettableFuture.create(); - futureCompletor.completeFuture(delaySchedulingFuture, delaySchedulingDuration.get()); - } - } - } - else { - failure = failureInfo.toException(); - } - break; - default: - throw new IllegalArgumentException("Unexpected task state: " + state); - } - } - } - if (failure != null) { - // must be called outside the lock - fail(failure); - } - if (previousTaskFinishedFuture != null && !previousTaskFinishedFuture.isDone()) { - previousTaskFinishedFuture.set(null); - } - if (previousDelaySchedulingFuture != null && !previousDelaySchedulingFuture.isDone()) { - previousDelaySchedulingFuture.set(null); - } - } - catch (Throwable t) { - fail(t); - } - } - - private boolean shouldDelayScheduling(ErrorCode errorCode) - { - return errorCode.getType() == INTERNAL_ERROR || errorCode.getType() == EXTERNAL; - } - - private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) - { - if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { - return executionFailureInfo; - } - - return new ExecutionFailureInfo( - executionFailureInfo.getType(), - executionFailureInfo.getMessage(), - executionFailureInfo.getCause(), - executionFailureInfo.getSuppressed(), - executionFailureInfo.getStack(), - executionFailureInfo.getErrorLocation(), - REMOTE_HOST_GONE.toErrorCode(), - executionFailureInfo.getRemoteHost()); - } - - private static class PendingPartition - { - private final int partition; - private final NodeAllocator.NodeLease nodeLease; - - public PendingPartition(int partition, NodeAllocator.NodeLease nodeLease) - { - this.partition = partition; - this.nodeLease = requireNonNull(nodeLease, "nodeLease is null"); - } - - public int getPartition() - { - return partition; - } - - public NodeAllocator.NodeLease getNodeLease() - { - return nodeLease; - } - } - - public interface DelayedFutureCompletor - { - void completeFuture(SettableFuture future, Duration delay); - } -} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java deleted file mode 100644 index 9e46ba93214e..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageTaskSourceFactory.java +++ /dev/null @@ -1,1026 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ListMultimap; -import com.google.common.collect.Multimap; -import com.google.common.collect.Multimaps; -import com.google.common.collect.SetMultimap; -import com.google.common.collect.Sets; -import com.google.common.util.concurrent.AbstractFuture; -import com.google.common.util.concurrent.FutureCallback; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import io.airlift.log.Logger; -import io.airlift.units.DataSize; -import io.trino.Session; -import io.trino.exchange.SpoolingExchangeInput; -import io.trino.execution.ForQueryExecution; -import io.trino.execution.QueryManagerConfig; -import io.trino.execution.TableExecuteContext; -import io.trino.execution.TableExecuteContextManager; -import io.trino.metadata.InternalNode; -import io.trino.metadata.InternalNodeManager; -import io.trino.metadata.Split; -import io.trino.spi.HostAddress; -import io.trino.spi.Node; -import io.trino.spi.QueryId; -import io.trino.spi.SplitWeight; -import io.trino.spi.connector.CatalogHandle; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.split.RemoteSplit; -import io.trino.split.SplitSource; -import io.trino.split.SplitSource.SplitBatch; -import io.trino.sql.planner.MergePartitioningHandle; -import io.trino.sql.planner.PartitioningHandle; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.SplitSourceFactory; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNode; -import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.PlanVisitor; -import io.trino.sql.planner.plan.RemoteSourceNode; -import io.trino.sql.planner.plan.TableWriterNode; - -import javax.annotation.concurrent.GuardedBy; -import javax.inject.Inject; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.Executor; -import java.util.concurrent.ExecutorService; -import java.util.function.LongConsumer; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; -import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.common.collect.Sets.newIdentityHashSet; -import static com.google.common.collect.Sets.union; -import static com.google.common.util.concurrent.Futures.addCallback; -import static com.google.common.util.concurrent.Futures.allAsList; -import static com.google.common.util.concurrent.Futures.immediateFuture; -import static io.airlift.concurrent.MoreFutures.addSuccessCallback; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxTaskSplitCount; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinTaskSplitCount; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskInputSize; -import static io.trino.SystemSessionProperties.getFaultTolerantExecutionTargetTaskSplitCount; -import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; -import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; -import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; -import static java.util.Objects.requireNonNull; - -/** - * Deprecated in favor of {@link EventDrivenTaskSourceFactory} - */ -@Deprecated -public class StageTaskSourceFactory - implements TaskSourceFactory -{ - private static final Logger log = Logger.get(StageTaskSourceFactory.class); - - private final SplitSourceFactory splitSourceFactory; - private final TableExecuteContextManager tableExecuteContextManager; - private final int splitBatchSize; - private final Executor executor; - private final InternalNodeManager nodeManager; - - @Inject - public StageTaskSourceFactory( - SplitSourceFactory splitSourceFactory, - TableExecuteContextManager tableExecuteContextManager, - QueryManagerConfig queryManagerConfig, - @ForQueryExecution ExecutorService executor, - InternalNodeManager nodeManager) - { - this( - splitSourceFactory, - tableExecuteContextManager, - queryManagerConfig.getScheduleSplitBatchSize(), - executor, - nodeManager); - } - - public StageTaskSourceFactory( - SplitSourceFactory splitSourceFactory, - TableExecuteContextManager tableExecuteContextManager, - int splitBatchSize, - ExecutorService executor, - InternalNodeManager nodeManager) - { - this.splitSourceFactory = requireNonNull(splitSourceFactory, "splitSourceFactory is null"); - this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); - this.splitBatchSize = splitBatchSize; - this.executor = requireNonNull(executor, "executor is null"); - this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); - } - - @Override - public TaskSource create( - Session session, - PlanFragment fragment, - Multimap exchangeSourceHandles, - LongConsumer getSplitTimeRecorder, - FaultTolerantPartitioningScheme sourcePartitioningScheme) - { - PartitioningHandle partitioning = fragment.getPartitioning(); - - if (partitioning.equals(SINGLE_DISTRIBUTION) || partitioning.equals(COORDINATOR_DISTRIBUTION)) { - return SingleDistributionTaskSource.create(fragment, exchangeSourceHandles, nodeManager, partitioning.equals(COORDINATOR_DISTRIBUTION)); - } - if (partitioning.equals(FIXED_ARBITRARY_DISTRIBUTION) || partitioning.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) { - return ArbitraryDistributionTaskSource.create( - fragment, - exchangeSourceHandles, - getFaultTolerantExecutionTargetTaskInputSize(session)); - } - if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalogHandle().isPresent() || - (partitioning.getConnectorHandle() instanceof MergePartitioningHandle)) { - return HashDistributionTaskSource.create( - session, - fragment, - splitSourceFactory, - exchangeSourceHandles, - splitBatchSize, - getSplitTimeRecorder, - sourcePartitioningScheme, - getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), - getFaultTolerantExecutionTargetTaskInputSize(session), - executor); - } - if (partitioning.equals(SOURCE_DISTRIBUTION)) { - return SourceDistributionTaskSource.create( - session, - fragment, - splitSourceFactory, - exchangeSourceHandles, - tableExecuteContextManager, - splitBatchSize, - getSplitTimeRecorder, - getFaultTolerantExecutionMinTaskSplitCount(session), - getFaultTolerantExecutionTargetTaskSplitCount(session) * SplitWeight.standard().getRawValue(), - getFaultTolerantExecutionMaxTaskSplitCount(session), - executor); - } - - // other partitioning handles are not expected to be set as a fragment partitioning - throw new IllegalArgumentException("Unexpected partitioning: " + partitioning); - } - - public static class SingleDistributionTaskSource - implements TaskSource - { - private final ListMultimap splits; - private final InternalNodeManager nodeManager; - private final boolean coordinatorOnly; - - private boolean finished; - - public static SingleDistributionTaskSource create( - PlanFragment fragment, - Multimap exchangeSourceHandles, - InternalNodeManager nodeManager, - boolean coordinatorOnly) - { - checkArgument(fragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", fragment.getPartitionedSources()); - return new SingleDistributionTaskSource( - createRemoteSplits(getInputsForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles)), - nodeManager, - coordinatorOnly); - } - - @VisibleForTesting - SingleDistributionTaskSource(ListMultimap splits, InternalNodeManager nodeManager, boolean coordinatorOnly) - { - this.splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null")); - this.nodeManager = requireNonNull(nodeManager, "nodeManager"); - this.coordinatorOnly = coordinatorOnly; - } - - @Override - public ListenableFuture> getMoreTasks() - { - if (finished) { - return immediateFuture(ImmutableList.of()); - } - ImmutableSet hostRequirement = ImmutableSet.of(); - if (coordinatorOnly) { - Node currentNode = nodeManager.getCurrentNode(); - verify(currentNode.isCoordinator(), "current node is expected to be a coordinator"); - hostRequirement = ImmutableSet.of(currentNode.getHostAndPort()); - } - List result = ImmutableList.of(new TaskDescriptor( - 0, - splits, - new NodeRequirements(Optional.empty(), hostRequirement))); - finished = true; - return immediateFuture(result); - } - - @Override - public boolean isFinished() - { - return finished; - } - - @Override - public void close() - { - } - } - - public static class ArbitraryDistributionTaskSource - implements TaskSource - { - private final Multimap partitionedExchangeSourceHandles; - private final Multimap replicatedExchangeSourceHandles; - private final long targetPartitionSizeInBytes; - - private boolean finished; - - public static ArbitraryDistributionTaskSource create( - PlanFragment fragment, - Multimap exchangeSourceHandles, - DataSize targetPartitionSize) - { - checkArgument(fragment.getPartitionedSources().isEmpty(), "no partitioned sources (table scans) expected, got: %s", fragment.getPartitionedSources()); - return new ArbitraryDistributionTaskSource( - getPartitionedExchangeSourceHandles(fragment, exchangeSourceHandles), - getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles), - targetPartitionSize); - } - - @VisibleForTesting - ArbitraryDistributionTaskSource( - Multimap partitionedExchangeSourceHandles, - Multimap replicatedExchangeSourceHandles, - DataSize targetPartitionSize) - { - this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(partitionedExchangeSourceHandles, "partitionedExchangeSourceHandles is null")); - this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null")); - this.targetPartitionSizeInBytes = requireNonNull(targetPartitionSize, "targetPartitionSize is null").toBytes(); - } - - @Override - public ListenableFuture> getMoreTasks() - { - if (finished) { - return immediateFuture(ImmutableList.of()); - } - NodeRequirements nodeRequirements = new NodeRequirements(Optional.empty(), ImmutableSet.of()); - - ImmutableList.Builder result = ImmutableList.builder(); - int currentPartitionId = 0; - - ListMultimap assignedExchangeSourceHandles = ArrayListMultimap.create(); - long assignedExchangeDataSize = 0; - int assignedExchangeSourceHandleCount = 0; - - for (Map.Entry entry : partitionedExchangeSourceHandles.entries()) { - PlanNodeId remoteSourcePlanNodeId = entry.getKey(); - ExchangeSourceHandle handle = entry.getValue(); - long handleDataSizeInBytes = handle.getDataSizeInBytes(); - - if (assignedExchangeDataSize != 0 && assignedExchangeDataSize + handleDataSizeInBytes > targetPartitionSizeInBytes) { - assignedExchangeSourceHandles.putAll(replicatedExchangeSourceHandles); - result.add(new TaskDescriptor( - currentPartitionId++, - createRemoteSplits(assignedExchangeSourceHandles), - nodeRequirements)); - assignedExchangeSourceHandles.clear(); - assignedExchangeDataSize = 0; - assignedExchangeSourceHandleCount = 0; - } - - assignedExchangeSourceHandles.put(remoteSourcePlanNodeId, handle); - assignedExchangeDataSize += handleDataSizeInBytes; - assignedExchangeSourceHandleCount++; - } - - if (assignedExchangeSourceHandleCount > 0) { - assignedExchangeSourceHandles.putAll(replicatedExchangeSourceHandles); - result.add(new TaskDescriptor(currentPartitionId, createRemoteSplits(assignedExchangeSourceHandles), nodeRequirements)); - } - - finished = true; - return immediateFuture(result.build()); - } - - @Override - public boolean isFinished() - { - return finished; - } - - @Override - public void close() - { - } - } - - public static class HashDistributionTaskSource - implements TaskSource - { - private final Map splitSources; - private final ListMultimap partitionedExchangeSourceHandles; - private final ListMultimap replicatedExchangeSourceHandles; - - private final int splitBatchSize; - private final LongConsumer getSplitTimeRecorder; - private final FaultTolerantPartitioningScheme sourcePartitioningScheme; - private final Optional catalogRequirement; - private final long targetPartitionSourceSizeInBytes; // compared data read from ExchangeSources - private final long targetPartitionSplitWeight; // compared against splits from SplitSources - private final Executor executor; - - @GuardedBy("this") - private ListenableFuture> loadedSplitsFuture; - @GuardedBy("this") - private boolean finished; - @GuardedBy("this") - private boolean closed; - - public static HashDistributionTaskSource create( - Session session, - PlanFragment fragment, - SplitSourceFactory splitSourceFactory, - Multimap exchangeSourceHandles, - int splitBatchSize, - LongConsumer getSplitTimeRecorder, - FaultTolerantPartitioningScheme sourcePartitioningScheme, - long targetPartitionSplitWeight, - DataSize targetPartitionSourceSize, - Executor executor) - { - Map splitSources = splitSourceFactory.createSplitSources(session, fragment); - return new HashDistributionTaskSource( - splitSources, - getPartitionedExchangeSourceHandles(fragment, exchangeSourceHandles), - getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles), - splitBatchSize, - getSplitTimeRecorder, - sourcePartitioningScheme, - fragment.getPartitioning().getCatalogHandle(), - targetPartitionSplitWeight, - isWriteFragment(fragment) ? DataSize.of(0, BYTE) : targetPartitionSourceSize, - executor); - } - - private static boolean isWriteFragment(PlanFragment fragment) - { - PlanVisitor visitor = new PlanVisitor<>() - { - @Override - protected Boolean visitPlan(PlanNode node, Void context) - { - for (PlanNode child : node.getSources()) { - if (child.accept(this, context)) { - return true; - } - } - return false; - } - - @Override - public Boolean visitTableWriter(TableWriterNode node, Void context) - { - return true; - } - }; - - return fragment.getRoot().accept(visitor, null); - } - - @VisibleForTesting - HashDistributionTaskSource( - Map splitSources, - ListMultimap partitionedExchangeSourceHandles, - ListMultimap replicatedExchangeSourceHandles, - int splitBatchSize, - LongConsumer getSplitTimeRecorder, - FaultTolerantPartitioningScheme sourcePartitioningScheme, - Optional catalogRequirement, - long targetPartitionSplitWeight, - DataSize targetPartitionSourceSize, - Executor executor) - { - this.splitSources = ImmutableMap.copyOf(requireNonNull(splitSources, "splitSources is null")); - this.partitionedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(partitionedExchangeSourceHandles, "partitionedExchangeSourceHandles is null")); - this.replicatedExchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(replicatedExchangeSourceHandles, "replicatedExchangeSourceHandles is null")); - this.splitBatchSize = splitBatchSize; - this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); - this.sourcePartitioningScheme = requireNonNull(sourcePartitioningScheme, "sourcePartitioningScheme is null"); - this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); - this.targetPartitionSourceSizeInBytes = targetPartitionSourceSize.toBytes(); - this.targetPartitionSplitWeight = targetPartitionSplitWeight; - this.executor = requireNonNull(executor, "executor is null"); - } - - @Override - public synchronized ListenableFuture> getMoreTasks() - { - if (finished || closed) { - return immediateFuture(ImmutableList.of()); - } - checkState(loadedSplitsFuture == null, "getMoreTasks called again while splits are being loaded"); - - List> splitSourceCompletionFutures = splitSources.entrySet().stream() - .map(entry -> { - SplitLoadingFuture future = new SplitLoadingFuture(entry.getKey(), entry.getValue(), splitBatchSize, getSplitTimeRecorder, executor); - future.load(); - return future; - }) - .collect(toImmutableList()); - - loadedSplitsFuture = allAsList(splitSourceCompletionFutures); - return Futures.transform( - loadedSplitsFuture, - loadedSplitsList -> { - synchronized (this) { - Map> partitionToSplitsMap = new HashMap<>(); - SetMultimap partitionToNodeMap = HashMultimap.create(); - for (LoadedSplits loadedSplits : loadedSplitsList) { - for (Split split : loadedSplits.getSplits()) { - int partition = sourcePartitioningScheme.getPartition(split); - Optional assignedNode = sourcePartitioningScheme.getNodeRequirement(partition); - if (assignedNode.isPresent()) { - HostAddress requiredAddress = assignedNode.get().getHostAndPort(); - Set existingRequirement = partitionToNodeMap.get(partition); - if (existingRequirement.isEmpty()) { - existingRequirement.add(requiredAddress); - } - else { - checkState( - existingRequirement.contains(requiredAddress), - "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", - partition, - existingRequirement, - requiredAddress); - existingRequirement.removeIf(host -> !host.equals(requiredAddress)); - } - } - - if (!split.isRemotelyAccessible()) { - Set requiredAddresses = ImmutableSet.copyOf(split.getAddresses()); - verify(!requiredAddresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty: %s", split); - Set existingRequirement = partitionToNodeMap.get(partition); - if (existingRequirement.isEmpty()) { - existingRequirement.addAll(requiredAddresses); - } - else { - Set intersection = Sets.intersection(requiredAddresses, existingRequirement); - checkState( - !intersection.isEmpty(), - "Unable to satisfy host requirement for partition %s. Existing requirement %s; Current split requirement: %s;", - partition, - existingRequirement, - requiredAddresses); - partitionToNodeMap.replaceValues(partition, ImmutableSet.copyOf(intersection)); - } - } - - Multimap partitionSplits = partitionToSplitsMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); - partitionSplits.put(loadedSplits.getPlanNodeId(), split); - } - } - - Map> partitionToExchangeSourceHandlesMap = new HashMap<>(); - for (Map.Entry entry : partitionedExchangeSourceHandles.entries()) { - PlanNodeId planNodeId = entry.getKey(); - ExchangeSourceHandle handle = entry.getValue(); - int partition = handle.getPartitionId(); - Multimap partitionSourceHandles = partitionToExchangeSourceHandlesMap.computeIfAbsent(partition, (p) -> ArrayListMultimap.create()); - partitionSourceHandles.put(planNodeId, handle); - } - - int taskPartitionId = 0; - ImmutableList.Builder partitionTasks = ImmutableList.builder(); - for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) { - ImmutableListMultimap.Builder splits = ImmutableListMultimap.builder(); - splits.putAll(partitionToSplitsMap.getOrDefault(partition, ImmutableListMultimap.of())); - // replicated exchange source will be added in postprocessTasks below - splits.putAll(createRemoteSplits(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableListMultimap.of()))); - Set hostRequirement = partitionToNodeMap.get(partition); - partitionTasks.add(new TaskDescriptor(taskPartitionId++, splits.build(), new NodeRequirements(catalogRequirement, hostRequirement))); - } - - List result = postprocessTasks(partitionTasks.build()); - finished = true; - return result; - } - }, - executor); - } - - private List postprocessTasks(List tasks) - { - ListMultimap taskGroups = groupCompatibleTasks(tasks); - ImmutableList.Builder joinedTasks = ImmutableList.builder(); - long replicatedExchangeSourcesSize = replicatedExchangeSourceHandles.values().stream().mapToLong(ExchangeSourceHandle::getDataSizeInBytes).sum(); - int taskPartitionId = 0; - for (Map.Entry> taskGroup : taskGroups.asMap().entrySet()) { - NodeRequirements groupNodeRequirements = taskGroup.getKey(); - Collection groupTasks = taskGroup.getValue(); - - ImmutableListMultimap.Builder splits = ImmutableListMultimap.builder(); - long splitsWeight = 0; - long exchangeSourcesSize = 0; - - for (TaskDescriptor task : groupTasks) { - ListMultimap taskSplits = task.getSplits(); - long taskSplitWeight = 0; - long taskExchangeSourcesSize = 0; - for (Split split : taskSplits.values()) { - if (split.getCatalogHandle().equals(REMOTE_CATALOG_HANDLE)) { - RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); - SpoolingExchangeInput exchangeInput = (SpoolingExchangeInput) remoteSplit.getExchangeInput(); - taskExchangeSourcesSize += exchangeInput.getExchangeSourceHandles().stream().mapToLong(ExchangeSourceHandle::getDataSizeInBytes).sum(); - } - else { - taskSplitWeight += split.getSplitWeight().getRawValue(); - } - } - - if ((splitsWeight > 0 || exchangeSourcesSize > 0) - && ((splitsWeight + taskSplitWeight) > targetPartitionSplitWeight || (exchangeSourcesSize + taskExchangeSourcesSize + replicatedExchangeSourcesSize) > targetPartitionSourceSizeInBytes)) { - splits.putAll(createRemoteSplits(replicatedExchangeSourceHandles)); // add replicated exchanges - joinedTasks.add(new TaskDescriptor(taskPartitionId++, splits.build(), groupNodeRequirements)); - splits = ImmutableListMultimap.builder(); - splitsWeight = 0; - exchangeSourcesSize = 0; - } - - splits.putAll(taskSplits); - splitsWeight += taskSplitWeight; - exchangeSourcesSize += taskExchangeSourcesSize; - } - - ImmutableListMultimap remainderSplits = splits.build(); - if (!remainderSplits.isEmpty()) { - joinedTasks.add(new TaskDescriptor( - taskPartitionId++, - ImmutableListMultimap.builder() - .putAll(remainderSplits) - // add replicated exchanges - .putAll(createRemoteSplits(replicatedExchangeSourceHandles)) - .build(), - groupNodeRequirements)); - } - } - return joinedTasks.build(); - } - - private ListMultimap groupCompatibleTasks(List tasks) - { - return Multimaps.index(tasks, TaskDescriptor::getNodeRequirements); - } - - @Override - public synchronized boolean isFinished() - { - return finished; - } - - @Override - public synchronized void close() - { - if (closed) { - return; - } - closed = true; - for (SplitSource splitSource : splitSources.values()) { - try { - splitSource.close(); - } - catch (RuntimeException e) { - log.error(e, "Error closing split source"); - } - } - } - } - - public static class SourceDistributionTaskSource - implements TaskSource - { - private final QueryId queryId; - private final PlanNodeId partitionedSourceNodeId; - private final TableExecuteContextManager tableExecuteContextManager; - private final SplitSource splitSource; - private final ListMultimap replicatedSplits; - private final int splitBatchSize; - private final LongConsumer getSplitTimeRecorder; - private final Optional catalogRequirement; - private final int minPartitionSplitCount; - private final long targetPartitionSplitWeight; - private final int maxPartitionSplitCount; - private final Executor executor; - - @GuardedBy("this") - private final Set remotelyAccessibleSplitBuffer = newIdentityHashSet(); - @GuardedBy("this") - private final Map> locallyAccessibleSplitBuffer = new HashMap<>(); - - @GuardedBy("this") - private int currentPartitionId; - @GuardedBy("this") - private boolean finished; - @GuardedBy("this") - private boolean closed; - @GuardedBy("this") - private ListenableFuture currentSplitBatchFuture = immediateFuture(null); - - public static SourceDistributionTaskSource create( - Session session, - PlanFragment fragment, - SplitSourceFactory splitSourceFactory, - Multimap exchangeSourceHandles, - TableExecuteContextManager tableExecuteContextManager, - int splitBatchSize, - LongConsumer getSplitTimeRecorder, - int minPartitionSplitCount, - long targetPartitionSplitWeight, - int maxPartitionSplitCount, - Executor executor) - { - checkArgument(fragment.getPartitionedSources().size() == 1, "single partitioned source is expected, got: %s", fragment.getPartitionedSources()); - - List remoteSourceNodes = fragment.getRemoteSourceNodes(); - checkArgument(remoteSourceNodes.stream().allMatch(node -> node.getExchangeType() == REPLICATE), "only replicated exchanges are expected in source distributed stage, got: %s", remoteSourceNodes); - - PlanNodeId partitionedSourceNodeId = getOnlyElement(fragment.getPartitionedSources()); - Map splitSources = splitSourceFactory.createSplitSources(session, fragment); - SplitSource splitSource = splitSources.get(partitionedSourceNodeId); - - Optional catalogName = Optional.of(splitSource.getCatalogHandle()) - .filter(catalog -> !catalog.getType().isInternal()); - - return new SourceDistributionTaskSource( - session.getQueryId(), - partitionedSourceNodeId, - tableExecuteContextManager, - splitSource, - createRemoteSplits(getReplicatedExchangeSourceHandles(fragment, exchangeSourceHandles)), - splitBatchSize, - getSplitTimeRecorder, - catalogName, - minPartitionSplitCount, - targetPartitionSplitWeight, - maxPartitionSplitCount, - executor); - } - - @VisibleForTesting - SourceDistributionTaskSource( - QueryId queryId, - PlanNodeId partitionedSourceNodeId, - TableExecuteContextManager tableExecuteContextManager, - SplitSource splitSource, - ListMultimap replicatedSplits, - int splitBatchSize, - LongConsumer getSplitTimeRecorder, - Optional catalogRequirement, - int minPartitionSplitCount, - long targetPartitionSplitWeight, - int maxPartitionSplitCount, - Executor executor) - { - this.queryId = requireNonNull(queryId, "queryId is null"); - this.partitionedSourceNodeId = requireNonNull(partitionedSourceNodeId, "partitionedSourceNodeId is null"); - this.tableExecuteContextManager = requireNonNull(tableExecuteContextManager, "tableExecuteContextManager is null"); - this.splitSource = requireNonNull(splitSource, "splitSource is null"); - this.replicatedSplits = ImmutableListMultimap.copyOf(requireNonNull(replicatedSplits, "replicatedSplits is null")); - this.splitBatchSize = splitBatchSize; - this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); - this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); - checkArgument(targetPartitionSplitWeight > 0, "targetPartitionSplitCount must be greater than 0: %s", targetPartitionSplitWeight); - this.targetPartitionSplitWeight = targetPartitionSplitWeight; - checkArgument(minPartitionSplitCount >= 0, "minPartitionSplitCount must be greater than or equal to 0: %s", minPartitionSplitCount); - this.minPartitionSplitCount = minPartitionSplitCount; - checkArgument(maxPartitionSplitCount > 0, "maxPartitionSplitCount must be greater than 0: %s", maxPartitionSplitCount); - checkArgument(maxPartitionSplitCount >= minPartitionSplitCount, - "maxPartitionSplitCount(%s) must be greater than or equal to minPartitionSplitCount(%s)", - maxPartitionSplitCount, - minPartitionSplitCount); - this.maxPartitionSplitCount = maxPartitionSplitCount; - this.executor = requireNonNull(executor, "executor is null"); - } - - @Override - public synchronized ListenableFuture> getMoreTasks() - { - if (finished || closed) { - return immediateFuture(ImmutableList.of()); - } - - checkState(currentSplitBatchFuture.isDone(), "getMoreTasks called again before the previous batch of splits was ready"); - currentSplitBatchFuture = splitSource.getNextBatch(splitBatchSize); - - long start = System.nanoTime(); - addSuccessCallback(currentSplitBatchFuture, () -> getSplitTimeRecorder.accept(start)); - - return Futures.transform( - currentSplitBatchFuture, - splitBatch -> { - synchronized (this) { - for (Split split : splitBatch.getSplits()) { - if (split.isRemotelyAccessible()) { - remotelyAccessibleSplitBuffer.add(split); - } - else { - List addresses = split.getAddresses(); - checkArgument(!addresses.isEmpty(), "split is not remotely accessible but the list of addresses is empty"); - for (HostAddress hostAddress : addresses) { - locallyAccessibleSplitBuffer.computeIfAbsent(hostAddress, key -> newIdentityHashSet()).add(split); - } - } - } - - ImmutableList.Builder readyTasksBuilder = ImmutableList.builder(); - boolean isLastBatch = splitBatch.isLastBatch(); - readyTasksBuilder.addAll(getReadyTasks( - remotelyAccessibleSplitBuffer, - ImmutableList.of(), - new NodeRequirements(catalogRequirement, ImmutableSet.of()), - isLastBatch)); - for (HostAddress remoteHost : locallyAccessibleSplitBuffer.keySet()) { - readyTasksBuilder.addAll(getReadyTasks( - locallyAccessibleSplitBuffer.get(remoteHost), - locallyAccessibleSplitBuffer.entrySet().stream() - .filter(entry -> !entry.getKey().equals(remoteHost)) - .map(Map.Entry::getValue) - .collect(toImmutableList()), - new NodeRequirements(catalogRequirement, ImmutableSet.of(remoteHost)), - isLastBatch)); - } - List readyTasks = readyTasksBuilder.build(); - - if (isLastBatch) { - Optional> tableExecuteSplitsInfo = splitSource.getTableExecuteSplitsInfo(); - - // Here we assume that we can get non-empty tableExecuteSplitsInfo only for queries which facilitate single split source. - tableExecuteSplitsInfo.ifPresent(info -> { - TableExecuteContext tableExecuteContext = tableExecuteContextManager.getTableExecuteContextForQuery(queryId); - tableExecuteContext.setSplitsInfo(info); - }); - - try { - splitSource.close(); - } - catch (RuntimeException e) { - log.error(e, "Error closing split source"); - } - finished = true; - } - - return readyTasks; - } - }, - executor); - } - - private List getReadyTasks(Set splits, List> otherSplitSets, NodeRequirements nodeRequirements, boolean includeRemainder) - { - ImmutableList.Builder result = ImmutableList.builder(); - while (true) { - Optional readyTask = getReadyTask(splits, otherSplitSets, nodeRequirements); - if (readyTask.isEmpty()) { - break; - } - result.add(readyTask.get()); - } - - if (includeRemainder && !splits.isEmpty()) { - result.add(buildTaskDescriptor(splits, nodeRequirements)); - for (Set otherSplits : otherSplitSets) { - otherSplits.removeAll(splits); - } - splits.clear(); - } - return result.build(); - } - - private Optional getReadyTask(Set splits, List> otherSplitSets, NodeRequirements nodeRequirements) - { - ImmutableList.Builder chosenSplitsBuilder = ImmutableList.builder(); - int splitCount = 0; - int totalSplitWeight = 0; - for (Split split : splits) { - totalSplitWeight += split.getSplitWeight().getRawValue(); - splitCount++; - chosenSplitsBuilder.add(split); - - if (splitCount >= minPartitionSplitCount && (totalSplitWeight >= targetPartitionSplitWeight || splitCount >= maxPartitionSplitCount)) { - ImmutableList chosenSplits = chosenSplitsBuilder.build(); - for (Set otherSplits : otherSplitSets) { - chosenSplits.forEach(otherSplits::remove); - } - chosenSplits.forEach(splits::remove); - return Optional.of(buildTaskDescriptor(chosenSplits, nodeRequirements)); - } - } - return Optional.empty(); - } - - private synchronized TaskDescriptor buildTaskDescriptor(Collection splits, NodeRequirements nodeRequirements) - { - return new TaskDescriptor( - currentPartitionId++, - ImmutableListMultimap.builder() - .putAll(partitionedSourceNodeId, splits) - .putAll(replicatedSplits) - .build(), - nodeRequirements); - } - - @Override - public synchronized boolean isFinished() - { - return finished; - } - - @Override - public synchronized void close() - { - if (closed) { - return; - } - closed = true; - splitSource.close(); - } - } - - private static ListMultimap getReplicatedExchangeSourceHandles(PlanFragment fragment, Multimap handles) - { - return getInputsForRemoteSources( - fragment.getRemoteSourceNodes().stream() - .filter(remoteSource -> remoteSource.getExchangeType() == REPLICATE) - .collect(toImmutableList()), - handles); - } - - private static ListMultimap getPartitionedExchangeSourceHandles(PlanFragment fragment, Multimap handles) - { - return getInputsForRemoteSources( - fragment.getRemoteSourceNodes().stream() - .filter(remoteSource -> remoteSource.getExchangeType() != REPLICATE) - .collect(toImmutableList()), - handles); - } - - private static ListMultimap getInputsForRemoteSources( - List remoteSources, - Multimap exchangeSourceHandles) - { - ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); - for (RemoteSourceNode remoteSource : remoteSources) { - for (PlanFragmentId fragmentId : remoteSource.getSourceFragmentIds()) { - result.putAll(remoteSource.getId(), exchangeSourceHandles.get(fragmentId)); - } - } - return result.build(); - } - - @VisibleForTesting - static ListMultimap createRemoteSplits(ListMultimap handles) - { - return Multimaps.asMap(handles).entrySet().stream() - .collect(toImmutableListMultimap(Map.Entry::getKey, entry -> createRemoteSplit(entry.getValue()))); - } - - @VisibleForTesting - static Split createRemoteSplit(Collection exchangeSourceHandles) - { - return new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.copyOf(exchangeSourceHandles), Optional.empty()))); - } - - private static class LoadedSplits - { - private final PlanNodeId planNodeId; - private final List splits; - - private LoadedSplits(PlanNodeId planNodeId, List splits) - { - this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); - this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); - } - - public PlanNodeId getPlanNodeId() - { - return planNodeId; - } - - public List getSplits() - { - return splits; - } - } - - private static class SplitLoadingFuture - extends AbstractFuture - { - private final PlanNodeId planNodeId; - private final SplitSource splitSource; - private final int splitBatchSize; - private final LongConsumer getSplitTimeRecorder; - private final Executor executor; - @GuardedBy("this") - private final List loadedSplits = new ArrayList<>(); - @GuardedBy("this") - private ListenableFuture currentSplitBatch = immediateFuture(null); - - SplitLoadingFuture(PlanNodeId planNodeId, SplitSource splitSource, int splitBatchSize, LongConsumer getSplitTimeRecorder, Executor executor) - { - this.planNodeId = requireNonNull(planNodeId, "planNodeId is null"); - this.splitSource = requireNonNull(splitSource, "splitSource is null"); - this.splitBatchSize = splitBatchSize; - this.getSplitTimeRecorder = requireNonNull(getSplitTimeRecorder, "getSplitTimeRecorder is null"); - this.executor = requireNonNull(executor, "executor is null"); - } - - // Called to initiate loading and to load next batch if not finished - public synchronized void load() - { - if (currentSplitBatch == null) { - checkState(isCancelled(), "SplitLoadingFuture should be in cancelled state"); - return; - } - checkState(currentSplitBatch.isDone(), "next batch of splits requested before previous batch is done"); - currentSplitBatch = splitSource.getNextBatch(splitBatchSize); - - long start = System.nanoTime(); - addCallback( - currentSplitBatch, - new FutureCallback<>() - { - @Override - public void onSuccess(SplitBatch splitBatch) - { - getSplitTimeRecorder.accept(start); - synchronized (SplitLoadingFuture.this) { - loadedSplits.addAll(splitBatch.getSplits()); - - if (splitBatch.isLastBatch()) { - set(new LoadedSplits(planNodeId, loadedSplits)); - try { - splitSource.close(); - } - catch (RuntimeException e) { - log.error(e, "Error closing split source"); - } - } - else { - load(); - } - } - } - - @Override - public void onFailure(Throwable throwable) - { - setException(throwable); - } - }, - executor); - } - - @Override - protected synchronized void interruptTask() - { - if (currentSplitBatch != null) { - currentSplitBatch.cancel(true); - currentSplitBatch = null; - } - } - } -} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java deleted file mode 100644 index 7e93d3df7731..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSource.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.util.concurrent.ListenableFuture; - -import java.io.Closeable; -import java.util.List; - -public interface TaskSource - extends Closeable -{ - ListenableFuture> getMoreTasks(); - - boolean isFinished(); - - @Override - void close(); -} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java deleted file mode 100644 index 34661234f083..000000000000 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/TaskSourceFactory.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.Multimap; -import io.trino.Session; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.plan.PlanFragmentId; - -import java.util.function.LongConsumer; - -/** - * Deprecated in favor of {@link EventDrivenTaskSourceFactory} - */ -@Deprecated -public interface TaskSourceFactory -{ - TaskSource create( - Session session, - PlanFragment fragment, - Multimap exchangeSourceHandles, - LongConsumer getSplitTimeRecorder, - FaultTolerantPartitioningScheme sourcePartitioningScheme); -} diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 7f405e1fbc81..7a4076566372 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -69,10 +69,8 @@ import io.trino.execution.scheduler.NodeSchedulerConfig; import io.trino.execution.scheduler.PartitionMemoryEstimatorFactory; import io.trino.execution.scheduler.SplitSchedulerStats; -import io.trino.execution.scheduler.StageTaskSourceFactory; import io.trino.execution.scheduler.TaskDescriptorStorage; import io.trino.execution.scheduler.TaskExecutionStats; -import io.trino.execution.scheduler.TaskSourceFactory; import io.trino.execution.scheduler.policy.AllAtOnceExecutionPolicy; import io.trino.execution.scheduler.policy.ExecutionPolicy; import io.trino.execution.scheduler.policy.PhasedExecutionPolicy; @@ -325,7 +323,6 @@ protected void setup(Binder binder) binder.bind(SplitSchedulerStats.class).in(Scopes.SINGLETON); newExporter(binder).export(SplitSchedulerStats.class).withGeneratedName(); - binder.bind(TaskSourceFactory.class).to(StageTaskSourceFactory.class).in(Scopes.SINGLETON); binder.bind(EventDrivenTaskSourceFactory.class).in(Scopes.SINGLETON); binder.bind(TaskDescriptorStorage.class).in(Scopes.SINGLETON); newExporter(binder).export(TaskDescriptorStorage.class).withGeneratedName(); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index d0d2d5c663b9..0fbbe82c4348 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -65,7 +65,6 @@ public void testDefaults() .setRequiredWorkersMaxWait(new Duration(5, MINUTES)) .setRetryPolicy(RetryPolicy.NONE) .setQueryRetryAttempts(4) - .setTaskRetryAttemptsOverall(Integer.MAX_VALUE) .setTaskRetryAttemptsPerTask(4) .setRetryInitialDelay(new Duration(10, SECONDS)) .setRetryMaxDelay(new Duration(1, MINUTES)) @@ -73,12 +72,10 @@ public void testDefaults() .setMaxTasksWaitingForExecutionPerQuery(10) .setMaxTasksWaitingForNodePerStage(5) .setFaultTolerantExecutionTargetTaskInputSize(DataSize.of(4, GIGABYTE)) - .setFaultTolerantExecutionMinTaskSplitCount(16) .setFaultTolerantExecutionTargetTaskSplitCount(64) .setFaultTolerantExecutionMaxTaskSplitCount(256) .setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize.ofBytes(Math.round(AVAILABLE_HEAP_MEMORY * 0.15))) .setFaultTolerantExecutionPartitionCount(50) - .setFaultTolerantExecutionEventDrivenSchedulerEnabled(true) .setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(true)); } @@ -112,7 +109,6 @@ public void testExplicitPropertyMappings() .put("query-manager.required-workers-max-wait", "33m") .put("retry-policy", "QUERY") .put("query-retry-attempts", "0") - .put("task-retry-attempts-overall", "17") .put("task-retry-attempts-per-task", "9") .put("retry-initial-delay", "1m") .put("retry-max-delay", "1h") @@ -120,12 +116,10 @@ public void testExplicitPropertyMappings() .put("max-tasks-waiting-for-execution-per-query", "22") .put("max-tasks-waiting-for-node-per-stage", "3") .put("fault-tolerant-execution-target-task-input-size", "222MB") - .put("fault-tolerant-execution-min-task-split-count", "2") .put("fault-tolerant-execution-target-task-split-count", "3") .put("fault-tolerant-execution-max-task-split-count", "22") .put("fault-tolerant-execution-task-descriptor-storage-max-memory", "3GB") .put("fault-tolerant-execution-partition-count", "123") - .put("experimental.fault-tolerant-execution-event-driven-scheduler-enabled", "false") .put("experimental.fault-tolerant-execution-force-preferred-write-partitioning-enabled", "false") .buildOrThrow(); @@ -156,7 +150,6 @@ public void testExplicitPropertyMappings() .setRequiredWorkersMaxWait(new Duration(33, MINUTES)) .setRetryPolicy(RetryPolicy.QUERY) .setQueryRetryAttempts(0) - .setTaskRetryAttemptsOverall(17) .setTaskRetryAttemptsPerTask(9) .setRetryInitialDelay(new Duration(1, MINUTES)) .setRetryMaxDelay(new Duration(1, HOURS)) @@ -164,12 +157,10 @@ public void testExplicitPropertyMappings() .setMaxTasksWaitingForExecutionPerQuery(22) .setMaxTasksWaitingForNodePerStage(3) .setFaultTolerantExecutionTargetTaskInputSize(DataSize.of(222, MEGABYTE)) - .setFaultTolerantExecutionMinTaskSplitCount(2) .setFaultTolerantExecutionTargetTaskSplitCount(3) .setFaultTolerantExecutionMaxTaskSplitCount(22) .setFaultTolerantExecutionTaskDescriptorStorageMaxMemory(DataSize.of(3, GIGABYTE)) .setFaultTolerantExecutionPartitionCount(123) - .setFaultTolerantExecutionEventDrivenSchedulerEnabled(false) .setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(false); assertFullMapping(properties, expected); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java deleted file mode 100644 index 7935ade157ce..000000000000 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestFaultTolerantStageScheduler.java +++ /dev/null @@ -1,1168 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.base.Stopwatch; -import com.google.common.base.Ticker; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Sets; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.airlift.testing.TestingTicker; -import io.airlift.units.DataSize; -import io.trino.Session; -import io.trino.client.NodeVersion; -import io.trino.cost.StatsAndCosts; -import io.trino.execution.DynamicFilterConfig; -import io.trino.execution.NodeTaskMap; -import io.trino.execution.RemoteTaskFactory; -import io.trino.execution.SqlStage; -import io.trino.execution.StageId; -import io.trino.execution.TaskId; -import io.trino.execution.TaskState; -import io.trino.execution.TestingRemoteTaskFactory; -import io.trino.execution.TestingRemoteTaskFactory.TestingRemoteTask; -import io.trino.execution.scheduler.TestingExchange.TestingExchangeSinkHandle; -import io.trino.execution.scheduler.TestingNodeSelectorFactory.TestingNodeSupplier; -import io.trino.failuredetector.NoOpFailureDetector; -import io.trino.metadata.InternalNode; -import io.trino.metadata.Split; -import io.trino.server.DynamicFilterService; -import io.trino.spi.QueryId; -import io.trino.spi.StandardErrorCode; -import io.trino.spi.TrinoException; -import io.trino.spi.exchange.Exchange; -import io.trino.spi.predicate.TupleDomain; -import io.trino.sql.PlannerContext; -import io.trino.sql.planner.Partitioning; -import io.trino.sql.planner.PartitioningScheme; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.plan.JoinNode; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.RemoteSourceNode; -import io.trino.sql.planner.plan.TableScanNode; -import io.trino.testing.TestingMetadata.TestingColumnHandle; -import io.trino.testing.TestingSplit; -import io.trino.util.FinalizerService; -import org.testng.annotations.AfterClass; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; - -import java.net.URI; -import java.time.Duration; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; - -import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.Iterables.cycle; -import static com.google.common.collect.Iterables.limit; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.airlift.units.DataSize.Unit.MEGABYTE; -import static io.trino.operator.RetryPolicy.TASK; -import static io.trino.spi.type.VarcharType.VARCHAR; -import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; -import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; -import static io.trino.sql.planner.TestingPlannerContext.plannerContextBuilder; -import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; -import static io.trino.sql.planner.plan.JoinNode.DistributionType.REPLICATED; -import static io.trino.sql.planner.plan.JoinNode.Type.INNER; -import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; -import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; -import static io.trino.testing.TestingSession.testSessionBuilder; -import static io.trino.testing.TestingSplit.createRemoteSplit; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; - -@Test(singleThreaded = true) -public class TestFaultTolerantStageScheduler -{ - private static final QueryId QUERY_ID = new QueryId("query"); - private static final Session SESSION = testSessionBuilder() - .setQueryId(QUERY_ID) - .build(); - private static final StageId STAGE_ID = new StageId(QUERY_ID, 0); - private static final PlanFragmentId FRAGMENT_ID = new PlanFragmentId("0"); - private static final PlanFragmentId SOURCE_FRAGMENT_ID_1 = new PlanFragmentId("1"); - private static final PlanFragmentId SOURCE_FRAGMENT_ID_2 = new PlanFragmentId("2"); - private static final PlanNodeId TABLE_SCAN_NODE_ID = new PlanNodeId("table_scan_id"); - - private static final InternalNode NODE_1 = new InternalNode("node-1", URI.create("local://127.0.0.1:8080"), NodeVersion.UNKNOWN, false); - private static final InternalNode NODE_2 = new InternalNode("node-2", URI.create("local://127.0.0.1:8081"), NodeVersion.UNKNOWN, false); - private static final InternalNode NODE_3 = new InternalNode("node-3", URI.create("local://127.0.0.1:8082"), NodeVersion.UNKNOWN, false); - - private static final PlannerContext PLANNER_CONTEXT = plannerContextBuilder().build(); - - private FinalizerService finalizerService; - private NodeTaskMap nodeTaskMap; - private FixedCountNodeAllocatorService nodeAllocatorService; - - private TestingTicker ticker; - private TestFutureCompletor futureCompletor; - - @BeforeClass - public void beforeClass() - { - finalizerService = new FinalizerService(); - finalizerService.start(); - nodeTaskMap = new NodeTaskMap(finalizerService); - ticker = new TestingTicker(); - futureCompletor = new TestFutureCompletor(ticker); - } - - @AfterClass(alwaysRun = true) - public void afterClass() - { - nodeTaskMap = null; - if (finalizerService != null) { - finalizerService.destroy(); - finalizerService = null; - } - } - - private void setupNodeAllocatorService(TestingNodeSupplier nodeSupplier) - { - shutdownNodeAllocatorService(); // just in case - nodeAllocatorService = new FixedCountNodeAllocatorService(new NodeScheduler(new TestingNodeSelectorFactory(NODE_1, nodeSupplier))); - } - - @AfterMethod(alwaysRun = true) - public void shutdownNodeAllocatorService() - { - if (nodeAllocatorService != null) { - nodeAllocatorService.stop(); - } - nodeAllocatorService = null; - } - - @Test - public void testHappyPath() - throws Exception - { - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(5, 2); - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( - NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_2, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_3, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - TestingExchange sinkExchange = new TestingExchange(); - - TestingExchange sourceExchange1 = new TestingExchange(); - TestingExchange sourceExchange2 = new TestingExchange(); - - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - sinkExchange, - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2, - 1); - - ListenableFuture blocked = scheduler.isBlocked(); - assertUnblocked(blocked); - - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on first source exchange - assertBlocked(blocked); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); - // still blocked on the second source exchange - assertBlocked(blocked); - assertFalse(scheduler.isBlocked().isDone()); - - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); - // now unblocked - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on node allocation - assertBlocked(blocked); - - // not all tasks have been enumerated yet - assertFalse(sinkExchange.isNoMoreSinks()); - - Map tasks = remoteTaskFactory.getTasks(); - // one task per node - assertThat(tasks).hasSize(3); - assertThat(tasks).containsKey(getTaskId(0, 0)); - assertThat(tasks).containsKey(getTaskId(1, 0)); - assertThat(tasks).containsKey(getTaskId(2, 0)); - - TestingRemoteTask task = tasks.get(getTaskId(0, 0)); - // fail task for partition 0 - task.fail(new RuntimeException("some failure")); - - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - // schedule more tasks - moveTime(10, SECONDS); // skip retry delay - scheduler.schedule(); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(4); - assertThat(tasks).containsKey(getTaskId(3, 0)); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - // finish some task - assertThat(tasks).containsKey(getTaskId(1, 0)); - tasks.get(getTaskId(1, 0)).finish(); - - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1)); - - // this will schedule failed task - scheduler.schedule(); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(5); - assertThat(tasks).containsKey(getTaskId(0, 1)); - - // finish some task - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(3, 0)); - tasks.get(getTaskId(3, 0)).finish(); - assertThat(sinkExchange.getFinishedSinkHandles()).contains(new TestingExchangeSinkHandle(1), new TestingExchangeSinkHandle(3)); - - assertUnblocked(blocked); - - // schedule the last task - scheduler.schedule(); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(6); - assertThat(tasks).containsKey(getTaskId(4, 0)); - assertTrue(sinkExchange.isNoMoreSinks()); - - // not finished yet, will be finished when all tasks succeed - assertFalse(scheduler.isFinished()); - - blocked = scheduler.isBlocked(); - // blocked on task scheduling - assertBlocked(blocked); - - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(4, 0)); - // finish remaining tasks - tasks.get(getTaskId(0, 1)).finish(); - tasks.get(getTaskId(2, 0)).finish(); - tasks.get(getTaskId(4, 0)).finish(); - - // now it's not blocked and finished - assertUnblocked(blocked); - assertUnblocked(scheduler.isBlocked()); - - assertThat(sinkExchange.getFinishedSinkHandles()).contains( - new TestingExchangeSinkHandle(0), - new TestingExchangeSinkHandle(1), - new TestingExchangeSinkHandle(2), - new TestingExchangeSinkHandle(3), - new TestingExchangeSinkHandle(4)); - - assertTrue(sinkExchange.isAllRequiredSinksFinished()); - - assertTrue(scheduler.isFinished()); - } - } - - @Test - public void testTasksWaitingForNodes() - throws Exception - { - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - List splits = ImmutableList.of( - new Split(TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), // 0 - new Split(TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), // 1 - new Split(TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), // 2 - new Split(TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_2.getHostAndPort()))), // 3 - new Split(TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_1.getHostAndPort()))), // 4 - new Split(TEST_CATALOG_HANDLE, new TestingSplit(false, ImmutableList.of(NODE_3.getHostAndPort())))); // 5 - TestingTaskSourceFactory taskSourceFactory = new TestingTaskSourceFactory(Optional.of(TEST_CATALOG_HANDLE), splits, 2); - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( - NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_2, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_3, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - TestingExchange sinkExchange = new TestingExchange(); - TestingExchange sourceExchange1 = new TestingExchange(); - TestingExchange sourceExchange2 = new TestingExchange(); - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - sinkExchange, - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2, - 3); // allow for 3 tasks waiting for nodes before blocking - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); - scheduler.schedule(); - - Map tasks; - - // we reached max pending tasks count (3) on split 4 and blocked; task for split 5 will not be allocated even though NODE_3 is free - assertBlocked(scheduler.isBlocked()); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(2); - assertThat(tasks).containsKey(getTaskId(0, 0)); - assertThat(tasks).containsKey(getTaskId(3, 0)); - - // unblocking NODE_2 does not help - tasks.get(getTaskId(3, 0)).finish(); - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(2); - assertThat(tasks).containsKey(getTaskId(0, 0)); - assertThat(tasks).containsKey(getTaskId(3, 0)); - - // unblocking NODE_1 allows for scheduling next pending split for NODE_1 and NODE_3 - tasks.get(getTaskId(0, 0)).finish(); - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).hasSize(4); - assertThat(tasks).containsKey(getTaskId(0, 0)); - assertThat(tasks).containsKey(getTaskId(1, 0)); // NEW (NODE_1) - assertThat(tasks).containsKey(getTaskId(3, 0)); - assertThat(tasks).containsKey(getTaskId(5, 0)); // NEW (NODE_3) - - // finish all remaining tasks until scheduler is finished - - tasks.get(getTaskId(1, 0)).finish(); - scheduler.schedule(); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(2, 0)); // NEW (NODE_1) - - tasks.get(getTaskId(2, 0)).finish(); - scheduler.schedule(); - tasks = remoteTaskFactory.getTasks(); - assertThat(tasks).containsKey(getTaskId(4, 0)); // NEW (NODE_1) - - tasks.get(getTaskId(4, 0)).finish(); - tasks.get(getTaskId(3, 0)).finish(); - tasks.get(getTaskId(5, 0)).finish(); - scheduler.schedule(); - assertUnblocked(scheduler.isBlocked()); - assertTrue(scheduler.isFinished()); - } - } - - @Test - public void testTaskFailure() - throws Exception - { - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(3, 1); - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( - NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_2, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - TestingExchange sourceExchange1 = new TestingExchange(); - TestingExchange sourceExchange2 = new TestingExchange(); - - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - new TestingExchange(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 0, - 1); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); - assertUnblocked(scheduler.isBlocked()); - - scheduler.schedule(); - - ListenableFuture blocked = scheduler.isBlocked(); - // waiting on node acquisition - assertBlocked(blocked); - - NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1, GIGABYTE)); - NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1, GIGABYTE)); - - remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); - - assertUnblocked(blocked); - assertUnblocked(acquireNode1.getNode()); - assertUnblocked(acquireNode2.getNode()); - - assertThatThrownBy(scheduler::schedule) - .hasMessageContaining("some failure"); - - assertUnblocked(scheduler.isBlocked()); - assertFalse(scheduler.isFinished()); - } - } - - @Test - public void testRetryDelay() - throws Exception - { - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(3, 1); - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( - NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_2, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_3, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - TestingExchange sourceExchange1 = new TestingExchange(); - TestingExchange sourceExchange2 = new TestingExchange(); - - Session session = testSessionBuilder() - .setQueryId(QUERY_ID) - .setSystemProperty("retry_initial_delay", "1s") - .setSystemProperty("retry_max_delay", "3s") - .setSystemProperty("retry_delay_scale_factor", "2.0") - .build(); - - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(session, 1)) { - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - session, - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - new TestingExchange(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 6, - 1); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); - assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); - - ListenableFuture blocked = scheduler.isBlocked(); - - // T+0.0 all tasks are running - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(3); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+0.0 fail task 0.0 - remoteTaskFactory.getTasks().get(getTaskId(0, 0)).fail(new RuntimeException("some failure")); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(3); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+0.9 retry should not trigger yet - moveTime(900, MILLISECONDS); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(3); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+1.4s past retry delay for task 0 - moveTime(500, MILLISECONDS); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(4); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+1.4 fail task 0.1 - remoteTaskFactory.getTasks().get(getTaskId(0, 1)).fail(new RuntimeException("some other failure")); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(4); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+3.3 another retry should not happen yet (delay is 2s on second failure) - moveTime(1900, MILLISECONDS); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(4); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+3.5s past retry delay for task 0.1 - moveTime(200, MILLISECONDS); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(5); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+3.5 fail task 0.2 - remoteTaskFactory.getTasks().get(getTaskId(0, 2)).fail(new RuntimeException("some other failure")); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(5); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+6.4 another retry should not happen yet (delay is 3s on thirf failure (we reached limit) - moveTime(2900, MILLISECONDS); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(5); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+6.6s past retry delay for task 0.2 - moveTime(200, MILLISECONDS); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(6); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+6.6 task 1 failure - remoteTaskFactory.getTasks().get(getTaskId(1, 0)).fail(new RuntimeException("some other failure")); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(6); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+9.0 task1 still not retried (delay is 3s) - moveTime(2400, MILLISECONDS); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(6); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+9.0 task 0.3 completes successfully - we should not reset delay; backoff still in progress - remoteTaskFactory.getTasks().get(getTaskId(0, 3)).finish(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(6); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+9.7 more than 3s passed since task 1.0 was killed; should restart now - moveTime(700, MILLISECONDS); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(7); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.RUNNING); - - // T+9.7 task 2.0 completes successfully - delay should be reset (we are not in backoff now) - remoteTaskFactory.getTasks().get(getTaskId(2, 0)).finish(); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(7); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED); - - // T+9.7 kill task 1.1; delay should be 1s now - remoteTaskFactory.getTasks().get(getTaskId(1, 1)).fail(new RuntimeException("some other failure")); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(7); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED); - - // T+10.6 task 1.2 should not start yet - moveTime(900, MILLISECONDS); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(7); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED); - - // T+10.8 more than 1s passed; task 1.2 should start now - moveTime(200, MILLISECONDS); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(8); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 2)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED); - - // T+10.8 if we kill task with out of memory error next try should be started right away - remoteTaskFactory.getTasks().get(getTaskId(1, 2)).fail(new TrinoException(StandardErrorCode.CLUSTER_OUT_OF_MEMORY, "oom")); - assertUnblocked(blocked); - scheduler.schedule(); - blocked = scheduler.isBlocked(); - assertBlocked(blocked); - assertThat(remoteTaskFactory.getTasks()).hasSize(9); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(0, 3)).getTaskStatus().getState(), TaskState.FINISHED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 0)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 1)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 2)).getTaskStatus().getState(), TaskState.FAILED); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(1, 3)).getTaskStatus().getState(), TaskState.RUNNING); - assertEquals(remoteTaskFactory.getTasks().get(getTaskId(2, 0)).getTaskStatus().getState(), TaskState.FINISHED); - } - } - - @Test - public void testCancellation() - throws Exception - { - testCancellation(true); - testCancellation(false); - } - - private void testCancellation(boolean abort) - throws Exception - { - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - TestingTaskSourceFactory taskSourceFactory = createTaskSourceFactory(3, 1); - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( - NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_2, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - TestingExchange sourceExchange1 = new TestingExchange(); - TestingExchange sourceExchange2 = new TestingExchange(); - - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - new TestingExchange(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 0, - 1); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); - assertUnblocked(scheduler.isBlocked()); - - scheduler.schedule(); - - ListenableFuture blocked = scheduler.isBlocked(); - // waiting on node acquisition - assertBlocked(blocked); - - NodeAllocator.NodeLease acquireNode1 = nodeAllocator.acquire(new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1, GIGABYTE)); - NodeAllocator.NodeLease acquireNode2 = nodeAllocator.acquire(new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of()), DataSize.of(1, GIGABYTE)); - - if (abort) { - scheduler.abort(); - } - else { - scheduler.cancel(); - } - - assertUnblocked(blocked); - assertUnblocked(acquireNode1.getNode()); - assertUnblocked(acquireNode2.getNode()); - - scheduler.schedule(); - - assertUnblocked(scheduler.isBlocked()); - assertFalse(scheduler.isFinished()); - } - } - - @Test - public void testAsyncTaskSource() - throws Exception - { - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - SettableFuture> splitsFuture = SettableFuture.create(); - TestingTaskSourceFactory taskSourceFactory = new TestingTaskSourceFactory(Optional.of(TEST_CATALOG_HANDLE), splitsFuture, 1); - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of( - NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE), - NODE_2, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - TestingExchange sourceExchange1 = new TestingExchange(); - TestingExchange sourceExchange2 = new TestingExchange(); - - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - new TestingExchange(), - ImmutableMap.of(SOURCE_FRAGMENT_ID_1, sourceExchange1, SOURCE_FRAGMENT_ID_2, sourceExchange2), - 2, - 1); - - sourceExchange1.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(0, 0, 1))); - sourceExchange2.setSourceHandles(ImmutableList.of(new TestingExchangeSourceHandle(1, 0, 1))); - assertUnblocked(scheduler.isBlocked()); - - scheduler.schedule(); - assertBlocked(scheduler.isBlocked()); - - splitsFuture.set(createSplits(2)); - assertUnblocked(scheduler.isBlocked()); - scheduler.schedule(); - assertThat(remoteTaskFactory.getTasks()).hasSize(2); - remoteTaskFactory.getTasks().values().forEach(task -> { - Collection splits = task.getSplits().values(); - // 2 normal splits + 1 split containing an output selector - assertThat(splits).hasSize(3); - task.finish(); - }); - assertThat(scheduler.isFinished()).isTrue(); - } - } - - @Test - public void testIsFinished() - throws Exception - { - TestingNodeSupplier nodeSupplier = TestingNodeSupplier.create(ImmutableMap.of(NODE_1, ImmutableList.of(TEST_CATALOG_HANDLE))); - setupNodeAllocatorService(nodeSupplier); - - // scheduler is not finished if the task future is not finished - TestingRemoteTaskFactory remoteTaskFactory = new TestingRemoteTaskFactory(); - SettableFuture> future = SettableFuture.create(); - AtomicBoolean taskSourceCreated = new AtomicBoolean(); - AtomicBoolean taskSourceFinished = new AtomicBoolean(); - TaskSource taskSource = new TaskSource() - { - @Override - public ListenableFuture> getMoreTasks() - { - return future; - } - - @Override - public boolean isFinished() - { - return taskSourceFinished.get(); - } - - @Override - public void close() {} - }; - try (NodeAllocator nodeAllocator = nodeAllocatorService.getNodeAllocator(SESSION, 1)) { - TestingExchange sourceExchange1 = new TestingExchange(); - sourceExchange1.setSourceHandles(ImmutableList.of()); - TestingExchange sourceExchange2 = new TestingExchange(); - sourceExchange2.setSourceHandles(ImmutableList.of()); - TestingExchange sinkExchange = new TestingExchange(); - FaultTolerantStageScheduler scheduler = createFaultTolerantTaskScheduler( - remoteTaskFactory, - (session, fragment, exchangeSourceHandles, getSplitTimeRecorder, bucketToPartition) -> { - taskSourceCreated.set(true); - return taskSource; - }, - nodeAllocator, - sinkExchange, - ImmutableMap.of( - SOURCE_FRAGMENT_ID_1, sourceExchange1, - SOURCE_FRAGMENT_ID_2, sourceExchange2), - 1, - 1); - - // ensure task source is created - assertFalse(taskSourceCreated.get()); - scheduler.schedule(); - assertTrue(taskSourceCreated.get()); - - // ensure scheduler is initially not finished - scheduler.schedule(); - assertFalse(scheduler.isFinished()); - - // transition task source to finished - taskSourceFinished.set(true); - scheduler.schedule(); - // task source is not finished as the future returned from task source is not yet competed - assertFalse(scheduler.isFinished()); - - future.set(ImmutableList.of()); - assertTrue(scheduler.isFinished()); - assertTrue(sinkExchange.isNoMoreSinks()); - assertTrue(sinkExchange.isAllRequiredSinksFinished()); - } - } - - private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( - RemoteTaskFactory remoteTaskFactory, - TaskSourceFactory taskSourceFactory, - NodeAllocator nodeAllocator, - Exchange sinkExchange, - Map sourceExchanges, - int retryAttempts, - int maxTasksWaitingForNodePerStage) - { - return createFaultTolerantTaskScheduler( - SESSION, - remoteTaskFactory, - taskSourceFactory, - nodeAllocator, - sinkExchange, - sourceExchanges, - retryAttempts, - maxTasksWaitingForNodePerStage); - } - - private FaultTolerantStageScheduler createFaultTolerantTaskScheduler( - Session session, - RemoteTaskFactory remoteTaskFactory, - TaskSourceFactory taskSourceFactory, - NodeAllocator nodeAllocator, - Exchange sinkExchange, - Map sourceExchanges, - int retryAttempts, - int maxTasksWaitingForNodePerStage) - { - TaskDescriptorStorage taskDescriptorStorage = new TaskDescriptorStorage(DataSize.of(10, MEGABYTE)); - taskDescriptorStorage.initialize(SESSION.getQueryId()); - DynamicFilterService dynamicFilterService = new DynamicFilterService(PLANNER_CONTEXT.getMetadata(), PLANNER_CONTEXT.getFunctionManager(), PLANNER_CONTEXT.getTypeOperators(), new DynamicFilterConfig()); - return createStageScheduler( - session, - createSqlStage(createIntermediatePlanFragment(), remoteTaskFactory), - nodeAllocator, - retryAttempts, - maxTasksWaitingForNodePerStage, - taskDescriptorStorage, - taskSourceFactory, - dynamicFilterService, - sinkExchange, - sourceExchanges.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> { - FaultTolerantStageScheduler sourceScheduler = createStageScheduler( - session, - createSqlStage(createLeafPlanFragment(entry.getKey()), remoteTaskFactory), - nodeAllocator, - retryAttempts, - maxTasksWaitingForNodePerStage, - taskDescriptorStorage, - new TestingTaskSourceFactory(Optional.empty(), ImmutableList.of(), 1), - dynamicFilterService, - entry.getValue(), - ImmutableMap.of(), - ImmutableMap.of()); - while (!sourceScheduler.isFinished()) { - try { - sourceScheduler.schedule(); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - return sourceScheduler; - })), - sourceExchanges); - } - - private FaultTolerantStageScheduler createStageScheduler( - Session session, - SqlStage stage, - NodeAllocator nodeAllocator, - int retryAttempts, - int maxTasksWaitingForNodePerStage, - TaskDescriptorStorage taskDescriptorStorage, - TaskSourceFactory taskSourceFactory, - DynamicFilterService dynamicFilterService, - Exchange sinkExchange, - Map sourceSchedulers, - Map sourceExchanges) - { - FaultTolerantPartitioningScheme partitioningScheme = new FaultTolerantPartitioningScheme(3, Optional.empty(), Optional.empty(), Optional.empty()); - return new FaultTolerantStageScheduler( - session, - stage, - new NoOpFailureDetector(), - taskSourceFactory, - nodeAllocator, - taskDescriptorStorage, - new ConstantPartitionMemoryEstimator(), - new TaskExecutionStats(), - futureCompletor, - ticker, - sinkExchange, - partitioningScheme, - sourceSchedulers, - sourceExchanges, - partitioningScheme, - new AtomicInteger(retryAttempts), - retryAttempts, - maxTasksWaitingForNodePerStage, - dynamicFilterService); - } - - private SqlStage createSqlStage(PlanFragment fragment, RemoteTaskFactory remoteTaskFactory) - { - return SqlStage.createSqlStage( - STAGE_ID, - fragment, - ImmutableMap.of(), - remoteTaskFactory, - SESSION, - false, - nodeTaskMap, - directExecutor(), - new SplitSchedulerStats()); - } - - private PlanFragment createIntermediatePlanFragment() - { - Symbol probeColumnSymbol = new Symbol("probe_column"); - Symbol buildColumnSymbol = new Symbol("build_column"); - TableScanNode tableScan = new TableScanNode( - TABLE_SCAN_NODE_ID, - TEST_TABLE_HANDLE, - ImmutableList.of(probeColumnSymbol), - ImmutableMap.of(probeColumnSymbol, new TestingColumnHandle("column")), - TupleDomain.none(), - Optional.empty(), - false, - Optional.empty()); - RemoteSourceNode remoteSource = new RemoteSourceNode( - new PlanNodeId("remote_source_id"), - ImmutableList.of(SOURCE_FRAGMENT_ID_1, SOURCE_FRAGMENT_ID_2), - ImmutableList.of(buildColumnSymbol), - Optional.empty(), - REPLICATE, - TASK); - return new PlanFragment( - FRAGMENT_ID, - new JoinNode( - new PlanNodeId("join_id"), - INNER, - tableScan, - remoteSource, - ImmutableList.of(), - tableScan.getOutputSymbols(), - remoteSource.getOutputSymbols(), - false, - Optional.empty(), - Optional.empty(), - Optional.empty(), - Optional.of(REPLICATED), - Optional.empty(), - ImmutableMap.of(), - Optional.empty()), - ImmutableMap.of(probeColumnSymbol, VARCHAR, buildColumnSymbol, VARCHAR), - SOURCE_DISTRIBUTION, - ImmutableList.of(TABLE_SCAN_NODE_ID), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(probeColumnSymbol, buildColumnSymbol)), - StatsAndCosts.empty(), - ImmutableList.of(), - Optional.empty()); - } - - private PlanFragment createLeafPlanFragment(PlanFragmentId fragmentId) - { - Symbol outputColumn = new Symbol("output_column"); - return new PlanFragment( - fragmentId, - new TableScanNode( - TABLE_SCAN_NODE_ID, - TEST_TABLE_HANDLE, - ImmutableList.of(outputColumn), - ImmutableMap.of(outputColumn, new TestingColumnHandle("column")), - TupleDomain.none(), - Optional.empty(), - false, - Optional.empty()), - ImmutableMap.of(outputColumn, VARCHAR), - SOURCE_DISTRIBUTION, - ImmutableList.of(TABLE_SCAN_NODE_ID), - new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(outputColumn)), - StatsAndCosts.empty(), - ImmutableList.of(), - Optional.empty()); - } - - private static TestingTaskSourceFactory createTaskSourceFactory(int splitCount, int taskPerBatch) - { - return new TestingTaskSourceFactory(Optional.of(TEST_CATALOG_HANDLE), createSplits(splitCount), taskPerBatch); - } - - private static List createSplits(int count) - { - return ImmutableList.copyOf(limit(cycle(new Split(TEST_CATALOG_HANDLE, createRemoteSplit())), count)); - } - - private static TaskId getTaskId(int partitionId, int attemptId) - { - return new TaskId(STAGE_ID, partitionId, attemptId); - } - - private static void assertBlocked(ListenableFuture blocked) - { - assertFalse(blocked.isDone()); - } - - private static void assertUnblocked(ListenableFuture blocked) - { - assertTrue(blocked.isDone()); - } - - private void moveTime(int delta, TimeUnit unit) - { - ticker.increment(delta, unit); - futureCompletor.trigger(); - } - - private static class TestFutureCompletor - implements FaultTolerantStageScheduler.DelayedFutureCompletor - { - private final Stopwatch stopwatch; - private final Set entries = Sets.newConcurrentHashSet(); - - private TestFutureCompletor(Ticker ticker) - { - this.stopwatch = Stopwatch.createStarted(ticker); - } - - @Override - public void completeFuture(SettableFuture future, Duration delay) - { - entries.add(new Entry(future, stopwatch.elapsed().plus(delay))); - } - - public void trigger() - { - Duration now = stopwatch.elapsed(); - Iterator iterator = entries.iterator(); - while (iterator.hasNext()) { - Entry entry = iterator.next(); - if (entry.completionTime.compareTo(now) <= 0) { - entry.future.set(null); - iterator.remove(); - } - } - } - - private static class Entry - { - private final SettableFuture future; - private final Duration completionTime; - - public Entry(SettableFuture future, Duration completionTime) - { - this.future = future; - this.completionTime = completionTime; - } - } - } -} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java deleted file mode 100644 index 413922938be9..000000000000 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestStageTaskSourceFactory.java +++ /dev/null @@ -1,949 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ListMultimap; -import com.google.common.collect.Multimap; -import com.google.common.collect.Multimaps; -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.SettableFuture; -import io.airlift.units.DataSize; -import io.trino.client.NodeVersion; -import io.trino.exchange.SpoolingExchangeInput; -import io.trino.execution.TableExecuteContextManager; -import io.trino.execution.scheduler.StageTaskSourceFactory.ArbitraryDistributionTaskSource; -import io.trino.execution.scheduler.StageTaskSourceFactory.HashDistributionTaskSource; -import io.trino.execution.scheduler.StageTaskSourceFactory.SingleDistributionTaskSource; -import io.trino.execution.scheduler.StageTaskSourceFactory.SourceDistributionTaskSource; -import io.trino.metadata.InMemoryNodeManager; -import io.trino.metadata.InternalNode; -import io.trino.metadata.InternalNodeManager; -import io.trino.metadata.Split; -import io.trino.spi.HostAddress; -import io.trino.spi.QueryId; -import io.trino.spi.SplitWeight; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.split.RemoteSplit; -import io.trino.split.SplitSource; -import io.trino.sql.planner.plan.PlanNodeId; -import org.testng.annotations.Test; - -import java.net.URI; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.OptionalInt; -import java.util.stream.IntStream; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.collect.ImmutableListMultimap.toImmutableListMultimap; -import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.common.collect.Streams.findLast; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.airlift.concurrent.MoreFutures.getDone; -import static io.airlift.concurrent.MoreFutures.getFutureValue; -import static io.airlift.units.DataSize.Unit.BYTE; -import static io.airlift.units.DataSize.Unit.GIGABYTE; -import static io.trino.execution.scheduler.StageTaskSourceFactory.createRemoteSplits; -import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; -import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; -import static java.util.Collections.nCopies; -import static java.util.Objects.requireNonNull; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.guava.api.Assertions.assertThat; -import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertFalse; -import static org.testng.Assert.assertTrue; - -public class TestStageTaskSourceFactory -{ - private static final HostAddress NODE_ADDRESS = HostAddress.fromString("testaddress"); - private static final PlanNodeId PLAN_NODE_1 = new PlanNodeId("planNode1"); - private static final PlanNodeId PLAN_NODE_2 = new PlanNodeId("planNode2"); - private static final PlanNodeId PLAN_NODE_3 = new PlanNodeId("planNode3"); - private static final PlanNodeId PLAN_NODE_4 = new PlanNodeId("planNode4"); - private static final PlanNodeId PLAN_NODE_5 = new PlanNodeId("planNode5"); - public static final long STANDARD_WEIGHT = SplitWeight.standard().getRawValue(); - - @Test - public void testSingleDistributionTaskSource() - { - ListMultimap sources = ImmutableListMultimap.builder() - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 123)) - .put(PLAN_NODE_2, new TestingExchangeSourceHandle(1, 0, 321)) - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(2, 0, 222)) - .build(); - TaskSource taskSource = new SingleDistributionTaskSource(createRemoteSplits(sources), new InMemoryNodeManager(), false); - - assertFalse(taskSource.isFinished()); - - List tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(1); - assertTrue(taskSource.isFinished()); - - TaskDescriptor task = tasks.get(0); - assertThat(task.getNodeRequirements().getCatalogHandle()).isEmpty(); - assertThat(task.getNodeRequirements().getAddresses()).isEmpty(); - assertEquals(task.getPartitionId(), 0); - assertEquals(extractSourceHandles(task.getSplits()), sources); - assertEquals(extractCatalogSplits(task.getSplits()), ImmutableListMultimap.of()); - } - - @Test - public void testCoordinatorDistributionTaskSource() - { - ListMultimap sources = ImmutableListMultimap.builder() - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 123)) - .put(PLAN_NODE_2, new TestingExchangeSourceHandle(1, 0, 321)) - .put(PLAN_NODE_1, new TestingExchangeSourceHandle(2, 0, 222)) - .build(); - InternalNodeManager nodeManager = new InMemoryNodeManager(); - TaskSource taskSource = new SingleDistributionTaskSource(createRemoteSplits(sources), nodeManager, true); - - assertFalse(taskSource.isFinished()); - - List tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(1); - assertTrue(taskSource.isFinished()); - - TaskDescriptor task = tasks.get(0); - assertThat(task.getNodeRequirements().getCatalogHandle()).isEmpty(); - assertThat(task.getNodeRequirements().getAddresses()).containsExactly(nodeManager.getCurrentNode().getHostAndPort()); - assertEquals(task.getPartitionId(), 0); - assertEquals(extractSourceHandles(task.getSplits()), sources); - assertEquals(extractCatalogSplits(task.getSplits()), ImmutableListMultimap.of()); - } - - @Test - public void testArbitraryDistributionTaskSource() - { - TaskSource taskSource = new ArbitraryDistributionTaskSource( - ImmutableListMultimap.of(), - ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); - assertFalse(taskSource.isFinished()); - List tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).isEmpty(); - assertTrue(taskSource.isFinished()); - - TestingExchangeSourceHandle sourceHandle1 = new TestingExchangeSourceHandle(0, 0, 1); - TestingExchangeSourceHandle sourceHandle2 = new TestingExchangeSourceHandle(1, 0, 2); - TestingExchangeSourceHandle sourceHandle3 = new TestingExchangeSourceHandle(2, 0, 3); - TestingExchangeSourceHandle sourceHandle4 = new TestingExchangeSourceHandle(3, 0, 4); - TestingExchangeSourceHandle sourceHandle123 = new TestingExchangeSourceHandle(4, 0, 123); - TestingExchangeSourceHandle sourceHandle321 = new TestingExchangeSourceHandle(5, 0, 321); - Multimap nonReplicatedSources = ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle3); - taskSource = new ArbitraryDistributionTaskSource( - nonReplicatedSources, - ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(1); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle3)); - - nonReplicatedSources = ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123); - taskSource = new ArbitraryDistributionTaskSource( - nonReplicatedSources, - ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(1); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123)); - - nonReplicatedSources = ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle123, - PLAN_NODE_2, sourceHandle321); - taskSource = new ArbitraryDistributionTaskSource( - nonReplicatedSources, - ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(2); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle123)); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, sourceHandle321)); - - nonReplicatedSources = ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle1, - PLAN_NODE_1, sourceHandle2, - PLAN_NODE_2, sourceHandle4); - taskSource = new ArbitraryDistributionTaskSource( - nonReplicatedSources, - ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(2); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals( - extractSourceHandles(tasks.get(0).getSplits()), - ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle1, - PLAN_NODE_1, sourceHandle2)); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, sourceHandle4)); - - nonReplicatedSources = ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle1, - PLAN_NODE_1, sourceHandle3, - PLAN_NODE_2, sourceHandle4); - taskSource = new ArbitraryDistributionTaskSource( - nonReplicatedSources, - ImmutableListMultimap.of(), - DataSize.of(3, BYTE)); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(3); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle1)); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_1, sourceHandle3)); - assertEquals(tasks.get(2).getPartitionId(), 2); - assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_2, sourceHandle4)); - - // with replicated sources - nonReplicatedSources = ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle1, - PLAN_NODE_1, sourceHandle2, - PLAN_NODE_1, sourceHandle4); - Multimap replicatedSources = ImmutableListMultimap.of( - PLAN_NODE_2, sourceHandle321); - taskSource = new ArbitraryDistributionTaskSource( - nonReplicatedSources, - replicatedSources, - DataSize.of(3, BYTE)); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertThat(tasks).hasSize(2); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals( - extractSourceHandles(tasks.get(0).getSplits()), - ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle1, - PLAN_NODE_1, sourceHandle2, - PLAN_NODE_2, sourceHandle321)); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.empty(), ImmutableSet.of())); - assertEquals( - extractSourceHandles(tasks.get(1).getSplits()), - ImmutableListMultimap.of( - PLAN_NODE_1, sourceHandle4, - PLAN_NODE_2, sourceHandle321)); - } - - @Test - public void testHashDistributionTaskSource() - { - TaskSource taskSource = createHashDistributionTaskSource( - ImmutableMap.of(), - ImmutableListMultimap.of(), - ImmutableListMultimap.of(), - 1, - createPartitioningScheme(4), - 0, - DataSize.of(3, BYTE)); - assertFalse(taskSource.isFinished()); - assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of()); - assertTrue(taskSource.isFinished()); - - taskSource = createHashDistributionTaskSource( - ImmutableMap.of(), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1)), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1)), - 1, - createPartitioningScheme(4), - 0, - DataSize.of(0, BYTE)); - assertFalse(taskSource.isFinished()); - List tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(3); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of())); - assertEquals(extractSourceHandles( - tasks.get(0).getSplits()), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of())); - assertEquals( - extractSourceHandles(tasks.get(1).getSplits()), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - assertEquals(tasks.get(2).getPartitionId(), 2); - assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of())); - assertEquals( - extractSourceHandles(tasks.get(2).getSplits()), - ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - - Split bucketedSplit1 = createBucketedSplit(0, 0); - Split bucketedSplit2 = createBucketedSplit(0, 2); - Split bucketedSplit3 = createBucketedSplit(0, 3); - Split bucketedSplit4 = createBucketedSplit(0, 1); - - taskSource = createHashDistributionTaskSource( - ImmutableMap.of( - PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), - PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), - ImmutableListMultimap.of(), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1)), - 1, - createPartitioningScheme(4, 4), - 0, - DataSize.of(0, BYTE)); - assertFalse(taskSource.isFinished()); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(4); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit1)); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_5, bucketedSplit4)); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); - assertEquals(tasks.get(2).getPartitionId(), 2); - assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit2)); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); - assertEquals(tasks.get(3).getPartitionId(), 3); - assertEquals(tasks.get(3).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit3)); - assertEquals(extractSourceHandles(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1))); - - taskSource = createHashDistributionTaskSource( - ImmutableMap.of( - PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), - PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1)), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1)), - 1, - createPartitioningScheme(4, 4), - 0, - DataSize.of(0, BYTE)); - assertFalse(taskSource.isFinished()); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(4); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit1)); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_5, bucketedSplit4)); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - assertEquals(tasks.get(2).getPartitionId(), 2); - assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit2)); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - assertEquals(tasks.get(3).getPartitionId(), 3); - assertEquals(tasks.get(3).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(3).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit3)); - assertEquals(extractSourceHandles(tasks.get(3).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(4, 0, 1))); - - taskSource = createHashDistributionTaskSource( - ImmutableMap.of( - PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), - PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1)), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(3, 0, 1)), - 2, - createPartitioningScheme(2, 4), - 0, DataSize.of(0, BYTE)); - assertFalse(taskSource.isFinished()); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(2); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_4, bucketedSplit1, - PLAN_NODE_4, bucketedSplit2)); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 0, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(3, 0, 1))); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_4, bucketedSplit3, - PLAN_NODE_5, bucketedSplit4)); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(3, 0, 1))); - - // join based on split target split weight - taskSource = createHashDistributionTaskSource( - ImmutableMap.of( - PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), - PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 1)), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1)), - 2, - createPartitioningScheme(4, 4), - 2 * STANDARD_WEIGHT, - DataSize.of(100, GIGABYTE)); - assertFalse(taskSource.isFinished()); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(2); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_4, bucketedSplit1, - PLAN_NODE_5, bucketedSplit4)); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 1), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_4, bucketedSplit2, - PLAN_NODE_4, bucketedSplit3)); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 1), - PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 1), - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); - - // join based on target exchange size - taskSource = createHashDistributionTaskSource( - ImmutableMap.of( - PLAN_NODE_4, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)), - PLAN_NODE_5, new TestingSplitSource(TEST_CATALOG_HANDLE, ImmutableList.of(bucketedSplit4))), - ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 20), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 30), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 20), - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 99), - PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 30)), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1)), - 2, - createPartitioningScheme(4, 4), - 100 * STANDARD_WEIGHT, - DataSize.of(100, BYTE)); - assertFalse(taskSource.isFinished()); - tasks = getFutureValue(taskSource.getMoreTasks()); - assertTrue(taskSource.isFinished()); - assertThat(tasks).hasSize(3); - assertEquals(tasks.get(0).getPartitionId(), 0); - assertEquals(tasks.get(0).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_4, bucketedSplit1, - PLAN_NODE_5, bucketedSplit4)); - assertEquals(extractSourceHandles(tasks.get(0).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_1, new TestingExchangeSourceHandle(0, 0, 20), - PLAN_NODE_1, new TestingExchangeSourceHandle(1, 1, 30), - PLAN_NODE_2, new TestingExchangeSourceHandle(2, 1, 20), - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); - assertEquals(tasks.get(1).getPartitionId(), 1); - assertEquals(tasks.get(1).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(1).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit2)); - assertEquals(extractSourceHandles(tasks.get(1).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(3, 2, 99), - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); - assertEquals(tasks.get(2).getPartitionId(), 2); - assertEquals(tasks.get(2).getNodeRequirements(), new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of(NODE_ADDRESS))); - assertEquals(extractCatalogSplits(tasks.get(2).getSplits()), ImmutableListMultimap.of(PLAN_NODE_4, bucketedSplit3)); - assertEquals(extractSourceHandles(tasks.get(2).getSplits()), ImmutableListMultimap.of( - PLAN_NODE_2, new TestingExchangeSourceHandle(4, 3, 30), - PLAN_NODE_3, new TestingExchangeSourceHandle(5, 17, 1))); - } - - private static HashDistributionTaskSource createHashDistributionTaskSource( - Map splitSources, - ListMultimap partitionedExchangeSources, - ListMultimap replicatedExchangeSources, - int splitBatchSize, - FaultTolerantPartitioningScheme sourcePartitioningScheme, - long targetPartitionSplitWeight, - DataSize targetPartitionSourceSize) - { - return new HashDistributionTaskSource( - splitSources, - partitionedExchangeSources, - replicatedExchangeSources, - splitBatchSize, - getSplitsTime -> {}, - sourcePartitioningScheme, - Optional.of(TEST_CATALOG_HANDLE), - targetPartitionSplitWeight, - targetPartitionSourceSize, - directExecutor()); - } - - @Test - public void testSourceDistributionTaskSource() - { - TaskSource taskSource = createSourceDistributionTaskSource(ImmutableList.of(), ImmutableListMultimap.of(), 2, 0, 3 * STANDARD_WEIGHT, 1000); - assertFalse(taskSource.isFinished()); - assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of()); - assertTrue(taskSource.isFinished()); - - Split split1 = createSplit(1); - Split split2 = createSplit(2); - Split split3 = createSplit(3); - - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(split1), - ImmutableListMultimap.of(), - 2, - 0, - 2 * STANDARD_WEIGHT, - 1000); - assertEquals(getFutureValue(taskSource.getMoreTasks()), ImmutableList.of(new TaskDescriptor( - 0, - ImmutableListMultimap.of(PLAN_NODE_1, split1), - new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of())))); - assertTrue(taskSource.isFinished()); - - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(split1, split2, split3), - ImmutableListMultimap.of(), - 3, - 0, - 2 * STANDARD_WEIGHT, - 1000); - - List tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(2); - assertThat(tasks.get(0).getSplits().values()).hasSize(2); - assertThat(tasks.get(1).getSplits().values()).hasSize(1); - assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getNodeRequirements().equals(new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of()))); - assertThat(tasks).allMatch(taskDescriptor -> extractSourceHandles(taskDescriptor.getSplits()).isEmpty()); - assertThat(flattenSplits(tasks)).hasSameEntriesAs(ImmutableMultimap.of( - PLAN_NODE_1, split1, - PLAN_NODE_1, split2, - PLAN_NODE_1, split3)); - assertTrue(taskSource.isFinished()); - - ImmutableListMultimap replicatedSources = ImmutableListMultimap.of(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 0, 1)); - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(split1, split2, split3), - replicatedSources, - 2, - 0, - 2 * STANDARD_WEIGHT, - 1000); - - tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(2); - assertThat(tasks.get(0).getSplits().values()).hasSize(3); - assertThat(tasks.get(1).getSplits().values()).hasSize(2); - assertThat(tasks).allMatch(taskDescriptor -> taskDescriptor.getNodeRequirements().equals(new NodeRequirements(Optional.of(TEST_CATALOG_HANDLE), ImmutableSet.of()))); - assertThat(tasks).allMatch(taskDescriptor -> extractSourceHandles(taskDescriptor.getSplits()).equals(replicatedSources)); - assertThat(extractCatalogSplits(flattenSplits(tasks))).hasSameEntriesAs(ImmutableMultimap.of( - PLAN_NODE_1, split1, - PLAN_NODE_1, split2, - PLAN_NODE_1, split3)); - assertTrue(taskSource.isFinished()); - - // non remotely accessible splits - ImmutableList splits = ImmutableList.of( - createSplit(1, "host1:8080", "host2:8080"), - createSplit(2, "host2:8080"), - createSplit(3, "host1:8080", "host3:8080"), - createSplit(4, "host3:8080", "host1:8080"), - createSplit(5, "host1:8080", "host2:8080"), - createSplit(6, "host2:8080", "host3:8080"), - createSplit(7, "host3:8080", "host4:8080")); - taskSource = createSourceDistributionTaskSource(splits, ImmutableListMultimap.of(), 3, 0, 2 * STANDARD_WEIGHT, 1000); - - tasks = readAllTasks(taskSource); - - assertThat(tasks).hasSize(4); - assertThat(tasks.stream()).allMatch(taskDescriptor -> extractSourceHandles(taskDescriptor.getSplits()).isEmpty()); - assertThat(flattenSplits(tasks)).hasSameEntriesAs(Multimaps.index(splits, split -> PLAN_NODE_1)); - assertThat(tasks).allMatch(task -> task.getSplits().values().stream().allMatch(split -> { - HostAddress requiredAddress = getOnlyElement(task.getNodeRequirements().getAddresses()); - return split.getAddresses().contains(requiredAddress); - })); - assertTrue(taskSource.isFinished()); - } - - @Test - public void testSourceDistributionTaskSourceWithWeights() - { - Split split1 = createWeightedSplit(1, STANDARD_WEIGHT); - long heavyWeight = 2 * STANDARD_WEIGHT; - Split heavySplit1 = createWeightedSplit(11, heavyWeight); - Split heavySplit2 = createWeightedSplit(12, heavyWeight); - Split heavySplit3 = createWeightedSplit(13, heavyWeight); - long lightWeight = (long) (0.5 * STANDARD_WEIGHT); - Split lightSplit1 = createWeightedSplit(21, lightWeight); - Split lightSplit2 = createWeightedSplit(22, lightWeight); - Split lightSplit3 = createWeightedSplit(23, lightWeight); - Split lightSplit4 = createWeightedSplit(24, lightWeight); - - // no limits - TaskSource taskSource = createSourceDistributionTaskSource( - ImmutableList.of(lightSplit1, lightSplit2, split1, heavySplit1, heavySplit2, lightSplit4), - ImmutableListMultimap.of(), - 1, // single split per batch for predictable results - 0, - (long) (1.9 * STANDARD_WEIGHT), - 1000); - List tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(4); - assertThat(tasks).allMatch(task -> getOnlyElement(task.getSplits().keySet()).equals(PLAN_NODE_1)); - assertThat(tasks.get(0).getSplits().values()).containsExactlyInAnyOrder(lightSplit1, lightSplit2, split1); - assertThat(tasks.get(1).getSplits().values()).containsExactlyInAnyOrder(heavySplit1); - assertThat(tasks.get(2).getSplits().values()).containsExactlyInAnyOrder(heavySplit2); - assertThat(tasks.get(3).getSplits().values()).containsExactlyInAnyOrder(lightSplit4); // remainder - assertTrue(taskSource.isFinished()); - - // min splits == 2 - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(heavySplit1, heavySplit2, heavySplit3, lightSplit1, lightSplit2, lightSplit3, lightSplit4), - ImmutableListMultimap.of(), - 1, // single split per batch for predictable results - 2, - 2 * STANDARD_WEIGHT, - 1000); - - tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(3); - assertThat(tasks).allMatch(task -> getOnlyElement(task.getSplits().keySet()).equals(PLAN_NODE_1)); - assertThat(tasks.get(0).getSplits().values()).containsExactlyInAnyOrder(heavySplit1, heavySplit2); - assertThat(tasks.get(1).getSplits().values()).containsExactlyInAnyOrder(heavySplit3, lightSplit1); - assertThat(tasks.get(2).getSplits().values()).containsExactlyInAnyOrder(lightSplit2, lightSplit3, lightSplit4); - assertTrue(taskSource.isFinished()); - - // max splits == 3 - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(lightSplit1, lightSplit2, lightSplit3, heavySplit1, lightSplit4), - ImmutableListMultimap.of(), - 1, // single split per batch for predictable results - 0, - 2 * STANDARD_WEIGHT, - 3); - - tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(3); - assertThat(tasks).allMatch(task -> getOnlyElement(task.getSplits().keySet()).equals(PLAN_NODE_1)); - assertThat(tasks.get(0).getSplits().values()).containsExactlyInAnyOrder(lightSplit1, lightSplit2, lightSplit3); - assertThat(tasks.get(1).getSplits().values()).containsExactlyInAnyOrder(heavySplit1); - assertThat(tasks.get(2).getSplits().values()).containsExactlyInAnyOrder(lightSplit4); - assertTrue(taskSource.isFinished()); - - // with addresses - Split split1a1 = createWeightedSplit(1, STANDARD_WEIGHT, "host1:8080"); - Split split2a2 = createWeightedSplit(2, STANDARD_WEIGHT, "host2:8080"); - Split split3a1 = createWeightedSplit(3, STANDARD_WEIGHT, "host1:8080"); - Split split3a12 = createWeightedSplit(3, STANDARD_WEIGHT, "host1:8080", "host2:8080"); - Split heavySplit2a2 = createWeightedSplit(12, heavyWeight, "host2:8080"); - Split lightSplit1a1 = createWeightedSplit(21, lightWeight, "host1:8080"); - - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(split1a1, heavySplit2a2, split3a1, lightSplit1a1), - ImmutableListMultimap.of(), - 1, // single split per batch for predictable results - 0, - 2 * STANDARD_WEIGHT, - 3); - - tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(3); - assertThat(tasks).allMatch(task -> getOnlyElement(task.getSplits().keySet()).equals(PLAN_NODE_1)); - assertThat(tasks.get(0).getSplits().values()).containsExactlyInAnyOrder(heavySplit2a2); - assertThat(tasks.get(1).getSplits().values()).containsExactlyInAnyOrder(split1a1, split3a1); - assertThat(tasks.get(2).getSplits().values()).containsExactlyInAnyOrder(lightSplit1a1); - assertTrue(taskSource.isFinished()); - - // with addresses with multiple matching - taskSource = createSourceDistributionTaskSource( - ImmutableList.of(split1a1, split3a12, split2a2), - ImmutableListMultimap.of(), - 1, // single split per batch for predictable results - 0, - 2 * STANDARD_WEIGHT, - 3); - - tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(2); - assertThat(tasks).allMatch(task -> getOnlyElement(task.getSplits().keySet()).equals(PLAN_NODE_1)); - assertThat(tasks.get(0).getSplits().values()).containsExactlyInAnyOrder(split1a1, split3a12); - assertThat(tasks.get(1).getSplits().values()).containsExactlyInAnyOrder(split2a2); - assertTrue(taskSource.isFinished()); - } - - @Test - public void testSourceDistributionTaskSourceLastIncompleteTaskAlwaysCreated() - { - for (int targetSplitsPerTask = 1; targetSplitsPerTask <= 21; targetSplitsPerTask++) { - List splits = new ArrayList<>(); - for (int i = 0; i < targetSplitsPerTask + 1 /* to make last task incomplete with only a single split */; i++) { - splits.add(createWeightedSplit(i, STANDARD_WEIGHT)); - } - for (int finishDelayIterations = 1; finishDelayIterations < 20; finishDelayIterations++) { - for (int splitBatchSize = 1; splitBatchSize <= 5; splitBatchSize++) { - TaskSource taskSource = createSourceDistributionTaskSource( - new TestingSplitSource(TEST_CATALOG_HANDLE, splits, finishDelayIterations), - ImmutableListMultimap.of(), - splitBatchSize, - targetSplitsPerTask, - STANDARD_WEIGHT * targetSplitsPerTask, - targetSplitsPerTask); - List tasks = readAllTasks(taskSource); - assertThat(tasks).hasSize(2); - TaskDescriptor lastTask = findLast(tasks.stream()).orElseThrow(); - assertThat(lastTask.getSplits()).hasSize(1); - } - } - } - } - - @Test - public void testSourceDistributionTaskSourceWithAsyncSplitSource() - { - SettableFuture> splitsFuture = SettableFuture.create(); - TaskSource taskSource = createSourceDistributionTaskSource( - new TestingSplitSource(TEST_CATALOG_HANDLE, splitsFuture, 0), - ImmutableListMultimap.of(), - 2, - 0, - 2 * STANDARD_WEIGHT, - 1000); - ListenableFuture> tasksFuture = taskSource.getMoreTasks(); - assertThat(tasksFuture).isNotDone(); - - splitsFuture.set(ImmutableList.of(createSplit(1), createSplit(2), createSplit(3))); - List tasks = getDone(tasksFuture); - assertThat(tasks).hasSize(1); - assertThat(tasks.get(0).getSplits()).hasSize(2); - - tasksFuture = taskSource.getMoreTasks(); - assertThat(tasksFuture).isDone(); - tasks = getDone(tasksFuture); - assertThat(tasks).hasSize(1); - assertThat(tasks.get(0).getSplits()).hasSize(1); - assertThat(taskSource.isFinished()).isTrue(); - } - - @Test - public void testHashDistributionTaskSourceWithAsyncSplitSource() - { - SettableFuture> splitsFuture1 = SettableFuture.create(); - SettableFuture> splitsFuture2 = SettableFuture.create(); - TaskSource taskSource = createHashDistributionTaskSource( - ImmutableMap.of( - PLAN_NODE_1, new TestingSplitSource(TEST_CATALOG_HANDLE, splitsFuture1, 0), - PLAN_NODE_2, new TestingSplitSource(TEST_CATALOG_HANDLE, splitsFuture2, 0)), - ImmutableListMultimap.of(), - ImmutableListMultimap.of( - PLAN_NODE_3, new TestingExchangeSourceHandle(0, 0, 1)), - 1, - createPartitioningScheme(4, 4), - 0, - DataSize.of(0, BYTE)); - ListenableFuture> tasksFuture = taskSource.getMoreTasks(); - assertThat(tasksFuture).isNotDone(); - - Split bucketedSplit1 = createBucketedSplit(0, 0); - Split bucketedSplit2 = createBucketedSplit(0, 2); - Split bucketedSplit3 = createBucketedSplit(0, 3); - splitsFuture1.set(ImmutableList.of(bucketedSplit1, bucketedSplit2, bucketedSplit3)); - assertThat(tasksFuture).isNotDone(); - - Split bucketedSplit4 = createBucketedSplit(0, 1); - splitsFuture2.set(ImmutableList.of(bucketedSplit4)); - List tasks = getDone(tasksFuture); - assertThat(tasks).hasSize(4); - tasks.forEach(task -> assertThat(task.getSplits()).hasSize(2)); - assertThat(taskSource.isFinished()).isTrue(); - } - - private static SourceDistributionTaskSource createSourceDistributionTaskSource( - List splits, - ListMultimap replicatedSources, - int splitBatchSize, - int minSplitsPerTask, - long splitWeightPerTask, - int maxSplitsPerTask) - { - return createSourceDistributionTaskSource( - new TestingSplitSource(TEST_CATALOG_HANDLE, splits), - replicatedSources, - splitBatchSize, - minSplitsPerTask, - splitWeightPerTask, - maxSplitsPerTask); - } - - private static SourceDistributionTaskSource createSourceDistributionTaskSource( - SplitSource splitSource, - ListMultimap replicatedSources, - int splitBatchSize, - int minSplitsPerTask, - long splitWeightPerTask, - int maxSplitsPerTask) - { - return new SourceDistributionTaskSource( - new QueryId("query"), - PLAN_NODE_1, - new TableExecuteContextManager(), - splitSource, - createRemoteSplits(replicatedSources), - splitBatchSize, - getSplitsTime -> {}, - Optional.of(TEST_CATALOG_HANDLE), - minSplitsPerTask, - splitWeightPerTask, - maxSplitsPerTask, - directExecutor()); - } - - private static Split createSplit(int id, String... addresses) - { - return new Split(TEST_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.empty(), addressesList(addresses))); - } - - private static Split createWeightedSplit(int id, long weight, String... addresses) - { - return new Split(TEST_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.empty(), addressesList(addresses), weight)); - } - - private static Split createBucketedSplit(int id, int bucket) - { - return new Split(TEST_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.of(bucket), Optional.empty())); - } - - private List readAllTasks(TaskSource taskSource) - { - ImmutableList.Builder tasks = ImmutableList.builder(); - while (!taskSource.isFinished()) { - tasks.addAll(getFutureValue(taskSource.getMoreTasks())); - } - return tasks.build(); - } - - private ListMultimap flattenSplits(List tasks) - { - return tasks.stream() - .flatMap(taskDescriptor -> taskDescriptor.getSplits().entries().stream()) - .collect(toImmutableListMultimap(Map.Entry::getKey, Map.Entry::getValue)); - } - - private static Optional> addressesList(String... addresses) - { - requireNonNull(addresses, "addresses is null"); - if (addresses.length == 0) { - return Optional.empty(); - } - return Optional.of(Arrays.stream(addresses) - .map(HostAddress::fromString) - .collect(toImmutableList())); - } - - private static FaultTolerantPartitioningScheme createPartitioningScheme(int partitionCount) - { - return new FaultTolerantPartitioningScheme( - partitionCount, - Optional.of(IntStream.range(0, partitionCount).toArray()), - Optional.empty(), - Optional.empty()); - } - - private static FaultTolerantPartitioningScheme createPartitioningScheme(int partitionCount, int bucketCount) - { - int[] bucketToPartitionMap = new int[bucketCount]; - for (int i = 0; i < bucketCount; i++) { - bucketToPartitionMap[i] = i % partitionCount; - } - return new FaultTolerantPartitioningScheme( - partitionCount, - Optional.of(bucketToPartitionMap), - Optional.of(split -> ((TestingConnectorSplit) split.getConnectorSplit()).getBucket().orElseThrow()), - Optional.of(nCopies(partitionCount, new InternalNode("local", URI.create("local://" + NODE_ADDRESS), NodeVersion.UNKNOWN, true)))); - } - - private static ListMultimap extractSourceHandles(ListMultimap splits) - { - ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); - splits.forEach(((planNodeId, split) -> { - if (split.getCatalogHandle().equals(REMOTE_CATALOG_HANDLE)) { - RemoteSplit remoteSplit = (RemoteSplit) split.getConnectorSplit(); - SpoolingExchangeInput input = (SpoolingExchangeInput) remoteSplit.getExchangeInput(); - result.putAll(planNodeId, input.getExchangeSourceHandles()); - } - })); - return result.build(); - } - - private static ListMultimap extractCatalogSplits(ListMultimap splits) - { - ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); - splits.forEach((planNodeId, split) -> { - if (!split.getCatalogHandle().equals(REMOTE_CATALOG_HANDLE)) { - result.put(planNodeId, split); - } - }); - return result.build(); - } -} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java deleted file mode 100644 index c5bdaaf61f2a..000000000000 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingExchange.java +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import io.trino.spi.exchange.Exchange; -import io.trino.spi.exchange.ExchangeId; -import io.trino.spi.exchange.ExchangeSinkHandle; -import io.trino.spi.exchange.ExchangeSinkInstanceHandle; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.spi.exchange.ExchangeSourceHandleSource; - -import java.util.List; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; - -import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.collect.Sets.newConcurrentHashSet; -import static io.trino.spi.exchange.ExchangeId.createRandomExchangeId; -import static java.util.Objects.requireNonNull; - -public class TestingExchange - implements Exchange -{ - private final ExchangeId exchangeId = createRandomExchangeId(); - private final Set finishedSinks = newConcurrentHashSet(); - private final Set allSinks = newConcurrentHashSet(); - private final AtomicBoolean noMoreSinks = new AtomicBoolean(); - private final CompletableFuture> sourceHandles = new CompletableFuture<>(); - private final AtomicBoolean allRequiredSinksFinished = new AtomicBoolean(); - - @Override - public ExchangeId getId() - { - return exchangeId; - } - - @Override - public ExchangeSinkHandle addSink(int taskPartitionId) - { - TestingExchangeSinkHandle sinkHandle = new TestingExchangeSinkHandle(taskPartitionId); - allSinks.add(sinkHandle); - return sinkHandle; - } - - @Override - public void noMoreSinks() - { - noMoreSinks.set(true); - } - - public boolean isNoMoreSinks() - { - return noMoreSinks.get(); - } - - @Override - public ExchangeSinkInstanceHandle instantiateSink(ExchangeSinkHandle sinkHandle, int taskAttemptId) - { - return new TestingExchangeSinkInstanceHandle((TestingExchangeSinkHandle) sinkHandle, taskAttemptId); - } - - @Override - public ExchangeSinkInstanceHandle updateSinkInstanceHandle(ExchangeSinkHandle sinkHandle, int taskAttemptId) - { - throw new UnsupportedOperationException(); - } - - @Override - public void sinkFinished(ExchangeSinkHandle sinkHandle, int taskAttemptId) - { - finishedSinks.add((TestingExchangeSinkHandle) sinkHandle); - } - - @Override - public void allRequiredSinksFinished() - { - allRequiredSinksFinished.set(true); - } - - public boolean isAllRequiredSinksFinished() - { - return allRequiredSinksFinished.get(); - } - - public Set getFinishedSinkHandles() - { - return ImmutableSet.copyOf(finishedSinks); - } - - @Override - public ExchangeSourceHandleSource getSourceHandles() - { - return new ExchangeSourceHandleSource() - { - @Override - public CompletableFuture getNextBatch() - { - return sourceHandles.thenApply(handles -> new ExchangeSourceHandleBatch(handles, true)); - } - - @Override - public void close() {} - }; - } - - public void setSourceHandles(List handles) - { - sourceHandles.complete(ImmutableList.copyOf(handles)); - } - - @Override - public void close() - { - } - - public static class TestingExchangeSinkInstanceHandle - implements ExchangeSinkInstanceHandle - { - private final TestingExchangeSinkHandle sinkHandle; - private final int attemptId; - - public TestingExchangeSinkInstanceHandle(TestingExchangeSinkHandle sinkHandle, int attemptId) - { - this.sinkHandle = requireNonNull(sinkHandle, "sinkHandle is null"); - this.attemptId = attemptId; - } - - public TestingExchangeSinkHandle getSinkHandle() - { - return sinkHandle; - } - - public int getAttemptId() - { - return attemptId; - } - } - - public static class TestingExchangeSinkHandle - implements ExchangeSinkHandle - { - private final int taskPartitionId; - - public TestingExchangeSinkHandle(int taskPartitionId) - { - this.taskPartitionId = taskPartitionId; - } - - public int getTaskPartitionId() - { - return taskPartitionId; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - TestingExchangeSinkHandle sinkHandle = (TestingExchangeSinkHandle) o; - return taskPartitionId == sinkHandle.taskPartitionId; - } - - @Override - public int hashCode() - { - return Objects.hash(taskPartitionId); - } - - @Override - public String toString() - { - return toStringHelper(this) - .add("taskPartitionId", taskPartitionId) - .toString(); - } - } -} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java deleted file mode 100644 index 17958293e2be..000000000000 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingSplitSource.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import io.trino.metadata.Split; -import io.trino.spi.connector.CatalogHandle; -import io.trino.split.SplitSource; - -import java.util.Iterator; -import java.util.List; -import java.util.Optional; - -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.util.concurrent.Futures.immediateFuture; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static java.util.Objects.requireNonNull; - -public class TestingSplitSource - implements SplitSource -{ - private final CatalogHandle catalogHandle; - private final ListenableFuture> splitsFuture; - private int finishDelayRemainingIterations; - private Iterator splits; - - public TestingSplitSource(CatalogHandle catalogHandle, List splits) - { - this(catalogHandle, splits, 0); - } - - public TestingSplitSource(CatalogHandle catalogHandle, List splits, int finishDelayIterations) - { - this( - catalogHandle, - immediateFuture(ImmutableList.copyOf(requireNonNull(splits, "splits is null"))), - finishDelayIterations); - } - - public TestingSplitSource(CatalogHandle catalogHandle, ListenableFuture> splitsFuture, int finishDelayIterations) - { - this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); - this.splitsFuture = requireNonNull(splitsFuture, "splitsFuture is null"); - this.finishDelayRemainingIterations = finishDelayIterations; - } - - @Override - public CatalogHandle getCatalogHandle() - { - return catalogHandle; - } - - @Override - public ListenableFuture getNextBatch(int maxSize) - { - if (isFinished()) { - return immediateFuture(new SplitBatch(ImmutableList.of(), true)); - } - - if (splits == null) { - return Futures.transform( - splitsFuture, - splits -> { - checkState(this.splits == null, "splits should be null"); - this.splits = splits.iterator(); - return populateSplitBatch(maxSize); - }, - directExecutor()); - } - checkState(splitsFuture.isDone(), "splitsFuture should be completed"); - return immediateFuture(populateSplitBatch(maxSize)); - } - - @Override - public void close() - { - } - - @Override - public boolean isFinished() - { - return (splits != null && !splits.hasNext()) - && finishDelayRemainingIterations-- <= 0; - } - - @Override - public Optional> getTableExecuteSplitsInfo() - { - return Optional.empty(); - } - - private SplitBatch populateSplitBatch(int maxSize) - { - ImmutableList.Builder result = ImmutableList.builder(); - for (int i = 0; i < maxSize; i++) { - if (!splits.hasNext()) { - break; - } - result.add(splits.next()); - } - return new SplitBatch(result.build(), isFinished()); - } -} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java deleted file mode 100644 index 6a5bd6e381bd..000000000000 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskLifecycleListener.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.ArrayListMultimap; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Multimap; -import io.trino.execution.RemoteTask; -import io.trino.execution.TaskId; -import io.trino.sql.planner.plan.PlanFragmentId; - -import javax.annotation.concurrent.GuardedBy; - -import java.util.Set; - -import static com.google.common.collect.Sets.newConcurrentHashSet; - -public class TestingTaskLifecycleListener - implements TaskLifecycleListener -{ - @GuardedBy("this") - private final Multimap tasks = ArrayListMultimap.create(); - private final Set noMoreTasks = newConcurrentHashSet(); - - @Override - public synchronized void taskCreated(PlanFragmentId fragmentId, RemoteTask task) - { - tasks.put(fragmentId, task.getTaskId()); - } - - public synchronized Multimap getTasks() - { - return ImmutableListMultimap.copyOf(tasks); - } - - @Override - public void noMoreTasks(PlanFragmentId fragmentId) - { - noMoreTasks.add(fragmentId); - } - - public Set getNoMoreTasks() - { - return ImmutableSet.copyOf(noMoreTasks); - } -} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java deleted file mode 100644 index 0aef687f8ce2..000000000000 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestingTaskSourceFactory.java +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.execution.scheduler; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.ListMultimap; -import com.google.common.collect.Multimap; -import com.google.common.util.concurrent.Futures; -import com.google.common.util.concurrent.ListenableFuture; -import io.trino.Session; -import io.trino.metadata.Split; -import io.trino.spi.connector.CatalogHandle; -import io.trino.spi.exchange.ExchangeSourceHandle; -import io.trino.sql.planner.PlanFragment; -import io.trino.sql.planner.plan.PlanFragmentId; -import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.RemoteSourceNode; - -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.LongConsumer; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; -import static com.google.common.collect.Iterables.getOnlyElement; -import static com.google.common.util.concurrent.Futures.immediateFuture; -import static com.google.common.util.concurrent.MoreExecutors.directExecutor; -import static io.trino.execution.scheduler.StageTaskSourceFactory.createRemoteSplits; -import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; -import static java.util.Objects.requireNonNull; - -public class TestingTaskSourceFactory - implements TaskSourceFactory -{ - private final Optional catalog; - private final ListenableFuture> splitsFuture; - private final int tasksPerBatch; - - public TestingTaskSourceFactory(Optional catalog, List splits, int tasksPerBatch) - { - this(catalog, immediateFuture(ImmutableList.copyOf(requireNonNull(splits, "splits is null"))), tasksPerBatch); - } - - public TestingTaskSourceFactory(Optional catalog, ListenableFuture> splitsFuture, int tasksPerBatch) - { - this.catalog = requireNonNull(catalog, "catalog is null"); - this.splitsFuture = requireNonNull(splitsFuture, "splitsFuture is null"); - this.tasksPerBatch = tasksPerBatch; - } - - @Override - public TaskSource create( - Session session, - PlanFragment fragment, - Multimap exchangeSourceHandles, - LongConsumer getSplitTimeRecorder, - FaultTolerantPartitioningScheme sourcePartitioningScheme) - { - List partitionedSources = fragment.getPartitionedSources(); - checkArgument(partitionedSources.size() == 1, "single partitioned source is expected"); - - return new TestingTaskSource( - catalog, - splitsFuture, - tasksPerBatch, - getOnlyElement(partitionedSources), - getHandlesForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles)); - } - - private static ListMultimap getHandlesForRemoteSources( - List remoteSources, - Multimap exchangeSourceHandles) - { - ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); - for (RemoteSourceNode remoteSource : remoteSources) { - checkArgument(remoteSource.getExchangeType() == REPLICATE, "expected exchange type to be REPLICATE, got: %s", remoteSource.getExchangeType()); - for (PlanFragmentId fragmentId : remoteSource.getSourceFragmentIds()) { - Collection handles = exchangeSourceHandles.get(fragmentId); - checkArgument(handles.size() == 1, "single exchange source handle is expected, got: %s", handles); - result.putAll(remoteSource.getId(), handles); - } - } - return result.build(); - } - - public static class TestingTaskSource - implements TaskSource - { - private final Optional catalogRequirement; - private final ListenableFuture> splitsFuture; - private final int tasksPerBatch; - private final PlanNodeId tableScanPlanNodeId; - private final ListMultimap exchangeSourceHandles; - - private final AtomicInteger nextPartitionId = new AtomicInteger(); - private Iterator splits; - - public TestingTaskSource( - Optional catalogRequirement, - ListenableFuture> splitsFuture, - int tasksPerBatch, - PlanNodeId tableScanPlanNodeId, - ListMultimap exchangeSourceHandles) - { - this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null"); - this.splitsFuture = requireNonNull(splitsFuture, "splitsFuture is null"); - this.tasksPerBatch = tasksPerBatch; - this.tableScanPlanNodeId = requireNonNull(tableScanPlanNodeId, "tableScanPlanNodeId is null"); - this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null")); - } - - @Override - public ListenableFuture> getMoreTasks() - { - checkState(!isFinished(), "already finished"); - - if (splits == null) { - return Futures.transform( - splitsFuture, - loadedSplits -> { - checkState(this.splits == null, "splits should be null"); - splits = loadedSplits.iterator(); - return getTasksBatch(); - }, - directExecutor()); - } - checkState(splitsFuture.isDone(), "splitsFuture should be completed"); - return immediateFuture(getTasksBatch()); - } - - @Override - public boolean isFinished() - { - return splits != null && !splits.hasNext(); - } - - @Override - public void close() - { - } - - private List getTasksBatch() - { - ImmutableList.Builder result = ImmutableList.builder(); - for (int i = 0; i < tasksPerBatch; i++) { - if (isFinished()) { - break; - } - Split split = splits.next(); - ImmutableListMultimap.Builder splits = ImmutableListMultimap.builder(); - splits.put(tableScanPlanNodeId, split); - splits.putAll(createRemoteSplits(exchangeSourceHandles)); - TaskDescriptor task = new TaskDescriptor( - nextPartitionId.getAndIncrement(), - splits.build(), - new NodeRequirements(catalogRequirement, ImmutableSet.copyOf(split.getAddresses()))); - result.add(task); - } - return result.build(); - } - } -}