diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java index 81a4e8827858..3f1c2847e355 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryStateMachine.java @@ -1037,7 +1037,7 @@ private static boolean isScheduled(Optional rootStage) } return getAllStages(rootStage).stream() .map(StageInfo::getState) - .allMatch(state -> state == StageState.RUNNING || state == StageState.FLUSHING || state.isDone()); + .allMatch(state -> state == StageState.RUNNING || state == StageState.PENDING || state.isDone()); } public Optional getFailureInfo() diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java index 9f478b2fe4fe..9cdd301478f0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlQueryExecution.java @@ -25,8 +25,6 @@ import io.trino.cost.StatsCalculator; import io.trino.execution.QueryPreparer.PreparedQuery; import io.trino.execution.StateMachine.StateChangeListener; -import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.scheduler.ExecutionPolicy; import io.trino.execution.scheduler.NodeScheduler; import io.trino.execution.scheduler.SplitSchedulerStats; @@ -55,7 +53,6 @@ import io.trino.sql.planner.InputExtractor; import io.trino.sql.planner.LogicalPlanner; import io.trino.sql.planner.NodePartitioningManager; -import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanFragmenter; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -89,8 +86,6 @@ import static io.trino.SystemSessionProperties.isEnableDynamicFiltering; import static io.trino.execution.QueryState.FAILED; import static io.trino.execution.QueryState.PLANNING; -import static io.trino.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; -import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.execution.scheduler.SqlQueryScheduler.createSqlQueryScheduler; import static io.trino.server.DynamicFilterService.DynamicFiltersStats; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; @@ -105,8 +100,6 @@ public class SqlQueryExecution { private static final Logger log = Logger.get(SqlQueryExecution.class); - private static final OutputBufferId OUTPUT_BUFFER_ID = new OutputBufferId(0); - private final QueryStateMachine stateMachine; private final Slug slug; private final Metadata metadata; @@ -522,11 +515,6 @@ private void planDistribution(PlanRoot plan) // record output field stateMachine.setColumns(outputStageExecutionPlan.getFieldNames(), outputStageExecutionPlan.getFragment().getTypes()); - PartitioningHandle partitioningHandle = plan.getRoot().getFragment().getPartitioningScheme().getPartitioning().getHandle(); - OutputBuffers rootOutputBuffers = createInitialEmptyOutputBuffers(partitioningHandle) - .withBuffer(OUTPUT_BUFFER_ID, BROADCAST_PARTITION_ID) - .withNoMoreBufferIds(); - // build the stage execution objects (this doesn't schedule execution) SqlQueryScheduler scheduler = createSqlQueryScheduler( stateMachine, @@ -534,13 +522,11 @@ private void planDistribution(PlanRoot plan) nodePartitioningManager, nodeScheduler, remoteTaskFactory, - stateMachine.getSession(), plan.isSummarizeTaskInfos(), scheduleSplitBatchSize, queryExecutor, schedulerExecutor, failureDetector, - rootOutputBuffers, nodeTaskMap, executionPolicy, schedulerStats, diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java index 4184e2ad17c9..a05bed0f02e6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStageExecution.java @@ -13,103 +13,55 @@ */ package io.trino.execution; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Multimap; -import com.google.common.collect.Sets; import io.airlift.units.Duration; import io.trino.Session; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.scheduler.SplitSchedulerStats; -import io.trino.failuredetector.FailureDetector; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; -import io.trino.server.DynamicFilterService; -import io.trino.spi.TrinoException; -import io.trino.split.RemoteSplit; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.DynamicFilterId; -import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.RemoteSourceNode; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import java.net.URI; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Sets.newConcurrentHashSet; -import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.trino.SystemSessionProperties.isEnableCoordinatorDynamicFiltersDistribution; -import static io.trino.failuredetector.FailureDetector.State.GONE; -import static io.trino.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; import static io.trino.server.DynamicFilterService.getOutboundDynamicFilters; -import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; -import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; import static java.util.Objects.requireNonNull; @ThreadSafe public final class SqlStageExecution { + private final Session session; private final StageStateMachine stateMachine; private final RemoteTaskFactory remoteTaskFactory; private final NodeTaskMap nodeTaskMap; private final boolean summarizeTaskInfo; - private final Executor executor; - private final FailureDetector failureDetector; - private final DynamicFilterService dynamicFilterService; - private final Map exchangeSources; - - private final Map> tasks = new ConcurrentHashMap<>(); + private final Set outboundDynamicFilterIds; - @GuardedBy("this") - private final AtomicInteger nextTaskId = new AtomicInteger(); + private final Map tasks = new ConcurrentHashMap<>(); @GuardedBy("this") private final Set allTasks = newConcurrentHashSet(); @GuardedBy("this") private final Set finishedTasks = newConcurrentHashSet(); @GuardedBy("this") - private final Set flushingTasks = newConcurrentHashSet(); - @GuardedBy("this") private final Set tasksWithFinalInfo = newConcurrentHashSet(); - @GuardedBy("this") - private final AtomicBoolean splitsScheduled = new AtomicBoolean(); - - @GuardedBy("this") - private final Multimap sourceTasks = HashMultimap.create(); - @GuardedBy("this") - private final Set completeSources = newConcurrentHashSet(); - @GuardedBy("this") - private final Set completeSourceFragments = newConcurrentHashSet(); - - private final AtomicReference outputBuffers = new AtomicReference<>(); - - private final ListenerManager> completedLifespansChangeListeners = new ListenerManager<>(); - - private final Set outboundDynamicFilterIds; public static SqlStageExecution createSqlStageExecution( StageId stageId, @@ -120,58 +72,42 @@ public static SqlStageExecution createSqlStageExecution( boolean summarizeTaskInfo, NodeTaskMap nodeTaskMap, ExecutorService executor, - FailureDetector failureDetector, - DynamicFilterService dynamicFilterService, SplitSchedulerStats schedulerStats) { requireNonNull(stageId, "stageId is null"); requireNonNull(fragment, "fragment is null"); + checkArgument(fragment.getPartitioningScheme().getBucketToPartition().isEmpty(), "bucket to partition is not expected to be set at this point"); requireNonNull(tables, "tables is null"); requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); requireNonNull(session, "session is null"); requireNonNull(nodeTaskMap, "nodeTaskMap is null"); requireNonNull(executor, "executor is null"); - requireNonNull(failureDetector, "failureDetector is null"); - requireNonNull(dynamicFilterService, "dynamicFilterService is null"); requireNonNull(schedulerStats, "schedulerStats is null"); SqlStageExecution sqlStageExecution = new SqlStageExecution( - new StageStateMachine(stageId, session, fragment, tables, executor, schedulerStats), + session, + new StageStateMachine(stageId, fragment, tables, executor, schedulerStats), remoteTaskFactory, nodeTaskMap, - summarizeTaskInfo, - executor, - failureDetector, - dynamicFilterService); + summarizeTaskInfo); sqlStageExecution.initialize(); return sqlStageExecution; } private SqlStageExecution( + Session session, StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, - boolean summarizeTaskInfo, - Executor executor, - FailureDetector failureDetector, - DynamicFilterService dynamicFilterService) + boolean summarizeTaskInfo) { + this.session = requireNonNull(session, "session is null"); this.stateMachine = stateMachine; this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.summarizeTaskInfo = summarizeTaskInfo; - this.executor = requireNonNull(executor, "executor is null"); - this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); - this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - - ImmutableMap.Builder fragmentToExchangeSource = ImmutableMap.builder(); - for (RemoteSourceNode remoteSourceNode : stateMachine.getFragment().getRemoteSourceNodes()) { - for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) { - fragmentToExchangeSource.put(planFragmentId, remoteSourceNode); - } - } - this.exchangeSources = fragmentToExchangeSource.build(); - if (isEnableCoordinatorDynamicFiltersDistribution(stateMachine.getSession())) { + + if (isEnableCoordinatorDynamicFiltersDistribution(session)) { this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment()); } else { @@ -183,11 +119,6 @@ private SqlStageExecution( private void initialize() { stateMachine.addStateChangeListener(newState -> checkAllTaskFinal()); - stateMachine.addStateChangeListener(newState -> { - if (!newState.canScheduleMoreTasks()) { - dynamicFilterService.stageCannotScheduleMoreTasks(stateMachine.getStageId(), getAllTasks().size()); - } - }); } public StageId getStageId() @@ -195,18 +126,26 @@ public StageId getStageId() return stateMachine.getStageId(); } - public StageState getState() + public synchronized boolean transitionToFinished() { - return stateMachine.getState(); + abortRunningTasks(); + return stateMachine.transitionToFinished(); } - /** - * Listener is always notified asynchronously using a dedicated notification thread pool so, care should - * be taken to avoid leaking {@code this} when adding a listener in a constructor. - */ - public void addStateChangeListener(StateChangeListener stateChangeListener) + public synchronized boolean transitionToFailed(Throwable throwable) + { + requireNonNull(throwable, "throwable is null"); + abortRunningTasks(); + return stateMachine.transitionToFailed(throwable); + } + + private synchronized void abortRunningTasks() { - stateMachine.addStateChangeListener(stateChangeListener); + for (RemoteTask task : tasks.values()) { + if (!task.getTaskStatus().getState().isDone()) { + task.abort(); + } + } } /** @@ -220,69 +159,11 @@ public void addFinalStageInfoListener(StateChangeListener stateChange stateMachine.addFinalStageInfoListener(stateChangeListener); } - public void addCompletedDriverGroupsChangedListener(Consumer> newlyCompletedDriverGroupConsumer) - { - completedLifespansChangeListeners.addListener(newlyCompletedDriverGroupConsumer); - } - public PlanFragment getFragment() { return stateMachine.getFragment(); } - public OutputBuffers getOutputBuffers() - { - return outputBuffers.get(); - } - - public void beginScheduling() - { - stateMachine.transitionToScheduling(); - } - - public synchronized void transitionToSchedulingSplits() - { - stateMachine.transitionToSchedulingSplits(); - } - - public synchronized void schedulingComplete() - { - if (!stateMachine.transitionToScheduled()) { - return; - } - - if (isFlushing()) { - stateMachine.transitionToFlushing(); - } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); - } - - for (PlanNodeId partitionedSource : stateMachine.getFragment().getPartitionedSources()) { - schedulingComplete(partitionedSource); - } - } - - public synchronized void schedulingComplete(PlanNodeId partitionedSource) - { - for (RemoteTask task : getAllTasks()) { - task.noMoreSplits(partitionedSource); - } - completeSources.add(partitionedSource); - } - - public synchronized void cancel() - { - stateMachine.transitionToCanceled(); - getAllTasks().forEach(RemoteTask::cancel); - } - - public synchronized void abort() - { - stateMachine.transitionToAborted(); - getAllTasks().forEach(RemoteTask::abort); - } - public long getUserMemoryReservation() { return stateMachine.getUserMemoryReservation(); @@ -295,7 +176,7 @@ public long getTotalMemoryReservation() public Duration getTotalCpuTime() { - long millis = getAllTasks().stream() + long millis = tasks.values().stream() .mapToLong(task -> task.getTaskInfo().getStats().getTotalCpuTime().toMillis()) .sum(); return new Duration(millis, TimeUnit.MILLISECONDS); @@ -313,65 +194,11 @@ public StageInfo getStageInfo() private Iterable getAllTaskInfo() { - return getAllTasks().stream() + return tasks.values().stream() .map(RemoteTask::getTaskInfo) .collect(toImmutableList()); } - public synchronized void addExchangeLocations(PlanFragmentId fragmentId, Set sourceTasks, boolean noMoreExchangeLocations) - { - requireNonNull(fragmentId, "fragmentId is null"); - requireNonNull(sourceTasks, "sourceTasks is null"); - - RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); - checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); - - this.sourceTasks.putAll(remoteSource.getId(), sourceTasks); - - for (RemoteTask task : getAllTasks()) { - ImmutableMultimap.Builder newSplits = ImmutableMultimap.builder(); - for (RemoteTask sourceTask : sourceTasks) { - URI exchangeLocation = sourceTask.getTaskStatus().getSelf(); - newSplits.put(remoteSource.getId(), createRemoteSplitFor(task.getTaskId(), exchangeLocation)); - } - task.addSplits(newSplits.build()); - } - - if (noMoreExchangeLocations) { - completeSourceFragments.add(fragmentId); - - // is the source now complete? - if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) { - completeSources.add(remoteSource.getId()); - for (RemoteTask task : getAllTasks()) { - task.noMoreSplits(remoteSource.getId()); - } - } - } - } - - public synchronized void setOutputBuffers(OutputBuffers outputBuffers) - { - requireNonNull(outputBuffers, "outputBuffers is null"); - - while (true) { - OutputBuffers currentOutputBuffers = this.outputBuffers.get(); - if (currentOutputBuffers != null) { - if (outputBuffers.getVersion() <= currentOutputBuffers.getVersion()) { - return; - } - currentOutputBuffers.checkValidTransition(outputBuffers); - } - - if (this.outputBuffers.compareAndSet(currentOutputBuffers, outputBuffers)) { - for (RemoteTask task : getAllTasks()) { - task.setOutputBuffers(outputBuffers); - } - return; - } - } - } - // do not synchronize // this is used for query info building which should be independent of scheduling work public boolean hasTasks() @@ -379,115 +206,47 @@ public boolean hasTasks() return !tasks.isEmpty(); } - // do not synchronize - // this is used for query info building which should be independent of scheduling work - public List getAllTasks() - { - return tasks.values().stream() - .flatMap(Set::stream) - .collect(toImmutableList()); - } - - public synchronized Optional scheduleTask(InternalNode node, int partition) + public synchronized Optional createTask( + InternalNode node, + int partition, + Optional bucketToPartition, + OutputBuffers outputBuffers, + Multimap splits, + Multimap noMoreSplitsForLifespan, + Set noMoreSplits) { - requireNonNull(node, "node is null"); - if (stateMachine.getState().isDone()) { return Optional.empty(); } - checkState(!splitsScheduled.get(), "scheduleTask cannot be called once splits have been scheduled"); - return Optional.of(scheduleTask(node, new TaskId(stateMachine.getStageId(), partition), ImmutableMultimap.of())); - } - public synchronized Set scheduleSplits(InternalNode node, Multimap splits, Multimap noMoreSplitsNotification) - { - requireNonNull(node, "node is null"); - requireNonNull(splits, "splits is null"); + TaskId taskId = new TaskId(stateMachine.getStageId(), partition); + checkArgument(!tasks.containsKey(taskId), "A task with id %s already exists", taskId); - if (stateMachine.getState().isDone()) { - return ImmutableSet.of(); - } - splitsScheduled.set(true); - - checkArgument(stateMachine.getFragment().getPartitionedSources().containsAll(splits.keySet()), "Invalid splits"); - - ImmutableSet.Builder newTasks = ImmutableSet.builder(); - Collection tasks = this.tasks.get(node); - RemoteTask task; - if (tasks == null) { - // The output buffer depends on the task id starting from 0 and being sequential, since each - // task is assigned a private buffer based on task id. - TaskId taskId = new TaskId(stateMachine.getStageId(), nextTaskId.getAndIncrement()); - task = scheduleTask(node, taskId, splits); - newTasks.add(task); - } - else { - task = tasks.iterator().next(); - task.addSplits(splits); - } - if (noMoreSplitsNotification.size() > 1) { - // The assumption that `noMoreSplitsNotification.size() <= 1` currently holds. - // If this assumption no longer holds, we should consider calling task.noMoreSplits with multiple entries in one shot. - // These kind of methods can be expensive since they are grabbing locks and/or sending HTTP requests on change. - throw new UnsupportedOperationException("This assumption no longer holds: noMoreSplitsNotification.size() < 1"); - } - for (Entry entry : noMoreSplitsNotification.entries()) { - task.noMoreSplits(entry.getKey(), entry.getValue()); - } - return newTasks.build(); - } - - private synchronized RemoteTask scheduleTask(InternalNode node, TaskId taskId, Multimap sourceSplits) - { - checkArgument(!allTasks.contains(taskId), "A task with id %s already exists", taskId); - - ImmutableMultimap.Builder initialSplits = ImmutableMultimap.builder(); - initialSplits.putAll(sourceSplits); - - sourceTasks.forEach((planNodeId, task) -> { - TaskStatus status = task.getTaskStatus(); - if (status.getState() != TaskState.FINISHED) { - initialSplits.put(planNodeId, createRemoteSplitFor(taskId, status.getSelf())); - } - }); - - OutputBuffers outputBuffers = this.outputBuffers.get(); - checkState(outputBuffers != null, "Initial output buffers must be set before a task can be scheduled"); + stateMachine.transitionToScheduling(); RemoteTask task = remoteTaskFactory.createRemoteTask( - stateMachine.getSession(), + session, taskId, node, - stateMachine.getFragment(), - initialSplits.build(), + stateMachine.getFragment().withBucketToPartition(bucketToPartition), + splits, outputBuffers, nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), outboundDynamicFilterIds, summarizeTaskInfo); - completeSources.forEach(task::noMoreSplits); + noMoreSplitsForLifespan.forEach(task::noMoreSplits); + noMoreSplits.forEach(task::noMoreSplits); + tasks.put(taskId, task); allTasks.add(taskId); - tasks.computeIfAbsent(node, key -> newConcurrentHashSet()).add(task); nodeTaskMap.addTask(node, task); - task.addStateChangeListener(new StageTaskListener()); + task.addStateChangeListener(this::updateTaskStatus); + task.addStateChangeListener(new MemoryUsageListener()); task.addFinalTaskInfoListener(this::updateFinalTaskInfo); - if (!stateMachine.getState().isDone()) { - task.start(); - } - else { - // stage finished while we were scheduling this task - task.abort(); - } - - return task; - } - - public Set getScheduledNodes() - { - return ImmutableSet.copyOf(tasks.keySet()); + return Optional.of(task); } public void recordGetSplitTime(long start) @@ -495,71 +254,20 @@ public void recordGetSplitTime(long start) stateMachine.recordGetSplitTime(start); } - private static Split createRemoteSplitFor(TaskId taskId, URI taskLocation) + private synchronized void updateTaskStatus(TaskStatus status) { - // Fetch the results from the buffer assigned to the task based on id - URI splitLocation = uriBuilderFrom(taskLocation).appendPath("results").appendPath(String.valueOf(taskId.getId())).build(); - return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(splitLocation), Lifespan.taskWide()); - } - - private synchronized void updateTaskStatus(TaskStatus taskStatus) - { - try { - StageState stageState = getState(); - if (stageState.isDone()) { - return; - } - - TaskState taskState = taskStatus.getState(); - - switch (taskState) { - case FAILED: - RuntimeException failure = taskStatus.getFailures().stream() - .findFirst() - .map(this::rewriteTransportFailure) - .map(ExecutionFailureInfo::toException) - .orElse(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); - stateMachine.transitionToFailed(failure); - break; - case ABORTED: - // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) - stateMachine.transitionToFailed(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); - break; - case FLUSHING: - flushingTasks.add(taskStatus.getTaskId()); - break; - case FINISHED: - finishedTasks.add(taskStatus.getTaskId()); - flushingTasks.remove(taskStatus.getTaskId()); - break; - default: - } + if (status.getState().isDone()) { + finishedTasks.add(status.getTaskId()); + } - if (stageState == StageState.SCHEDULED || stageState == StageState.RUNNING || stageState == StageState.FLUSHING) { - if (taskState == TaskState.RUNNING) { - stateMachine.transitionToRunning(); - } - if (isFlushing()) { - stateMachine.transitionToFlushing(); - } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); - } - } + if (!finishedTasks.containsAll(allTasks)) { + stateMachine.transitionToRunning(); } - finally { - // after updating state, check if all tasks have final status information - checkAllTaskFinal(); + else { + stateMachine.transitionToPending(); } } - private synchronized boolean isFlushing() - { - // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. - return !flushingTasks.isEmpty() - && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); - } - private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) { tasksWithFinalInfo.add(finalTaskInfo.getTaskStatus().getTaskId()); @@ -568,70 +276,29 @@ private synchronized void updateFinalTaskInfo(TaskInfo finalTaskInfo) private synchronized void checkAllTaskFinal() { - if (stateMachine.getState().isDone() && tasksWithFinalInfo.containsAll(allTasks)) { - List finalTaskInfos = getAllTasks().stream() + if (stateMachine.getState().isDone() && tasksWithFinalInfo.containsAll(tasks.keySet())) { + List finalTaskInfos = tasks.values().stream() .map(RemoteTask::getTaskInfo) .collect(toImmutableList()); stateMachine.setAllTasksFinal(finalTaskInfos); } } - public List getTaskStatuses() - { - return getAllTasks().stream() - .map(RemoteTask::getTaskStatus) - .collect(toImmutableList()); - } - - public boolean isAnyTaskBlocked() - { - return getTaskStatuses().stream().anyMatch(TaskStatus::isOutputBufferOverutilized); - } - - private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) - { - if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { - return executionFailureInfo; - } - - return new ExecutionFailureInfo( - executionFailureInfo.getType(), - executionFailureInfo.getMessage(), - executionFailureInfo.getCause(), - executionFailureInfo.getSuppressed(), - executionFailureInfo.getStack(), - executionFailureInfo.getErrorLocation(), - REMOTE_HOST_GONE.toErrorCode(), - executionFailureInfo.getRemoteHost()); - } - @Override public String toString() { return stateMachine.toString(); } - private class StageTaskListener + private class MemoryUsageListener implements StateChangeListener { private long previousUserMemory; private long previousSystemMemory; private long previousRevocableMemory; - private final Set completedDriverGroups = new HashSet<>(); @Override - public void stateChanged(TaskStatus taskStatus) - { - try { - updateMemoryUsage(taskStatus); - updateCompletedDriverGroups(taskStatus); - } - finally { - updateTaskStatus(taskStatus); - } - } - - private synchronized void updateMemoryUsage(TaskStatus taskStatus) + public synchronized void stateChanged(TaskStatus taskStatus) { long currentUserMemory = taskStatus.getMemoryReservation().toBytes(); long currentSystemMemory = taskStatus.getSystemMemoryReservation().toBytes(); @@ -644,42 +311,5 @@ private synchronized void updateMemoryUsage(TaskStatus taskStatus) previousRevocableMemory = currentRevocableMemory; stateMachine.updateMemoryUsage(deltaUserMemoryInBytes, deltaRevocableMemoryInBytes, deltaTotalMemoryInBytes); } - - private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus) - { - // Sets.difference returns a view. - // Once we add the difference into `completedDriverGroups`, the view will be empty. - // `completedLifespansChangeListeners.invoke` happens asynchronously. - // As a result, calling the listeners before updating `completedDriverGroups` doesn't make a difference. - // That's why a copy must be made here. - Set newlyCompletedDriverGroups = ImmutableSet.copyOf(Sets.difference(taskStatus.getCompletedDriverGroups(), this.completedDriverGroups)); - if (newlyCompletedDriverGroups.isEmpty()) { - return; - } - completedLifespansChangeListeners.invoke(newlyCompletedDriverGroups, executor); - // newlyCompletedDriverGroups is a view. - // Making changes to completedDriverGroups will change newlyCompletedDriverGroups. - completedDriverGroups.addAll(newlyCompletedDriverGroups); - } - } - - private static class ListenerManager - { - private final List> listeners = new ArrayList<>(); - private boolean frozen; - - public synchronized void addListener(Consumer listener) - { - checkState(!frozen, "Listeners have been invoked"); - listeners.add(listener); - } - - public synchronized void invoke(T payload, Executor executor) - { - frozen = true; - for (Consumer listener : listeners) { - executor.execute(() -> listener.accept(payload)); - } - } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/StageState.java b/core/trino-main/src/main/java/io/trino/execution/StageState.java index 27013ba787a0..162a259931e7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageState.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageState.java @@ -31,36 +31,18 @@ public enum StageState * Stage tasks are being scheduled on nodes. */ SCHEDULING(false, false), - /** - * All stage tasks have been scheduled, but splits are still being scheduled. - */ - SCHEDULING_SPLITS(false, false), - /** - * Stage has been scheduled on nodes and ready to execute, but all tasks are still queued. - */ - SCHEDULED(false, false), /** * Stage is running. */ RUNNING(false, false), /** - * Stage has finished executing and output being consumed. - * In this state, at-least one of the tasks is flushing and the non-flushing tasks are finished + * Stage is finished running existing tasks but more tasks could be scheduled in the future. */ - FLUSHING(false, false), + PENDING(false, false), /** * Stage has finished executing and all output has been consumed. */ FINISHED(true, false), - /** - * Stage was canceled by a user. - */ - CANCELED(true, false), - /** - * Stage was aborted due to a failure in the query. The failure - * was not in this stage. - */ - ABORTED(true, true), /** * Stage execution failed. */ @@ -93,29 +75,4 @@ public boolean isFailure() { return failureState; } - - public boolean canScheduleMoreTasks() - { - switch (this) { - case PLANNED: - case SCHEDULING: - // workers are still being added to the query - return true; - case SCHEDULING_SPLITS: - case SCHEDULED: - case RUNNING: - case FLUSHING: - case FINISHED: - case CANCELED: - // no more workers will be added to the query - return false; - case ABORTED: - case FAILED: - // DO NOT complete a FAILED or ABORTED stage. This will cause the - // stage above to finish normally, which will result in a query - // completing successfully when it should fail.. - return true; - } - throw new IllegalStateException("Unhandled state: " + this); - } } diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java index 7798d8e53921..0665c916c0cf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java @@ -18,7 +18,6 @@ import io.airlift.log.Logger; import io.airlift.stats.Distribution; import io.airlift.units.Duration; -import io.trino.Session; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.operator.BlockedReason; @@ -52,16 +51,12 @@ import static com.google.common.base.Preconditions.checkState; import static io.airlift.units.DataSize.succinctBytes; import static io.airlift.units.Duration.succinctDuration; -import static io.trino.execution.StageState.ABORTED; -import static io.trino.execution.StageState.CANCELED; import static io.trino.execution.StageState.FAILED; import static io.trino.execution.StageState.FINISHED; -import static io.trino.execution.StageState.FLUSHING; +import static io.trino.execution.StageState.PENDING; import static io.trino.execution.StageState.PLANNED; import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; import static io.trino.execution.StageState.SCHEDULING; -import static io.trino.execution.StageState.SCHEDULING_SPLITS; import static io.trino.execution.StageState.TERMINAL_STAGE_STATES; import static java.lang.Math.max; import static java.lang.Math.min; @@ -77,7 +72,6 @@ public class StageStateMachine private final StageId stageId; private final PlanFragment fragment; - private final Session session; private final Map tables; private final SplitSchedulerStats scheduledStats; @@ -96,14 +90,12 @@ public class StageStateMachine public StageStateMachine( StageId stageId, - Session session, PlanFragment fragment, Map tables, ExecutorService executor, SplitSchedulerStats schedulerStats) { this.stageId = requireNonNull(stageId, "stageId is null"); - this.session = requireNonNull(session, "session is null"); this.fragment = requireNonNull(fragment, "fragment is null"); this.tables = ImmutableMap.copyOf(requireNonNull(tables, "tables is null")); this.scheduledStats = requireNonNull(schedulerStats, "schedulerStats is null"); @@ -119,11 +111,6 @@ public StageId getStageId() return stageId; } - public Session getSession() - { - return session; - } - public StageState getState() { return stageState.get(); @@ -144,30 +131,20 @@ public void addStateChangeListener(StateChangeListener stateChangeLi stageState.addStateChangeListener(stateChangeListener); } - public synchronized boolean transitionToScheduling() + public boolean transitionToScheduling() { return stageState.compareAndSet(PLANNED, SCHEDULING); } - public synchronized boolean transitionToSchedulingSplits() - { - return stageState.setIf(SCHEDULING_SPLITS, currentState -> currentState == PLANNED || currentState == SCHEDULING); - } - - public synchronized boolean transitionToScheduled() - { - schedulingComplete.compareAndSet(null, DateTime.now()); - return stageState.setIf(SCHEDULED, currentState -> currentState == PLANNED || currentState == SCHEDULING || currentState == SCHEDULING_SPLITS); - } - public boolean transitionToRunning() { - return stageState.setIf(RUNNING, currentState -> currentState != RUNNING && currentState != FLUSHING && !currentState.isDone()); + schedulingComplete.compareAndSet(null, DateTime.now()); + return stageState.setIf(RUNNING, currentState -> currentState != RUNNING && !currentState.isDone()); } - public boolean transitionToFlushing() + public boolean transitionToPending() { - return stageState.setIf(FLUSHING, currentState -> currentState != FLUSHING && !currentState.isDone()); + return stageState.setIf(PENDING, currentState -> currentState != PENDING && !currentState.isDone()); } public boolean transitionToFinished() @@ -175,16 +152,6 @@ public boolean transitionToFinished() return stageState.setIf(FINISHED, currentState -> !currentState.isDone()); } - public boolean transitionToCanceled() - { - return stageState.setIf(CANCELED, currentState -> !currentState.isDone()); - } - - public boolean transitionToAborted() - { - return stageState.setIf(ABORTED, currentState -> !currentState.isDone()); - } - public boolean transitionToFailed(Throwable throwable) { requireNonNull(throwable, "throwable is null"); @@ -259,7 +226,7 @@ public BasicStageStats getBasicStageStats(Supplier> taskInfos // information, the stage could finish, and the task states would // never be visible. StageState state = stageState.get(); - boolean isScheduled = state == RUNNING || state == FLUSHING || state.isDone(); + boolean isScheduled = state == RUNNING || state == StageState.PENDING || state.isDone(); List taskInfos = ImmutableList.copyOf(taskInfosSupplier.get()); diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStats.java b/core/trino-main/src/main/java/io/trino/execution/StageStats.java index 50d3613ac916..075132c286d7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStats.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStats.java @@ -32,7 +32,6 @@ import java.util.Set; import static com.google.common.base.Preconditions.checkArgument; -import static io.trino.execution.StageState.FLUSHING; import static io.trino.execution.StageState.RUNNING; import static java.lang.Math.min; import static java.util.Objects.requireNonNull; @@ -433,7 +432,7 @@ public List getOperatorSummaries() public BasicStageStats toBasicStageStats(StageState stageState) { - boolean isScheduled = stageState == RUNNING || stageState == FLUSHING || stageState.isDone(); + boolean isScheduled = stageState == RUNNING || stageState == StageState.PENDING || stageState.isDone(); OptionalDouble progressPercentage = OptionalDouble.empty(); if (isScheduled && totalDrivers != 0) { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java index 34e7ed2bdb9c..4acc446b10cc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionPolicy.java @@ -13,15 +13,13 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Collection; public class AllAtOnceExecutionPolicy implements ExecutionPolicy { @Override - public ExecutionSchedule createExecutionSchedule(Collection stages) + public ExecutionSchedule createExecutionSchedule(Collection stages) { return new AllAtOnceExecutionSchedule(stages); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java index 8d46987a8bff..77f0a1ec4bfb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/AllAtOnceExecutionSchedule.java @@ -17,8 +17,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Ordering; -import io.trino.execution.SqlStageExecution; -import io.trino.execution.StageState; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.IndexJoinNode; @@ -43,35 +41,35 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Iterables.getOnlyElement; -import static io.trino.execution.StageState.FLUSHING; -import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.SCHEDULED; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; public class AllAtOnceExecutionSchedule implements ExecutionSchedule { - private final Set schedulingStages; + private final Set schedulingStages; - public AllAtOnceExecutionSchedule(Collection stages) + public AllAtOnceExecutionSchedule(Collection stages) { requireNonNull(stages, "stages is null"); List preferredScheduleOrder = getPreferredScheduleOrder(stages.stream() - .map(SqlStageExecution::getFragment) + .map(StreamingStageExecution::getFragment) .collect(toImmutableList())); - Ordering ordering = Ordering.explicit(preferredScheduleOrder) + Ordering ordering = Ordering.explicit(preferredScheduleOrder) .onResultOf(PlanFragment::getId) - .onResultOf(SqlStageExecution::getFragment); + .onResultOf(StreamingStageExecution::getFragment); schedulingStages = new LinkedHashSet<>(ordering.sortedCopy(stages)); } @Override - public Set getStagesToSchedule() + public Set getStagesToSchedule() { - for (Iterator iterator = schedulingStages.iterator(); iterator.hasNext(); ) { - StageState state = iterator.next().getState(); + for (Iterator iterator = schedulingStages.iterator(); iterator.hasNext(); ) { + StreamingStageExecution.State state = iterator.next().getState(); if (state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone()) { iterator.remove(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java index af4da83997e9..d66cce34772d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/BroadcastOutputBufferManager.java @@ -19,57 +19,46 @@ import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.ThreadSafe; -import java.util.List; -import java.util.function.Consumer; - import static io.trino.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; import static io.trino.execution.buffer.OutputBuffers.BufferType.BROADCAST; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static java.util.Objects.requireNonNull; @ThreadSafe class BroadcastOutputBufferManager implements OutputBufferManager { - private final Consumer outputBufferTarget; - @GuardedBy("this") private OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(BROADCAST); - public BroadcastOutputBufferManager(Consumer outputBufferTarget) - { - this.outputBufferTarget = requireNonNull(outputBufferTarget, "outputBufferTarget is null"); - outputBufferTarget.accept(outputBuffers); - } - @Override - public void addOutputBuffers(List newBuffers, boolean noMoreBuffers) + public synchronized void addOutputBuffer(OutputBufferId newBuffer) { - OutputBuffers newOutputBuffers; - synchronized (this) { - if (outputBuffers.isNoMoreBufferIds()) { - // a stage can move to a final state (e.g., failed) while scheduling, so ignore - // the new buffers - return; - } - - OutputBuffers originalOutputBuffers = outputBuffers; + if (outputBuffers.isNoMoreBufferIds()) { + // a stage can move to a final state (e.g., failed) while scheduling, so ignore + // the new buffers + return; + } - // Note: it does not matter which partition id the task is using, in broadcast all tasks read from the same partition - for (OutputBufferId newBuffer : newBuffers) { - outputBuffers = outputBuffers.withBuffer(newBuffer, BROADCAST_PARTITION_ID); - } + // Note: it does not matter which partition id the task is using, in broadcast all tasks read from the same partition + OutputBuffers newOutputBuffers = outputBuffers.withBuffer(newBuffer, BROADCAST_PARTITION_ID); - if (noMoreBuffers) { - outputBuffers = outputBuffers.withNoMoreBufferIds(); - } + // don't update if nothing changed + if (newOutputBuffers != outputBuffers) { + this.outputBuffers = newOutputBuffers; + } + } - // don't update if nothing changed - if (outputBuffers == originalOutputBuffers) { - return; - } - newOutputBuffers = this.outputBuffers; + @Override + public synchronized void noMoreBuffers() + { + if (!outputBuffers.isNoMoreBufferIds()) { + outputBuffers = outputBuffers.withNoMoreBufferIds(); } - outputBufferTarget.accept(newOutputBuffers); + } + + @Override + public synchronized OutputBuffers getOutputBuffers() + { + return outputBuffers; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java index 46c91e7e919e..4b1dc5300168 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionPolicy.java @@ -13,11 +13,9 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Collection; public interface ExecutionPolicy { - ExecutionSchedule createExecutionSchedule(Collection stages); + ExecutionSchedule createExecutionSchedule(Collection stages); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java index 1b5096d2f5dc..55c3edcf74d5 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ExecutionSchedule.java @@ -13,13 +13,11 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Set; public interface ExecutionSchedule { - Set getStagesToSchedule(); + Set getStagesToSchedule(); boolean isFinished(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java index 6f8bf8e4ed34..1d2466cc1229 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedCountScheduler.java @@ -14,8 +14,8 @@ package io.trino.execution.scheduler; import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMultimap; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.metadata.InternalNode; import java.util.List; @@ -36,10 +36,10 @@ public interface TaskScheduler private final TaskScheduler taskScheduler; private final List partitionToNode; - public FixedCountScheduler(SqlStageExecution stage, List partitionToNode) + public FixedCountScheduler(StreamingStageExecution stageExecution, List partitionToNode) { - requireNonNull(stage, "stage is null"); - this.taskScheduler = stage::scheduleTask; + requireNonNull(stageExecution, "stage is null"); + this.taskScheduler = (node, partition) -> stageExecution.scheduleTask(node, partition, ImmutableMultimap.of(), ImmutableMultimap.of()); this.partitionToNode = requireNonNull(partitionToNode, "partitionToNode is null"); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index 8424f7c9a44d..2d33103ecefa 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -14,13 +14,12 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Streams; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.log.Logger; import io.trino.execution.Lifespan; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.execution.scheduler.ScheduleResult.BlockedReason; import io.trino.execution.scheduler.group.DynamicLifespanScheduler; import io.trino.execution.scheduler.group.FixedLifespanScheduler; @@ -34,22 +33,22 @@ import io.trino.sql.planner.plan.PlanNodeId; import java.util.ArrayList; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsSourceScheduler; import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; -import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; public class FixedSourcePartitionedScheduler @@ -57,15 +56,17 @@ public class FixedSourcePartitionedScheduler { private static final Logger log = Logger.get(FixedSourcePartitionedScheduler.class); - private final SqlStageExecution stage; + private final StreamingStageExecution stageExecution; private final List nodes; private final List sourceSchedulers; private final List partitionHandles; - private boolean scheduledTasks; private final Optional groupedLifespanScheduler; + private final AtomicInteger nextPartitionId; + private final Map scheduledTasks; + public FixedSourcePartitionedScheduler( - SqlStageExecution stage, + StreamingStageExecution stageExecution, Map splitSources, StageExecutionDescriptor stageExecutionDescriptor, List schedulingOrder, @@ -77,19 +78,19 @@ public FixedSourcePartitionedScheduler( List partitionHandles, DynamicFilterService dynamicFilterService) { - requireNonNull(stage, "stage is null"); + requireNonNull(stageExecution, "stageExecution is null"); requireNonNull(splitSources, "splitSources is null"); requireNonNull(bucketNodeMap, "bucketNodeMap is null"); checkArgument(!requireNonNull(nodes, "nodes is null").isEmpty(), "nodes is empty"); requireNonNull(partitionHandles, "partitionHandles is null"); - this.stage = stage; + this.stageExecution = stageExecution; this.nodes = ImmutableList.copyOf(nodes); this.partitionHandles = ImmutableList.copyOf(partitionHandles); checkArgument(splitSources.keySet().equals(ImmutableSet.copyOf(schedulingOrder))); - BucketedSplitPlacementPolicy splitPlacementPolicy = new BucketedSplitPlacementPolicy(nodeSelector, nodes, bucketNodeMap, stage::getAllTasks); + BucketedSplitPlacementPolicy splitPlacementPolicy = new BucketedSplitPlacementPolicy(nodeSelector, nodes, bucketNodeMap, stageExecution::getAllTasks); ArrayList sourceSchedulers = new ArrayList<>(); checkArgument( @@ -106,20 +107,25 @@ public FixedSourcePartitionedScheduler( boolean firstPlanNode = true; Optional groupedLifespanScheduler = Optional.empty(); + + nextPartitionId = new AtomicInteger(); + scheduledTasks = new HashMap<>(); for (PlanNodeId planNodeId : schedulingOrder) { SplitSource splitSource = splitSources.get(planNodeId); boolean groupedExecutionForScanNode = stageExecutionDescriptor.isScanGroupedExecution(planNodeId); // TODO : change anySourceTaskBlocked to accommodate the correct blocked status of source tasks // (ref : https://github.com/trinodb/trino/issues/4713) SourceScheduler sourceScheduler = newSourcePartitionedSchedulerAsSourceScheduler( - stage, + stageExecution, planNodeId, splitSource, splitPlacementPolicy, Math.max(splitBatchSize / concurrentLifespans, 1), groupedExecutionForScanNode, dynamicFilterService, - () -> true); + () -> true, + nextPartitionId, + scheduledTasks); if (stageExecutionDescriptor.isStageGroupedExecution() && !groupedExecutionForScanNode) { sourceScheduler = new AsGroupedSourceScheduler(sourceScheduler); @@ -149,7 +155,7 @@ public FixedSourcePartitionedScheduler( // Schedule the first few lifespans lifespanScheduler.scheduleInitial(sourceScheduler); // Schedule new lifespans for finished ones - stage.addCompletedDriverGroupsChangedListener(lifespanScheduler::onLifespanFinished); + stageExecution.addCompletedDriverGroupsChangedListener(lifespanScheduler::onLifespanFinished); groupedLifespanScheduler = Optional.of(lifespanScheduler); } } @@ -171,14 +177,16 @@ public ScheduleResult schedule() { // schedule a task on every node in the distribution List newTasks = ImmutableList.of(); - if (!scheduledTasks) { - newTasks = Streams.mapWithIndex( - nodes.stream(), - (node, id) -> stage.scheduleTask(node, toIntExact(id))) - .filter(Optional::isPresent) - .map(Optional::get) - .collect(toImmutableList()); - scheduledTasks = true; + if (scheduledTasks.isEmpty()) { + ImmutableList.Builder newTasksBuilder = ImmutableList.builder(); + for (InternalNode node : nodes) { + Optional task = stageExecution.scheduleTask(node, nextPartitionId.getAndIncrement(), ImmutableMultimap.of(), ImmutableMultimap.of()); + if (task.isPresent()) { + scheduledTasks.put(node, task.get()); + newTasksBuilder.add(task.get()); + } + } + newTasks = newTasksBuilder.build(); } boolean allBlocked = true; @@ -222,7 +230,7 @@ public ScheduleResult schedule() driverGroupsToStart = sourceScheduler.drainCompletedLifespans(); if (schedule.isFinished()) { - stage.schedulingComplete(sourceScheduler.getPlanNodeId()); + stageExecution.schedulingComplete(sourceScheduler.getPlanNodeId()); schedulerIterator.remove(); sourceScheduler.close(); shouldInvokeNoMoreDriverGroups = true; diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java index c655e2487426..d3a98bcf3a0f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/OutputBufferManager.java @@ -13,11 +13,14 @@ */ package io.trino.execution.scheduler; +import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; -import java.util.List; - interface OutputBufferManager { - void addOutputBuffers(List newBuffers, boolean noMoreBuffers); + void addOutputBuffer(OutputBufferId newBuffer); + + void noMoreBuffers(); + + OutputBuffers getOutputBuffers(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java index bad4a9e1e162..04dc97ea0980 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PartitionedOutputBufferManager.java @@ -20,10 +20,6 @@ import javax.annotation.concurrent.ThreadSafe; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - import static com.google.common.base.Preconditions.checkArgument; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static java.util.Objects.requireNonNull; @@ -32,9 +28,9 @@ public class PartitionedOutputBufferManager implements OutputBufferManager { - private final Map outputBuffers; + private final OutputBuffers outputBuffers; - public PartitionedOutputBufferManager(PartitioningHandle partitioningHandle, int partitionCount, Consumer outputBufferTarget) + public PartitionedOutputBufferManager(PartitioningHandle partitioningHandle, int partitionCount) { checkArgument(partitionCount >= 1, "partitionCount must be at least 1"); @@ -43,27 +39,33 @@ public PartitionedOutputBufferManager(PartitioningHandle partitioningHandle, int partitions.put(new OutputBufferId(partition), partition); } - OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(requireNonNull(partitioningHandle, "partitioningHandle is null")) + outputBuffers = createInitialEmptyOutputBuffers(requireNonNull(partitioningHandle, "partitioningHandle is null")) .withBuffers(partitions.build()) .withNoMoreBufferIds(); - outputBufferTarget.accept(outputBuffers); - - this.outputBuffers = outputBuffers.getBuffers(); } @Override - public void addOutputBuffers(List newBuffers, boolean noMoreBuffers) + public void addOutputBuffer(OutputBufferId newBuffer) { // All buffers are created in the constructor, so just validate that this isn't // a request to add a new buffer - for (OutputBufferId newBuffer : newBuffers) { - Integer existingBufferId = outputBuffers.get(newBuffer); - if (existingBufferId == null) { - throw new IllegalStateException("Unexpected new output buffer " + newBuffer); - } - if (newBuffer.getId() != existingBufferId) { - throw new IllegalStateException("newOutputBuffers has changed the assignment for task " + newBuffer); - } + Integer existingBufferId = outputBuffers.getBuffers().get(newBuffer); + if (existingBufferId == null) { + throw new IllegalStateException("Unexpected new output buffer " + newBuffer); + } + if (newBuffer.getId() != existingBufferId) { + throw new IllegalStateException("newOutputBuffers has changed the assignment for task " + newBuffer); } } + + @Override + public void noMoreBuffers() + { + } + + @Override + public OutputBuffers getOutputBuffers() + { + return outputBuffers; + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java index 99190392603d..6f5c48019d1b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionPolicy.java @@ -13,15 +13,13 @@ */ package io.trino.execution.scheduler; -import io.trino.execution.SqlStageExecution; - import java.util.Collection; public class PhasedExecutionPolicy implements ExecutionPolicy { @Override - public ExecutionSchedule createExecutionSchedule(Collection stages) + public ExecutionSchedule createExecutionSchedule(Collection stages) { return new PhasedExecutionSchedule(stages); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java index f1b14d289021..9489b1d3cd78 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PhasedExecutionSchedule.java @@ -16,8 +16,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; -import io.trino.execution.SqlStageExecution; -import io.trino.execution.StageState; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.IndexJoinNode; @@ -50,9 +48,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static io.trino.execution.StageState.FLUSHING; -import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.SCHEDULED; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static java.util.function.Function.identity; @@ -60,14 +58,14 @@ public class PhasedExecutionSchedule implements ExecutionSchedule { - private final List> schedulePhases; - private final Set activeSources = new HashSet<>(); + private final List> schedulePhases; + private final Set activeSources = new HashSet<>(); - public PhasedExecutionSchedule(Collection stages) + public PhasedExecutionSchedule(Collection stages) { - List> phases = extractPhases(stages.stream().map(SqlStageExecution::getFragment).collect(toImmutableList())); + List> phases = extractPhases(stages.stream().map(StreamingStageExecution::getFragment).collect(toImmutableList())); - Map stagesByFragmentId = stages.stream().collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); + Map stagesByFragmentId = stages.stream().collect(toImmutableMap(stage -> stage.getFragment().getId(), identity())); // create a mutable list of mutable sets of stages, so we can remove completed stages schedulePhases = new ArrayList<>(); @@ -79,7 +77,7 @@ public PhasedExecutionSchedule(Collection stages) } @Override - public Set getStagesToSchedule() + public Set getStagesToSchedule() { removeCompletedStages(); addPhasesIfNecessary(); @@ -91,8 +89,8 @@ public Set getStagesToSchedule() private void removeCompletedStages() { - for (Iterator stageIterator = activeSources.iterator(); stageIterator.hasNext(); ) { - StageState state = stageIterator.next().getState(); + for (Iterator stageIterator = activeSources.iterator(); stageIterator.hasNext(); ) { + StreamingStageExecution.State state = stageIterator.next().getState(); if (state == SCHEDULED || state == RUNNING || state == FLUSHING || state.isDone()) { stageIterator.remove(); } @@ -107,7 +105,7 @@ private void addPhasesIfNecessary() } while (!schedulePhases.isEmpty()) { - Set phase = schedulePhases.remove(0); + Set phase = schedulePhases.remove(0); activeSources.addAll(phase); if (hasSourceDistributedStage(phase)) { return; @@ -115,7 +113,7 @@ private void addPhasesIfNecessary() } } - private static boolean hasSourceDistributedStage(Set phase) + private static boolean hasSourceDistributedStage(Set phase) { return phase.stream().anyMatch(stage -> !stage.getFragment().getPartitionedSources().isEmpty()); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ResultsConsumer.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ResultsConsumer.java new file mode 100644 index 000000000000..c066d08c8495 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ResultsConsumer.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.trino.execution.RemoteTask; +import io.trino.sql.planner.plan.PlanFragmentId; + +public interface ResultsConsumer +{ + void addSourceTask(PlanFragmentId fragmentId, RemoteTask sourceTask); + + void noMoreSourceTasks(PlanFragmentId fragmentId); +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java index 4af58c3c7990..117477a30161 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledOutputBufferManager.java @@ -18,55 +18,45 @@ import javax.annotation.concurrent.GuardedBy; -import java.util.List; -import java.util.function.Consumer; - import static io.trino.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static java.util.Objects.requireNonNull; public class ScaledOutputBufferManager implements OutputBufferManager { - private final Consumer outputBufferTarget; - @GuardedBy("this") private OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(ARBITRARY); - public ScaledOutputBufferManager(Consumer outputBufferTarget) - { - this.outputBufferTarget = requireNonNull(outputBufferTarget, "outputBufferTarget is null"); - outputBufferTarget.accept(outputBuffers); - } - @SuppressWarnings("ObjectEquality") @Override - public void addOutputBuffers(List newBuffers, boolean noMoreBuffers) + public synchronized void addOutputBuffer(OutputBufferId newBuffer) { - OutputBuffers newOutputBuffers; - synchronized (this) { - if (outputBuffers.isNoMoreBufferIds()) { - // a stage can move to a final state (e.g., failed) while scheduling, - // so ignore the new buffers - return; - } - - OutputBuffers originalOutputBuffers = outputBuffers; + if (outputBuffers.isNoMoreBufferIds()) { + // a stage can move to a final state (e.g., failed) while scheduling, so ignore + // the new buffers + return; + } - for (OutputBufferId newBuffer : newBuffers) { - outputBuffers = outputBuffers.withBuffer(newBuffer, newBuffer.getId()); - } + // Note: it does not matter which partition id the task is using, in broadcast all tasks read from the same partition + OutputBuffers newOutputBuffers = outputBuffers.withBuffer(newBuffer, newBuffer.getId()); - if (noMoreBuffers) { - outputBuffers = outputBuffers.withNoMoreBufferIds(); - } + // don't update if nothing changed + if (newOutputBuffers != outputBuffers) { + this.outputBuffers = newOutputBuffers; + } + } - // don't update if nothing changed - if (outputBuffers == originalOutputBuffers) { - return; - } - newOutputBuffers = this.outputBuffers; + @Override + public synchronized void noMoreBuffers() + { + if (!outputBuffers.isNoMoreBufferIds()) { + outputBuffers = outputBuffers.withNoMoreBufferIds(); } - outputBufferTarget.accept(newOutputBuffers); + } + + @Override + public synchronized OutputBuffers getOutputBuffers() + { + return outputBuffers; } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java index 2bfe6d619e8d..5e5106fce23a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java @@ -14,10 +14,10 @@ package io.trino.execution.scheduler; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMultimap; import com.google.common.util.concurrent.SettableFuture; import io.airlift.units.DataSize; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.execution.TaskStatus; import io.trino.metadata.InternalNode; @@ -39,7 +39,7 @@ public class ScaledWriterScheduler implements StageScheduler { - private final SqlStageExecution stage; + private final StreamingStageExecution stage; private final Supplier> sourceTasksProvider; private final Supplier> writerTasksProvider; private final NodeSelector nodeSelector; @@ -50,7 +50,7 @@ public class ScaledWriterScheduler private volatile SettableFuture future = SettableFuture.create(); public ScaledWriterScheduler( - SqlStageExecution stage, + StreamingStageExecution stage, Supplier> sourceTasksProvider, Supplier> writerTasksProvider, NodeSelector nodeSelector, @@ -119,7 +119,7 @@ private List scheduleTasks(int count) ImmutableList.Builder tasks = ImmutableList.builder(); for (InternalNode node : nodes) { - Optional remoteTask = stage.scheduleTask(node, scheduledNodes.size()); + Optional remoteTask = stage.scheduleTask(node, scheduledNodes.size(), ImmutableMultimap.of(), ImmutableMultimap.of()); remoteTask.ifPresent(task -> { tasks.add(task); scheduledNodes.add(node); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java index 1697ade417b4..427267cb3fe2 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SourcePartitionedScheduler.java @@ -22,7 +22,6 @@ import com.google.common.util.concurrent.SettableFuture; import io.trino.execution.Lifespan; import io.trino.execution.RemoteTask; -import io.trino.execution.SqlStageExecution; import io.trino.execution.scheduler.FixedSourcePartitionedScheduler.BucketedSplitPlacementPolicy; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; @@ -40,14 +39,15 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BooleanSupplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; -import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; @@ -89,7 +89,7 @@ private enum State FINISHED } - private final SqlStageExecution stage; + private final StreamingStageExecution stageExecution; private final SplitSource splitSource; private final SplitPlacementPolicy splitPlacementPolicy; private final int splitBatchSize; @@ -97,6 +97,8 @@ private enum State private final boolean groupedExecution; private final DynamicFilterService dynamicFilterService; private final BooleanSupplier anySourceTaskBlocked; + private final AtomicInteger nextPartitionId; + private final Map scheduledTasks; private final Map scheduleGroups = new HashMap<>(); private boolean noMoreScheduleGroups; @@ -105,25 +107,28 @@ private enum State private SettableFuture whenFinishedOrNewLifespanAdded = SettableFuture.create(); private SourcePartitionedScheduler( - SqlStageExecution stage, + StreamingStageExecution stageExecution, PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, int splitBatchSize, boolean groupedExecution, DynamicFilterService dynamicFilterService, - BooleanSupplier anySourceTaskBlocked) + BooleanSupplier anySourceTaskBlocked, + AtomicInteger nextPartitionId, + Map scheduledTasks) { - this.stage = requireNonNull(stage, "stage is null"); - this.partitionedNode = requireNonNull(partitionedNode, "partitionedNode is null"); + this.stageExecution = requireNonNull(stageExecution, "stageExecution is null"); this.splitSource = requireNonNull(splitSource, "splitSource is null"); this.splitPlacementPolicy = requireNonNull(splitPlacementPolicy, "splitPlacementPolicy is null"); - this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - this.anySourceTaskBlocked = requireNonNull(anySourceTaskBlocked, "anySourceTaskBlocked is null"); - checkArgument(splitBatchSize > 0, "splitBatchSize must be at least one"); this.splitBatchSize = splitBatchSize; + this.partitionedNode = requireNonNull(partitionedNode, "partitionedNode is null"); this.groupedExecution = groupedExecution; + this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); + this.anySourceTaskBlocked = requireNonNull(anySourceTaskBlocked, "anySourceTaskBlocked is null"); + this.nextPartitionId = requireNonNull(nextPartitionId, "nextPartitionId is null"); + this.scheduledTasks = requireNonNull(scheduledTasks, "scheduledTasks is null"); } @Override @@ -140,7 +145,7 @@ public PlanNodeId getPlanNodeId() * minimal management from the caller, which is ideal for use as a stage scheduler. */ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( - SqlStageExecution stage, + StreamingStageExecution stageExecution, PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, @@ -149,14 +154,16 @@ public static StageScheduler newSourcePartitionedSchedulerAsStageScheduler( BooleanSupplier anySourceTaskBlocked) { SourcePartitionedScheduler sourcePartitionedScheduler = new SourcePartitionedScheduler( - stage, + stageExecution, partitionedNode, splitSource, splitPlacementPolicy, splitBatchSize, false, dynamicFilterService, - anySourceTaskBlocked); + anySourceTaskBlocked, + new AtomicInteger(), + new HashMap<>()); sourcePartitionedScheduler.startLifespan(Lifespan.taskWide(), NOT_PARTITIONED); sourcePartitionedScheduler.noMoreLifespans(); @@ -190,24 +197,28 @@ public void close() * transitioning of the object will not work properly. */ public static SourceScheduler newSourcePartitionedSchedulerAsSourceScheduler( - SqlStageExecution stage, + StreamingStageExecution stageExecution, PlanNodeId partitionedNode, SplitSource splitSource, SplitPlacementPolicy splitPlacementPolicy, int splitBatchSize, boolean groupedExecution, DynamicFilterService dynamicFilterService, - BooleanSupplier anySourceTaskBlocked) + BooleanSupplier anySourceTaskBlocked, + AtomicInteger nextPartitionId, + Map scheduledTasks) { return new SourcePartitionedScheduler( - stage, + stageExecution, partitionedNode, splitSource, splitPlacementPolicy, splitBatchSize, groupedExecution, dynamicFilterService, - anySourceTaskBlocked); + anySourceTaskBlocked, + nextPartitionId, + scheduledTasks); } @Override @@ -257,7 +268,7 @@ else if (pendingSplits.isEmpty()) { scheduleGroup.nextSplitBatchFuture = splitSource.getNextBatch(scheduleGroup.partitionHandle, lifespan, splitBatchSize - pendingSplits.size()); long start = System.nanoTime(); - addSuccessCallback(scheduleGroup.nextSplitBatchFuture, () -> stage.recordGetSplitTime(start)); + addSuccessCallback(scheduleGroup.nextSplitBatchFuture, () -> stageExecution.recordGetSplitTime(start)); } if (scheduleGroup.nextSplitBatchFuture.isDone()) { @@ -377,17 +388,17 @@ else if (pendingSplits.isEmpty()) { } if (anyBlockedOnNextSplitBatch - && stage.getScheduledNodes().isEmpty() - && dynamicFilterService.isCollectingTaskNeeded(stage.getStageId().getQueryId(), stage.getFragment())) { + && scheduledTasks.isEmpty() + && dynamicFilterService.isCollectingTaskNeeded(stageExecution.getStageId().getQueryId(), stageExecution.getFragment())) { // schedule a task for collecting dynamic filters in case probe split generator is waiting for them - overallNewTasks.addAll(createTaskOnRandomNode()); + createTaskOnRandomNode().ifPresent(overallNewTasks::add); } boolean anySourceTaskBlocked = this.anySourceTaskBlocked.getAsBoolean(); if (anySourceTaskBlocked) { // Dynamic filters might not be collected due to build side source tasks being blocked on full buffer. // In such case probe split generation that is waiting for dynamic filters should be unblocked to prevent deadlock. - dynamicFilterService.unblockStageDynamicFilters(stage.getStageId().getQueryId(), stage.getFragment()); + dynamicFilterService.unblockStageDynamicFilters(stageExecution.getStageId().getQueryId(), stageExecution.getFragment()); } if (groupedExecution) { @@ -499,42 +510,55 @@ private Set assignSplits(Multimap splitAssignme if (noMoreSplitsNotification.containsKey(node)) { noMoreSplits.putAll(partitionedNode, noMoreSplitsNotification.get(node)); } - newTasks.addAll(stage.scheduleSplits( - node, - splits, - noMoreSplits.build())); + RemoteTask task = scheduledTasks.get(node); + if (task != null) { + task.addSplits(splits); + noMoreSplits.build().forEach(task::noMoreSplits); + } + else { + scheduleTask(node, splits, noMoreSplits.build()).ifPresent(newTasks::add); + } } return newTasks.build(); } - private Set createTaskOnRandomNode() + private Optional createTaskOnRandomNode() { - checkState(stage.getScheduledNodes().isEmpty(), "Stage task is already scheduled on node"); + checkState(scheduledTasks.isEmpty(), "Stage task is already scheduled on node"); List allNodes = splitPlacementPolicy.allNodes(); checkState(allNodes.size() > 0, "No nodes available"); InternalNode node = allNodes.get(ThreadLocalRandom.current().nextInt(0, allNodes.size())); - return stage.scheduleSplits(node, ImmutableMultimap.of(), ImmutableMultimap.of()); + return scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()); } private Set finalizeTaskCreationIfNecessary() { // only lock down tasks if there is a sub stage that could block waiting for this stage to create all tasks - if (stage.getFragment().isLeaf()) { + if (stageExecution.getFragment().isLeaf()) { return ImmutableSet.of(); } splitPlacementPolicy.lockDownNodes(); - Set scheduledNodes = stage.getScheduledNodes(); - Set newTasks = splitPlacementPolicy.allNodes().stream() - .filter(node -> !scheduledNodes.contains(node)) - .flatMap(node -> stage.scheduleSplits(node, ImmutableMultimap.of(), ImmutableMultimap.of()).stream()) - .collect(toImmutableSet()); + ImmutableSet.Builder newTasks = ImmutableSet.builder(); + for (InternalNode node : splitPlacementPolicy.allNodes()) { + if (scheduledTasks.containsKey(node)) { + continue; + } + scheduleTask(node, ImmutableMultimap.of(), ImmutableMultimap.of()).ifPresent(newTasks::add); + } // notify listeners that we have scheduled all tasks so they can set no more buffers or exchange splits - stage.transitionToSchedulingSplits(); + stageExecution.transitionToSchedulingSplits(); - return newTasks; + return newTasks.build(); + } + + private Optional scheduleTask(InternalNode node, Multimap initialSplits, Multimap noMoreSplitsForLifespan) + { + Optional remoteTask = stageExecution.scheduleTask(node, nextPartitionId.getAndIncrement(), initialSplits, noMoreSplitsForLifespan); + remoteTask.ifPresent(task -> scheduledTasks.put(node, task)); + return remoteTask; } private static class ScheduleGroup diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java index d20407c34745..d85b1998a362 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/SqlQueryScheduler.java @@ -13,11 +13,13 @@ */ package io.trino.execution.scheduler; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; +import com.google.common.graph.Traverser; import com.google.common.primitives.Ints; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; @@ -27,6 +29,7 @@ import io.trino.Session; import io.trino.connector.CatalogName; import io.trino.execution.BasicStageStats; +import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.NodeTaskMap; import io.trino.execution.QueryState; import io.trino.execution.QueryStateMachine; @@ -35,38 +38,41 @@ import io.trino.execution.SqlStageExecution; import io.trino.execution.StageId; import io.trino.execution.StageInfo; -import io.trino.execution.StageState; import io.trino.execution.TaskStatus; -import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.failuredetector.FailureDetector; import io.trino.metadata.InternalNode; import io.trino.server.DynamicFilterService; +import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorPartitionHandle; import io.trino.split.SplitSource; import io.trino.sql.planner.NodePartitionMap; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.StageExecutionPlan; import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.TableScanNode; import java.net.URI; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; @@ -77,6 +83,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.newConcurrentHashSet; import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; @@ -86,22 +93,25 @@ import static io.trino.connector.CatalogName.isInternalSystemConnector; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; import static io.trino.execution.SqlStageExecution.createSqlStageExecution; -import static io.trino.execution.StageState.ABORTED; -import static io.trino.execution.StageState.CANCELED; -import static io.trino.execution.StageState.FAILED; -import static io.trino.execution.StageState.FINISHED; -import static io.trino.execution.StageState.FLUSHING; -import static io.trino.execution.StageState.RUNNING; -import static io.trino.execution.StageState.SCHEDULED; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; +import static io.trino.execution.scheduler.StreamingStageExecution.State.ABORTED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.CANCELED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FAILED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FINISHED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.SCHEDULED; +import static io.trino.execution.scheduler.StreamingStageExecution.createStreamingStageExecution; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; import static io.trino.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static io.trino.util.Failures.checkCondition; +import static java.lang.Integer.parseInt; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -112,16 +122,26 @@ public class SqlQueryScheduler { private final QueryStateMachine queryStateMachine; - private final ExecutionPolicy executionPolicy; - private final Map stages; + private final StageExecutionPlan plan; + private final NodePartitioningManager nodePartitioningManager; + private final NodeScheduler nodeScheduler; + private final int splitBatchSize; private final ExecutorService executor; - private final StageId rootStageId; - private final Map stageSchedulers; - private final Map stageLinkages; + private final ScheduledExecutorService schedulerExecutor; + private final FailureDetector failureDetector; + private final ExecutionPolicy executionPolicy; private final SplitSchedulerStats schedulerStats; - private final boolean summarizeTaskInfo; private final DynamicFilterService dynamicFilterService; + + private final Map stages; + private final StageId rootStageId; + private final Map> stageLineage; + + // all stages that could be scheduled remotely (excluding coordinator only stages) + private final Set remotelyScheduledStages; + private final AtomicBoolean started = new AtomicBoolean(); + private final AtomicReference scheduler = new AtomicReference<>(); public static SqlQueryScheduler createSqlQueryScheduler( QueryStateMachine queryStateMachine, @@ -129,13 +149,11 @@ public static SqlQueryScheduler createSqlQueryScheduler( NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, - Session session, boolean summarizeTaskInfo, int splitBatchSize, ExecutorService queryExecutor, ScheduledExecutorService schedulerExecutor, FailureDetector failureDetector, - OutputBuffers rootOutputBuffers, NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy, SplitSchedulerStats schedulerStats, @@ -147,13 +165,11 @@ public static SqlQueryScheduler createSqlQueryScheduler( nodePartitioningManager, nodeScheduler, remoteTaskFactory, - session, summarizeTaskInfo, splitBatchSize, queryExecutor, schedulerExecutor, failureDetector, - rootOutputBuffers, nodeTaskMap, executionPolicy, schedulerStats, @@ -168,67 +184,238 @@ private SqlQueryScheduler( NodePartitioningManager nodePartitioningManager, NodeScheduler nodeScheduler, RemoteTaskFactory remoteTaskFactory, - Session session, boolean summarizeTaskInfo, int splitBatchSize, ExecutorService queryExecutor, ScheduledExecutorService schedulerExecutor, FailureDetector failureDetector, - OutputBuffers rootOutputBuffers, NodeTaskMap nodeTaskMap, ExecutionPolicy executionPolicy, SplitSchedulerStats schedulerStats, DynamicFilterService dynamicFilterService) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.plan = requireNonNull(plan, "plan is null"); + this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); + this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); + this.splitBatchSize = splitBatchSize; + this.executor = requireNonNull(queryExecutor, "queryExecutor is null"); + this.schedulerExecutor = requireNonNull(schedulerExecutor, "schedulerExecutor is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); this.executionPolicy = requireNonNull(executionPolicy, "executionPolicy is null"); this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); - this.summarizeTaskInfo = summarizeTaskInfo; this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); - // todo come up with a better way to build this, or eliminate this map - ImmutableMap.Builder stageSchedulers = ImmutableMap.builder(); - ImmutableMap.Builder stageLinkages = ImmutableMap.builder(); - - // Only fetch a distribution once per query to assure all stages see the same machine assignments - Map partitioningCache = new HashMap<>(); - - OutputBufferId rootBufferId = Iterables.getOnlyElement(rootOutputBuffers.getBuffers().keySet()); - List stages = createStages( - (fragmentId, tasks, noMoreExchangeLocations) -> updateQueryOutputLocations(queryStateMachine, rootBufferId, tasks, noMoreExchangeLocations), - new AtomicInteger(), - plan.withBucketToPartition(Optional.of(new int[1])), - nodeScheduler, + stages = createStages( + queryStateMachine.getSession(), remoteTaskFactory, - session, - splitBatchSize, - partitioningHandle -> partitioningCache.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(session, handle)), - nodePartitioningManager, - queryExecutor, - schedulerExecutor, - failureDetector, nodeTaskMap, - stageSchedulers, - stageLinkages); + queryExecutor, + schedulerStats, + plan, + summarizeTaskInfo); + rootStageId = getStageId(queryStateMachine.getQueryId(), plan.getFragment().getId()); + stageLineage = getStageLineage(queryStateMachine.getQueryId(), plan); + remotelyScheduledStages = stages.values().stream() + .filter(stage -> !stage.getFragment().getPartitioning().isCoordinatorOnly()) + .map(SqlStageExecution::getStageId) + .collect(toImmutableSet()); + } - SqlStageExecution rootStage = stages.get(0); - rootStage.setOutputBuffers(rootOutputBuffers); - this.rootStageId = rootStage.getStageId(); + // this is a separate method to ensure that the `this` reference is not leaked during construction + private void initialize() + { + // when query is done or any time a stage completes, attempt to transition query to "final query info ready" + queryStateMachine.addStateChangeListener(newState -> { + if (newState.isDone()) { + stages.values().forEach(SqlStageExecution::transitionToFinished); + queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo())); + } + }); + for (SqlStageExecution stage : stages.values()) { + stage.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo()))); + } + } - this.stages = stages.stream() - .collect(toImmutableMap(SqlStageExecution::getStageId, identity())); + private static Map createStages( + Session session, + RemoteTaskFactory taskFactory, + NodeTaskMap nodeTaskMap, + ExecutorService executor, + SplitSchedulerStats schedulerStats, + StageExecutionPlan planTree, + boolean summarizeTaskInfo) + { + ImmutableMap.Builder result = ImmutableMap.builder(); + for (StageExecutionPlan planNode : Traverser.forTree(StageExecutionPlan::getSubStages).breadthFirst(planTree)) { + PlanFragment fragment = planNode.getFragment(); + SqlStageExecution stageExecution = createSqlStageExecution( + getStageId(session.getQueryId(), fragment.getId()), + fragment, + planNode.getTables(), + taskFactory, + session, + summarizeTaskInfo, + nodeTaskMap, + executor, + schedulerStats); + result.put(fragment.getId(), stageExecution); + } + return result.build(); + } - this.stageSchedulers = stageSchedulers.build(); - this.stageLinkages = stageLinkages.build(); + private static Map> getStageLineage(QueryId queryId, StageExecutionPlan planTree) + { + ImmutableMap.Builder> result = ImmutableMap.builder(); + for (StageExecutionPlan planNode : Traverser.forTree(StageExecutionPlan::getSubStages).breadthFirst(planTree)) { + result.put( + getStageId(queryId, planNode.getFragment().getId()), + planNode.getSubStages().stream() + .map(stage -> getStageId(queryId, stage.getFragment().getId())) + .collect(toImmutableSet())); + } + return result.build(); + } - this.executor = queryExecutor; + private static StageId getStageId(QueryId queryId, PlanFragmentId fragmentId) + { + // TODO: refactor fragment id to be based on an integer + return new StageId(queryId, parseInt(fragmentId.toString())); } - // this is a separate method to ensure that the `this` reference is not leaked during construction - private void initialize() + public synchronized void start() + { + if (started.compareAndSet(false, true)) { + if (queryStateMachine.isDone()) { + return; + } + StreamingScheduler streamingScheduler = createStreamingScheduler(new ResultsConsumer() + { + @Override + public void addSourceTask(PlanFragmentId fragmentId, RemoteTask task) + { + Set bufferLocations = ImmutableSet.of(uriBuilderFrom(task.getTaskStatus().getSelf()) + .appendPath("results") + .appendPath("0").build()); + queryStateMachine.updateOutputLocations(bufferLocations, false); + } + + @Override + public void noMoreSourceTasks(PlanFragmentId fragmentId) + { + queryStateMachine.updateOutputLocations(ImmutableSet.of(), true); + } + }); + scheduler.set(streamingScheduler); + executor.submit(streamingScheduler::schedule); + } + } + + public synchronized void cancelStage(StageId stageId) + { + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + StreamingScheduler scheduler = this.scheduler.get(); + if (scheduler != null) { + scheduler.cancelStage(stageId); + } + } + } + + public synchronized void abort() + { + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + StreamingScheduler scheduler = this.scheduler.get(); + if (scheduler != null) { + scheduler.abort(); + } + stages.values().forEach(SqlStageExecution::transitionToFinished); + } + } + + public BasicStageStats getBasicStageStats() + { + List stageStats = stages.values().stream() + .map(SqlStageExecution::getBasicStageStats) + .collect(toImmutableList()); + + return aggregateBasicStageStats(stageStats); + } + + public StageInfo getStageInfo() + { + Map stageInfos = stages.values().stream() + .map(SqlStageExecution::getStageInfo) + .collect(toImmutableMap(StageInfo::getStageId, identity())); + + return buildStageInfo(rootStageId, stageInfos); + } + + private StageInfo buildStageInfo(StageId stageId, Map stageInfos) { - SqlStageExecution rootStage = stages.get(rootStageId); - rootStage.addStateChangeListener(state -> { + StageInfo parent = stageInfos.get(stageId); + checkArgument(parent != null, "No stageInfo for %s", parent); + List childStages = stageLineage.get(stageId).stream() + .map(childStageId -> buildStageInfo(childStageId, stageInfos)) + .collect(toImmutableList()); + if (childStages.isEmpty()) { + return parent; + } + return new StageInfo( + parent.getStageId(), + parent.getState(), + parent.getPlan(), + parent.getTypes(), + parent.getStageStats(), + parent.getTasks(), + childStages, + parent.getTables(), + parent.getFailureCause()); + } + + public long getUserMemoryReservation() + { + return stages.values().stream() + .mapToLong(SqlStageExecution::getUserMemoryReservation) + .sum(); + } + + public long getTotalMemoryReservation() + { + return stages.values().stream() + .mapToLong(SqlStageExecution::getTotalMemoryReservation) + .sum(); + } + + public Duration getTotalCpuTime() + { + long millis = stages.values().stream() + .mapToLong(stage -> stage.getTotalCpuTime().toMillis()) + .sum(); + return new Duration(millis, MILLISECONDS); + } + + private StreamingScheduler createStreamingScheduler(ResultsConsumer resultsConsumer) + { + Session session = queryStateMachine.getSession(); + Map partitioningCacheMap = new HashMap<>(); + Function partitioningCache = partitioningHandle -> + partitioningCacheMap.computeIfAbsent(partitioningHandle, handle -> nodePartitioningManager.getNodePartitioningMap(session, handle)); + Map> bucketToPartitionMap = createBucketToPartitionMap(plan, partitioningCache); + Map outputBufferManagers = createOutputBufferManagers(stages.values(), bucketToPartitionMap); + ImmutableList.Builder executionsBuilder = ImmutableList.builder(); + ImmutableMap.Builder schedulersBuilder = ImmutableMap.builder(); + StreamingStageExecution rootExecution = createStreamingExecution( + plan, + outputBufferManagers, + bucketToPartitionMap, + partitioningCache, + resultsConsumer, + executionsBuilder, + schedulersBuilder); + List executions = executionsBuilder.build(); + ExecutionSchedule executionSchedule = executionPolicy.createExecutionSchedule(executions); + + rootExecution.addStateChangeListener(state -> { if (state == FINISHED) { queryStateMachine.transitionToFinishing(); } @@ -238,13 +425,21 @@ else if (state == CANCELED) { } }); - for (SqlStageExecution stage : stages.values()) { - stage.addStateChangeListener(state -> { + Set finishedStages = newConcurrentHashSet(); + for (StreamingStageExecution execution : executions) { + execution.addStateChangeListener(state -> { if (queryStateMachine.isDone()) { return; } + if (!state.canScheduleMoreTasks()) { + dynamicFilterService.stageCannotScheduleMoreTasks(execution.getStageId(), execution.getAllTasks().size()); + } if (state == FAILED) { - queryStateMachine.transitionToFailed(stage.getStageInfo().getFailureCause().toException()); + RuntimeException failureCause = execution.getFailureCause() + .map(ExecutionFailureInfo::toException) + .orElseGet(() -> new VerifyException(format("stage execution for stage %s is failed by failure cause is not present", execution.getStageId()))); + stages.get(execution.getFragment().getId()).transitionToFailed(failureCause); + queryStateMachine.transitionToFailed(failureCause); } else if (state == ABORTED) { // this should never happen, since abort can only be triggered in query clean up after the query is finished @@ -252,148 +447,137 @@ else if (state == ABORTED) { } else if (queryStateMachine.getQueryState() == QueryState.STARTING) { // if the stage has at least one task, we are running - if (stage.hasTasks()) { + if (!execution.getAllTasks().isEmpty()) { queryStateMachine.transitionToRunning(); } } + else if (state.isDone() && !state.isFailure()) { + finishedStages.add(execution.getStageId()); + // Once all remotely scheduled stages complete it should be safe to transition stage execution + // to the finished state as at this point no further task retries are expected + // This is needed to make explain analyze work that requires final stage info to be available before the + // explain analyze stage is finished + if (finishedStages.containsAll(remotelyScheduledStages)) { + stages.values().stream() + .filter(stage -> finishedStages.contains(stage.getStageId())) + .forEach(SqlStageExecution::transitionToFinished); + } + } }); } - // when query is done or any time a stage completes, attempt to transition query to "final query info ready" - queryStateMachine.addStateChangeListener(newState -> { - if (newState.isDone()) { - queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo())); - } - }); - for (SqlStageExecution stage : stages.values()) { - stage.addFinalStageInfoListener(status -> queryStateMachine.updateQueryInfo(Optional.ofNullable(getStageInfo()))); - } - } - - private static void updateQueryOutputLocations(QueryStateMachine queryStateMachine, OutputBufferId rootBufferId, Set tasks, boolean noMoreExchangeLocations) - { - Set bufferLocations = tasks.stream() - .map(task -> task.getTaskStatus().getSelf()) - .map(location -> uriBuilderFrom(location).appendPath("results").appendPath(rootBufferId.toString()).build()) - .collect(toImmutableSet()); - queryStateMachine.updateOutputLocations(bufferLocations, noMoreExchangeLocations); + return new StreamingScheduler( + queryStateMachine, + executionSchedule, + schedulersBuilder.build(), + schedulerStats, + executions); } - private List createStages( - ExchangeLocationsConsumer parent, - AtomicInteger nextStageId, + private StreamingStageExecution createStreamingExecution( StageExecutionPlan plan, - NodeScheduler nodeScheduler, - RemoteTaskFactory remoteTaskFactory, - Session session, - int splitBatchSize, + Map outputBufferManagers, + Map> bucketToPartitionMap, Function partitioningCache, - NodePartitioningManager nodePartitioningManager, - ExecutorService queryExecutor, - ScheduledExecutorService schedulerExecutor, - FailureDetector failureDetector, - NodeTaskMap nodeTaskMap, - ImmutableMap.Builder stageSchedulers, - ImmutableMap.Builder stageLinkages) + ResultsConsumer resultsConsumer, + ImmutableList.Builder executions, + ImmutableMap.Builder schedulers) { - ImmutableList.Builder stages = ImmutableList.builder(); - - StageId stageId = new StageId(queryStateMachine.getQueryId(), nextStageId.getAndIncrement()); - SqlStageExecution stage = createSqlStageExecution( - stageId, - plan.getFragment(), - plan.getTables(), - remoteTaskFactory, - session, - summarizeTaskInfo, - nodeTaskMap, - queryExecutor, + PlanFragment fragment = plan.getFragment(); + StreamingStageExecution execution = createStreamingStageExecution( + stages.get(fragment.getId()), + outputBufferManagers, + resultsConsumer, failureDetector, - dynamicFilterService, - schedulerStats); - stages.add(stage); - - // function to create child stages recursively by supplying the bucket partitioning (according to parent's partitioning) - Function, Set> createChildStages = bucketToPartition -> { - ImmutableSet.Builder childStagesBuilder = ImmutableSet.builder(); - for (StageExecutionPlan subStagePlan : plan.getSubStages()) { - List subTree = createStages( - stage::addExchangeLocations, - nextStageId, - subStagePlan.withBucketToPartition(bucketToPartition), - nodeScheduler, - remoteTaskFactory, - session, - splitBatchSize, - partitioningCache, - nodePartitioningManager, - queryExecutor, - schedulerExecutor, - failureDetector, - nodeTaskMap, - stageSchedulers, - stageLinkages); - stages.addAll(subTree); - - SqlStageExecution childStage = subTree.get(0); - childStagesBuilder.add(childStage); + executor, + bucketToPartitionMap.get(fragment.getId())); + executions.add(execution); + ImmutableList.Builder childExecutionsBuilder = ImmutableList.builder(); + for (StageExecutionPlan child : plan.getSubStages()) { + childExecutionsBuilder.add(createStreamingExecution( + child, + outputBufferManagers, + bucketToPartitionMap, + partitioningCache, + execution, + executions, + schedulers)); + } + ImmutableList childExecutions = childExecutionsBuilder.build(); + StageScheduler scheduler = createStageScheduler( + execution, + plan.getSplitSources(), + childExecutions, + partitioningCache); + schedulers.put(execution.getStageId(), scheduler); + + execution.addStateChangeListener(newState -> { + if (newState == FLUSHING || newState.isDone()) { + childExecutions.forEach(StreamingStageExecution::cancel); } - return childStagesBuilder.build(); - }; + }); - Set childStages; - PartitioningHandle partitioningHandle = plan.getFragment().getPartitioning(); + return execution; + } + + private StageScheduler createStageScheduler( + StreamingStageExecution stageExecution, + Map splitSources, + List childStages, + Function partitioningCache) + { + Session session = queryStateMachine.getSession(); + PlanFragment fragment = stageExecution.getFragment(); + PartitioningHandle partitioningHandle = fragment.getPartitioning(); if (partitioningHandle.equals(SOURCE_DISTRIBUTION)) { // nodes are selected dynamically based on the constraints of the splits and the system load - Entry entry = Iterables.getOnlyElement(plan.getSplitSources().entrySet()); + Entry entry = Iterables.getOnlyElement(splitSources.entrySet()); PlanNodeId planNodeId = entry.getKey(); SplitSource splitSource = entry.getValue(); Optional catalogName = Optional.of(splitSource.getCatalogName()) .filter(catalog -> !isInternalSystemConnector(catalog)); NodeSelector nodeSelector = nodeScheduler.createNodeSelector(session, catalogName); - SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stage::getAllTasks); + SplitPlacementPolicy placementPolicy = new DynamicSplitPlacementPolicy(nodeSelector, stageExecution::getAllTasks); - checkArgument(!plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution()); + checkArgument(!fragment.getStageExecutionDescriptor().isStageGroupedExecution()); - childStages = createChildStages.apply(Optional.of(new int[1])); - stageSchedulers.put(stageId, newSourcePartitionedSchedulerAsStageScheduler( - stage, + return newSourcePartitionedSchedulerAsStageScheduler( + stageExecution, planNodeId, splitSource, placementPolicy, splitBatchSize, dynamicFilterService, - () -> childStages.stream().anyMatch(SqlStageExecution::isAnyTaskBlocked))); + () -> childStages.stream().anyMatch(StreamingStageExecution::isAnyTaskBlocked)); } else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { - childStages = createChildStages.apply(Optional.of(new int[1])); Supplier> sourceTasksProvider = () -> childStages.stream() - .map(SqlStageExecution::getTaskStatuses) + .map(StreamingStageExecution::getTaskStatuses) .flatMap(List::stream) .collect(toImmutableList()); - Supplier> writerTasksProvider = stage::getTaskStatuses; + Supplier> writerTasksProvider = stageExecution::getTaskStatuses; ScaledWriterScheduler scheduler = new ScaledWriterScheduler( - stage, + stageExecution, sourceTasksProvider, writerTasksProvider, nodeScheduler.createNodeSelector(session, Optional.empty()), schedulerExecutor, getWriterMinSize(session)); - whenAllStages(childStages, StageState::isDone) + + whenAllStages(childStages, StreamingStageExecution.State::isDone) .addListener(scheduler::finish, directExecutor()); - stageSchedulers.put(stageId, scheduler); + + return scheduler; } else { - Optional bucketToPartition; - Map splitSources = plan.getSplitSources(); if (!splitSources.isEmpty()) { // contains local source - List schedulingOrder = plan.getFragment().getPartitionedSources(); + List schedulingOrder = fragment.getPartitionedSources(); Optional catalogName = partitioningHandle.getConnectorId(); checkArgument(catalogName.isPresent(), "No connector ID for partitioning handle: %s", partitioningHandle); List connectorPartitionHandles; - boolean groupedExecutionForStage = plan.getFragment().getStageExecutionDescriptor().isStageGroupedExecution(); + boolean groupedExecutionForStage = fragment.getStageExecutionDescriptor().isStageGroupedExecution(); if (groupedExecutionForStage) { connectorPartitionHandles = nodePartitioningManager.listPartitionHandles(session, partitioningHandle); checkState(!ImmutableList.of(NOT_PARTITIONED).equals(connectorPartitionHandles)); @@ -404,9 +588,9 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { BucketNodeMap bucketNodeMap; List stageNodeList; - if (plan.getFragment().getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { + if (fragment.getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { // no remote source - boolean dynamicLifespanSchedule = plan.getFragment().getStageExecutionDescriptor().isDynamicLifespanSchedule(); + boolean dynamicLifespanSchedule = fragment.getStageExecutionDescriptor().isDynamicLifespanSchedule(); bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, dynamicLifespanSchedule); // verify execution is consistent with planner's decision on dynamic lifespan schedule @@ -414,26 +598,24 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { stageNodeList = new ArrayList<>(nodeScheduler.createNodeSelector(session, catalogName).allNodes()); Collections.shuffle(stageNodeList); - bucketToPartition = Optional.empty(); } else { // cannot use dynamic lifespan schedule - verify(!plan.getFragment().getStageExecutionDescriptor().isDynamicLifespanSchedule()); + verify(!fragment.getStageExecutionDescriptor().isDynamicLifespanSchedule()); // remote source requires nodePartitionMap - NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning()); + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); if (groupedExecutionForStage) { checkState(connectorPartitionHandles.size() == nodePartitionMap.getBucketToPartition().length); } stageNodeList = nodePartitionMap.getPartitionToNode(); bucketNodeMap = nodePartitionMap.asBucketNodeMap(); - bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition()); } - stageSchedulers.put(stageId, new FixedSourcePartitionedScheduler( - stage, + return new FixedSourcePartitionedScheduler( + stageExecution, splitSources, - plan.getFragment().getStageExecutionDescriptor(), + fragment.getStageExecutionDescriptor(), schedulingOrder, stageNodeList, bucketNodeMap, @@ -441,287 +623,234 @@ else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { getConcurrentLifespansPerNode(session), nodeScheduler.createNodeSelector(session, catalogName), connectorPartitionHandles, - dynamicFilterService)); + dynamicFilterService); } else { // all sources are remote - NodePartitionMap nodePartitionMap = partitioningCache.apply(plan.getFragment().getPartitioning()); + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); List partitionToNode = nodePartitionMap.getPartitionToNode(); // todo this should asynchronously wait a standard timeout period before failing checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); - stageSchedulers.put(stageId, new FixedCountScheduler(stage, partitionToNode)); - bucketToPartition = Optional.of(nodePartitionMap.getBucketToPartition()); + return new FixedCountScheduler(stageExecution, partitionToNode); } - childStages = createChildStages.apply(bucketToPartition); } - - stage.addStateChangeListener(newState -> { - if (newState == FLUSHING || newState.isDone()) { - childStages.forEach(SqlStageExecution::cancel); - } - }); - - stageLinkages.put(stageId, new StageLinkage(plan.getFragment().getId(), parent, childStages)); - - return stages.build(); } - public BasicStageStats getBasicStageStats() + private static ListenableFuture whenAllStages(Collection stages, Predicate predicate) { - List stageStats = stages.values().stream() - .map(SqlStageExecution::getBasicStageStats) - .collect(toImmutableList()); - - return aggregateBasicStageStats(stageStats); - } - - public StageInfo getStageInfo() - { - Map stageInfos = stages.values().stream() - .map(SqlStageExecution::getStageInfo) - .collect(toImmutableMap(StageInfo::getStageId, identity())); - - return buildStageInfo(rootStageId, stageInfos); - } + checkArgument(!stages.isEmpty(), "stages is empty"); + Set stageIds = stages.stream() + .map(StreamingStageExecution::getStageId) + .collect(toCollection(Sets::newConcurrentHashSet)); + SettableFuture future = SettableFuture.create(); - private StageInfo buildStageInfo(StageId stageId, Map stageInfos) - { - StageInfo parent = stageInfos.get(stageId); - checkArgument(parent != null, "No stageInfo for %s", parent); - List childStages = stageLinkages.get(stageId).getChildStageIds().stream() - .map(childStageId -> buildStageInfo(childStageId, stageInfos)) - .collect(toImmutableList()); - if (childStages.isEmpty()) { - return parent; + for (StreamingStageExecution stage : stages) { + stage.addStateChangeListener(state -> { + if (predicate.test(state) && stageIds.remove(stage.getStageId()) && stageIds.isEmpty()) { + future.set(null); + } + }); } - return new StageInfo( - parent.getStageId(), - parent.getState(), - parent.getPlan(), - parent.getTypes(), - parent.getStageStats(), - parent.getTasks(), - childStages, - parent.getTables(), - parent.getFailureCause()); - } - - public long getUserMemoryReservation() - { - return stages.values().stream() - .mapToLong(SqlStageExecution::getUserMemoryReservation) - .sum(); - } - - public long getTotalMemoryReservation() - { - return stages.values().stream() - .mapToLong(SqlStageExecution::getTotalMemoryReservation) - .sum(); - } - public Duration getTotalCpuTime() - { - long millis = stages.values().stream() - .mapToLong(stage -> stage.getTotalCpuTime().toMillis()) - .sum(); - return new Duration(millis, MILLISECONDS); - } - - public void start() - { - if (started.compareAndSet(false, true)) { - executor.submit(this::schedule); - } + return future; } - private void schedule() + private static Map createOutputBufferManagers( + Collection stageExecutions, + Map> bucketToPartitionMap) { - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - Set completedStages = new HashSet<>(); - ExecutionSchedule executionSchedule = executionPolicy.createExecutionSchedule(stages.values()); - while (!executionSchedule.isFinished()) { - List> blockedStages = new ArrayList<>(); - for (SqlStageExecution stage : executionSchedule.getStagesToSchedule()) { - stage.beginScheduling(); - - // perform some scheduling work - ScheduleResult result = stageSchedulers.get(stage.getStageId()) - .schedule(); - - // modify parent and children based on the results of the scheduling - if (result.isFinished()) { - stage.schedulingComplete(); - } - else if (!result.getBlocked().isDone()) { - blockedStages.add(result.getBlocked()); - } - stageLinkages.get(stage.getStageId()) - .processScheduleResults(stage.getState(), result.getNewTasks()); - schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled()); - if (result.getBlockedReason().isPresent()) { - switch (result.getBlockedReason().get()) { - case WRITER_SCALING: - // no-op - break; - case WAITING_FOR_SOURCE: - schedulerStats.getWaitingForSource().update(1); - break; - case SPLIT_QUEUES_FULL: - schedulerStats.getSplitQueuesFull().update(1); - break; - case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE: - case NO_ACTIVE_DRIVER_GROUP: - break; - default: - throw new UnsupportedOperationException("Unknown blocked reason: " + result.getBlockedReason().get()); - } - } - } - - // make sure to update stage linkage at least once per loop to catch async state changes (e.g., partial cancel) - for (SqlStageExecution stage : stages.values()) { - if (!completedStages.contains(stage.getStageId()) && stage.getState().isDone()) { - stageLinkages.get(stage.getStageId()) - .processScheduleResults(stage.getState(), ImmutableSet.of()); - completedStages.add(stage.getStageId()); - } - } - - // wait for a state change and then schedule again - if (!blockedStages.isEmpty()) { - try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { - tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); - } - for (ListenableFuture blockedStage : blockedStages) { - blockedStage.cancel(true); - } - } + ImmutableMap.Builder result = ImmutableMap.builder(); + for (SqlStageExecution stageExecution : stageExecutions) { + PlanFragmentId fragmentId = stageExecution.getFragment().getId(); + PartitioningHandle partitioningHandle = stageExecution.getFragment().getPartitioningScheme().getPartitioning().getHandle(); + OutputBufferManager outputBufferManager; + if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) { + outputBufferManager = new BroadcastOutputBufferManager(); } - - for (SqlStageExecution stage : stages.values()) { - StageState state = stage.getState(); - if (state != SCHEDULED && state != RUNNING && state != FLUSHING && !state.isDone()) { - throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage %s is in state %s", stage.getStageId(), state)); - } - } - } - catch (Throwable t) { - queryStateMachine.transitionToFailed(t); - throw t; - } - finally { - RuntimeException closeError = new RuntimeException(); - for (StageScheduler scheduler : stageSchedulers.values()) { - try { - scheduler.close(); - } - catch (Throwable t) { - queryStateMachine.transitionToFailed(t); - // Self-suppression not permitted - if (closeError != t) { - closeError.addSuppressed(t); - } - } + else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { + outputBufferManager = new ScaledOutputBufferManager(); } - if (closeError.getSuppressed().length > 0) { - throw closeError; + else { + Optional bucketToPartition = bucketToPartitionMap.get(fragmentId); + checkArgument(bucketToPartition.isPresent(), "bucketToPartition is expected to be present for fragment: %s", fragmentId); + int partitionCount = Ints.max(bucketToPartition.get()) + 1; + outputBufferManager = new PartitionedOutputBufferManager(partitioningHandle, partitionCount); } + result.put(fragmentId, outputBufferManager); } + return result.build(); } - public void cancelStage(StageId stageId) + private static Map> createBucketToPartitionMap( + StageExecutionPlan plan, + Function partitioningCache) { - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - SqlStageExecution sqlStageExecution = stages.get(stageId); - SqlStageExecution stage = requireNonNull(sqlStageExecution, () -> format("Stage '%s' does not exist", stageId)); - stage.cancel(); + ImmutableMap.Builder> result = ImmutableMap.builder(); + // root fragment always has a single consumer + result.put(plan.getFragment().getId(), Optional.of(new int[] {0})); + Queue queue = new ArrayDeque<>(); + queue.add(plan); + while (!queue.isEmpty()) { + StageExecutionPlan executionPlan = queue.poll(); + PlanFragment fragment = executionPlan.getFragment(); + Optional bucketToPartition = getBucketToPartition(fragment.getPartitioning(), partitioningCache, fragment.getRoot(), fragment.getRemoteSourceNodes()); + for (StageExecutionPlan child : executionPlan.getSubStages()) { + result.put(child.getFragment().getId(), bucketToPartition); + queue.add(child); + } } + return result.build(); } - public void abort() + private static Optional getBucketToPartition( + PartitioningHandle partitioningHandle, + Function partitioningCache, + PlanNode fragmentRoot, + List remoteSourceNodes) { - try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { - stages.values().forEach(SqlStageExecution::abort); + if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { + return Optional.of(new int[1]); + } + else if (searchFrom(fragmentRoot).where(node -> node instanceof TableScanNode).findFirst().isPresent()) { + if (remoteSourceNodes.stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { + return Optional.empty(); + } + else { + // remote source requires nodePartitionMap + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); + return Optional.of(nodePartitionMap.getBucketToPartition()); + } + } + else { + NodePartitionMap nodePartitionMap = partitioningCache.apply(partitioningHandle); + List partitionToNode = nodePartitionMap.getPartitionToNode(); + // todo this should asynchronously wait a standard timeout period before failing + checkCondition(!partitionToNode.isEmpty(), NO_NODES_AVAILABLE, "No worker nodes available"); + return Optional.of(nodePartitionMap.getBucketToPartition()); } } - private static ListenableFuture whenAllStages(Collection stages, Predicate predicate) + private static class StreamingScheduler { - checkArgument(!stages.isEmpty(), "stages is empty"); - Set stageIds = stages.stream() - .map(SqlStageExecution::getStageId) - .collect(toCollection(Sets::newConcurrentHashSet)); - SettableFuture future = SettableFuture.create(); - - for (SqlStageExecution stage : stages) { - stage.addStateChangeListener(state -> { - if (predicate.test(state) && stageIds.remove(stage.getStageId()) && stageIds.isEmpty()) { - future.set(null); - } - }); + private final QueryStateMachine queryStateMachine; + private final ExecutionSchedule executionSchedule; + private final Map stageSchedulers; + private final SplitSchedulerStats schedulerStats; + private final List stageExecutions; + + private final AtomicBoolean started = new AtomicBoolean(); + + private StreamingScheduler( + QueryStateMachine queryStateMachine, + ExecutionSchedule executionSchedule, + Map stageSchedulers, + SplitSchedulerStats schedulerStats, + List stageExecutions) + { + this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); + this.executionSchedule = requireNonNull(executionSchedule, "executionSchedule is null"); + this.stageSchedulers = ImmutableMap.copyOf(requireNonNull(stageSchedulers, "stageSchedulers is null")); + this.schedulerStats = requireNonNull(schedulerStats, "schedulerStats is null"); + this.stageExecutions = ImmutableList.copyOf(requireNonNull(stageExecutions, "stageExecutions is null")); } - return future; - } + public void schedule() + { + checkState(started.compareAndSet(false, true), "already started"); - private interface ExchangeLocationsConsumer - { - void addExchangeLocations(PlanFragmentId fragmentId, Set tasks, boolean noMoreExchangeLocations); - } + try (SetThreadName ignored = new SetThreadName("Query-%s", queryStateMachine.getQueryId())) { + while (!executionSchedule.isFinished()) { + List> blockedStages = new ArrayList<>(); + for (StreamingStageExecution stage : executionSchedule.getStagesToSchedule()) { + stage.beginScheduling(); - private static class StageLinkage - { - private final PlanFragmentId currentStageFragmentId; - private final ExchangeLocationsConsumer parent; - private final Set childOutputBufferManagers; - private final Set childStageIds; + // perform some scheduling work + ScheduleResult result = stageSchedulers.get(stage.getStageId()) + .schedule(); - public StageLinkage(PlanFragmentId fragmentId, ExchangeLocationsConsumer parent, Set children) - { - this.currentStageFragmentId = fragmentId; - this.parent = parent; - this.childOutputBufferManagers = children.stream() - .map(childStage -> { - PartitioningHandle partitioningHandle = childStage.getFragment().getPartitioningScheme().getPartitioning().getHandle(); - if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) { - return new BroadcastOutputBufferManager(childStage::setOutputBuffers); + // modify parent and children based on the results of the scheduling + if (result.isFinished()) { + stage.schedulingComplete(); } - else if (partitioningHandle.equals(SCALED_WRITER_DISTRIBUTION)) { - return new ScaledOutputBufferManager(childStage::setOutputBuffers); + else if (!result.getBlocked().isDone()) { + blockedStages.add(result.getBlocked()); } - else { - int partitionCount = Ints.max(childStage.getFragment().getPartitioningScheme().getBucketToPartition().get()) + 1; - return new PartitionedOutputBufferManager(partitioningHandle, partitionCount, childStage::setOutputBuffers); + schedulerStats.getSplitsScheduledPerIteration().add(result.getSplitsScheduled()); + if (result.getBlockedReason().isPresent()) { + switch (result.getBlockedReason().get()) { + case WRITER_SCALING: + // no-op + break; + case WAITING_FOR_SOURCE: + schedulerStats.getWaitingForSource().update(1); + break; + case SPLIT_QUEUES_FULL: + schedulerStats.getSplitQueuesFull().update(1); + break; + case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE: + case NO_ACTIVE_DRIVER_GROUP: + break; + default: + throw new UnsupportedOperationException("Unknown blocked reason: " + result.getBlockedReason().get()); + } } - }) - .collect(toImmutableSet()); + } - this.childStageIds = children.stream() - .map(SqlStageExecution::getStageId) - .collect(toImmutableSet()); - } + // wait for a state change and then schedule again + if (!blockedStages.isEmpty()) { + try (TimeStat.BlockTimer timer = schedulerStats.getSleepTime().time()) { + tryGetFutureValue(whenAnyComplete(blockedStages), 1, SECONDS); + } + for (ListenableFuture blockedStage : blockedStages) { + blockedStage.cancel(true); + } + } + } - public Set getChildStageIds() - { - return childStageIds; + for (StreamingStageExecution stage : stageExecutions) { + StreamingStageExecution.State state = stage.getState(); + if (state != SCHEDULED && state != RUNNING && state != FLUSHING && !state.isDone()) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Scheduling is complete, but stage %s is in state %s", stage.getStageId(), state)); + } + } + } + catch (Throwable t) { + queryStateMachine.transitionToFailed(t); + throw t; + } + finally { + RuntimeException closeError = new RuntimeException(); + for (StageScheduler scheduler : stageSchedulers.values()) { + try { + scheduler.close(); + } + catch (Throwable t) { + queryStateMachine.transitionToFailed(t); + // Self-suppression not permitted + if (closeError != t) { + closeError.addSuppressed(t); + } + } + } + if (closeError.getSuppressed().length > 0) { + throw closeError; + } + } } - public void processScheduleResults(StageState newState, Set newTasks) + public void cancelStage(StageId stageId) { - boolean noMoreTasks = !newState.canScheduleMoreTasks(); - // Add an exchange location to the parent stage for each new task - parent.addExchangeLocations(currentStageFragmentId, newTasks, noMoreTasks); - - if (!childOutputBufferManagers.isEmpty()) { - // Add an output buffer to the child stages for each new task - List newOutputBuffers = newTasks.stream() - .map(task -> new OutputBufferId(task.getTaskId().getId())) - .collect(toImmutableList()); - for (OutputBufferManager child : childOutputBufferManagers) { - child.addOutputBuffers(newOutputBuffers, noMoreTasks); + for (StreamingStageExecution execution : stageExecutions) { + if (execution.getStageId().equals(stageId)) { + execution.cancel(); + break; } } } + + public void abort() + { + stageExecutions.forEach(StreamingStageExecution::abort); + } } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StreamingStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StreamingStageExecution.java new file mode 100644 index 000000000000..81defd120850 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StreamingStageExecution.java @@ -0,0 +1,700 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import com.google.common.collect.HashMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Multimap; +import com.google.common.collect.Sets; +import io.airlift.log.Logger; +import io.trino.execution.ExecutionFailureInfo; +import io.trino.execution.Lifespan; +import io.trino.execution.RemoteTask; +import io.trino.execution.SqlStageExecution; +import io.trino.execution.StageId; +import io.trino.execution.StateMachine; +import io.trino.execution.StateMachine.StateChangeListener; +import io.trino.execution.TaskId; +import io.trino.execution.TaskState; +import io.trino.execution.TaskStatus; +import io.trino.execution.buffer.OutputBuffers; +import io.trino.execution.buffer.OutputBuffers.OutputBufferId; +import io.trino.failuredetector.FailureDetector; +import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; +import io.trino.spi.TrinoException; +import io.trino.split.RemoteSplit; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.util.Failures; +import org.joda.time.DateTime; + +import javax.annotation.concurrent.GuardedBy; + +import java.net.URI; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.newConcurrentHashSet; +import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.trino.execution.scheduler.StreamingStageExecution.State.ABORTED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.CANCELED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FAILED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FINISHED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.FLUSHING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.PLANNED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.RUNNING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.SCHEDULED; +import static io.trino.execution.scheduler.StreamingStageExecution.State.SCHEDULING; +import static io.trino.execution.scheduler.StreamingStageExecution.State.SCHEDULING_SPLITS; +import static io.trino.failuredetector.FailureDetector.State.GONE; +import static io.trino.operator.ExchangeOperator.REMOTE_CONNECTOR_ID; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; +import static java.util.Objects.requireNonNull; + +public class StreamingStageExecution + implements ResultsConsumer +{ + private static final Logger log = Logger.get(StreamingStageExecution.class); + + private final StreamingStageStateMachine stateMachine; + private final SqlStageExecution stageExecution; + private final Map outputBufferManagers; + private final ResultsConsumer parent; + private final FailureDetector failureDetector; + private final Executor executor; + private final Optional bucketToPartition; + private final Map exchangeSources; + + private final Map tasks = new ConcurrentHashMap<>(); + + // current stage task tracking + @GuardedBy("this") + private final Set allTasks = newConcurrentHashSet(); + @GuardedBy("this") + private final Set finishedTasks = newConcurrentHashSet(); + @GuardedBy("this") + private final Set flushingTasks = newConcurrentHashSet(); + + // source task tracking + @GuardedBy("this") + private final Multimap sourceTasks = HashMultimap.create(); + @GuardedBy("this") + private final Set completeSourceFragments = newConcurrentHashSet(); + @GuardedBy("this") + private final Set completeSources = newConcurrentHashSet(); + + // lifespan tracking + private final Set completedDriverGroups = new HashSet<>(); + private final ListenerManager> completedLifespansChangeListeners = new ListenerManager<>(); + + public static StreamingStageExecution createStreamingStageExecution( + SqlStageExecution stageExecution, + Map outputBufferManagers, + ResultsConsumer parent, + FailureDetector failureDetector, + Executor executor, + Optional bucketToPartition) + { + StreamingStageStateMachine stateMachine = new StreamingStageStateMachine(stageExecution.getStageId(), executor); + ImmutableMap.Builder exchangeSources = ImmutableMap.builder(); + for (RemoteSourceNode remoteSourceNode : stageExecution.getFragment().getRemoteSourceNodes()) { + for (PlanFragmentId planFragmentId : remoteSourceNode.getSourceFragmentIds()) { + exchangeSources.put(planFragmentId, remoteSourceNode); + } + } + StreamingStageExecution execution = new StreamingStageExecution( + stateMachine, + stageExecution, + outputBufferManagers, + parent, + failureDetector, + executor, + bucketToPartition, + exchangeSources.build()); + execution.initialize(); + return execution; + } + + private StreamingStageExecution( + StreamingStageStateMachine stateMachine, + SqlStageExecution stageExecution, + Map outputBufferManagers, + ResultsConsumer parent, + FailureDetector failureDetector, + Executor executor, + Optional bucketToPartition, + Map exchangeSources) + { + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); + this.stageExecution = requireNonNull(stageExecution, "stageExecution is null"); + this.outputBufferManagers = ImmutableMap.copyOf(requireNonNull(outputBufferManagers, "outputBufferManagers is null")); + this.parent = requireNonNull(parent, "parent is null"); + this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); + this.executor = requireNonNull(executor, "executor is null"); + this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); + this.exchangeSources = ImmutableMap.copyOf(requireNonNull(exchangeSources, "exchangeSources is null")); + } + + private void initialize() + { + stateMachine.addStateChangeListener(state -> { + if (!state.canScheduleMoreTasks()) { + // notify parent stage + parent.noMoreSourceTasks(stageExecution.getFragment().getId()); + + // update output buffers + for (PlanFragmentId sourceFragment : exchangeSources.keySet()) { + OutputBufferManager outputBufferManager = outputBufferManagers.get(sourceFragment); + outputBufferManager.noMoreBuffers(); + for (RemoteTask sourceTask : sourceTasks.get(stageExecution.getFragment().getId())) { + sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers()); + } + } + } + }); + } + + public State getState() + { + return stateMachine.getState(); + } + + /** + * Listener is always notified asynchronously using a dedicated notification thread pool so, care should + * be taken to avoid leaking {@code this} when adding a listener in a constructor. + */ + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + stateMachine.addStateChangeListener(stateChangeListener); + } + + public void addCompletedDriverGroupsChangedListener(Consumer> newlyCompletedDriverGroupConsumer) + { + completedLifespansChangeListeners.addListener(newlyCompletedDriverGroupConsumer); + } + + public synchronized void beginScheduling() + { + stateMachine.transitionToScheduling(); + } + + public synchronized void transitionToSchedulingSplits() + { + stateMachine.transitionToSchedulingSplits(); + } + + public synchronized void schedulingComplete() + { + if (!stateMachine.transitionToScheduled()) { + return; + } + + if (isFlushing()) { + stateMachine.transitionToFlushing(); + } + if (finishedTasks.containsAll(allTasks)) { + stateMachine.transitionToFinished(); + } + + for (PlanNodeId partitionedSource : stageExecution.getFragment().getPartitionedSources()) { + schedulingComplete(partitionedSource); + } + } + + private synchronized boolean isFlushing() + { + // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. + return !flushingTasks.isEmpty() + && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); + } + + public synchronized void schedulingComplete(PlanNodeId partitionedSource) + { + for (RemoteTask task : getAllTasks()) { + task.noMoreSplits(partitionedSource); + } + completeSources.add(partitionedSource); + } + + public synchronized void cancel() + { + stateMachine.transitionToCanceled(); + getAllTasks().forEach(RemoteTask::cancel); + } + + public synchronized void abort() + { + stateMachine.transitionToAborted(); + getAllTasks().forEach(RemoteTask::abort); + } + + public synchronized Optional scheduleTask( + InternalNode node, + int partition, + Multimap initialSplits, + Multimap noMoreSplitsForLifespan) + { + if (stateMachine.getState().isDone()) { + return Optional.empty(); + } + + checkArgument(!tasks.containsKey(partition), "A task for partition %s already exists", partition); + + OutputBuffers outputBuffers = outputBufferManagers.get(stageExecution.getFragment().getId()).getOutputBuffers(); + + Optional optionalTask = stageExecution.createTask( + node, + partition, + bucketToPartition, + outputBuffers, + initialSplits, + ImmutableMultimap.of(), + ImmutableSet.of()); + + if (optionalTask.isEmpty()) { + return Optional.empty(); + } + + RemoteTask task = optionalTask.get(); + + tasks.put(partition, task); + + ImmutableMultimap.Builder exchangeSplits = ImmutableMultimap.builder(); + sourceTasks.forEach((fragmentId, sourceTask) -> { + TaskStatus status = sourceTask.getTaskStatus(); + if (status.getState() != TaskState.FINISHED) { + PlanNodeId planNodeId = exchangeSources.get(fragmentId).getId(); + exchangeSplits.put(planNodeId, createExchangeSplit(task, sourceTask)); + } + }); + + allTasks.add(task.getTaskId()); + + task.addSplits(exchangeSplits.build()); + noMoreSplitsForLifespan.forEach(task::noMoreSplits); + completeSources.forEach(task::noMoreSplits); + + task.addStateChangeListener(this::updateTaskStatus); + task.addStateChangeListener(this::updateCompletedDriverGroups); + + task.start(); + + // update parent stage + parent.addSourceTask(stageExecution.getFragment().getId(), task); + + // update output buffers + OutputBufferId outputBufferId = new OutputBufferId(task.getTaskId().getId()); + for (PlanFragmentId sourceFragment : exchangeSources.keySet()) { + OutputBufferManager outputBufferManager = outputBufferManagers.get(sourceFragment); + outputBufferManager.addOutputBuffer(outputBufferId); + for (RemoteTask sourceTask : sourceTasks.get(stageExecution.getFragment().getId())) { + sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers()); + } + } + + return Optional.of(task); + } + + private synchronized void updateTaskStatus(TaskStatus taskStatus) + { + State stageState = stateMachine.getState(); + if (stageState.isDone()) { + return; + } + + TaskState taskState = taskStatus.getState(); + + switch (taskState) { + case FAILED: + RuntimeException failure = taskStatus.getFailures().stream() + .findFirst() + .map(this::rewriteTransportFailure) + .map(ExecutionFailureInfo::toException) + .orElse(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); + stateMachine.transitionToFailed(failure); + break; + case ABORTED: + // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) + stateMachine.transitionToFailed(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); + break; + case FLUSHING: + flushingTasks.add(taskStatus.getTaskId()); + break; + case FINISHED: + finishedTasks.add(taskStatus.getTaskId()); + flushingTasks.remove(taskStatus.getTaskId()); + break; + default: + } + + if (stageState == SCHEDULED || stageState == RUNNING || stageState == FLUSHING) { + if (taskState == TaskState.RUNNING) { + stateMachine.transitionToRunning(); + } + if (isFlushing()) { + stateMachine.transitionToFlushing(); + } + if (finishedTasks.containsAll(allTasks)) { + stateMachine.transitionToFinished(); + } + } + } + + private synchronized void updateCompletedDriverGroups(TaskStatus taskStatus) + { + // Sets.difference returns a view. + // Once we add the difference into `completedDriverGroups`, the view will be empty. + // `completedLifespansChangeListeners.invoke` happens asynchronously. + // As a result, calling the listeners before updating `completedDriverGroups` doesn't make a difference. + // That's why a copy must be made here. + Set newlyCompletedDriverGroups = ImmutableSet.copyOf(Sets.difference(taskStatus.getCompletedDriverGroups(), this.completedDriverGroups)); + if (newlyCompletedDriverGroups.isEmpty()) { + return; + } + completedLifespansChangeListeners.invoke(newlyCompletedDriverGroups, executor); + // newlyCompletedDriverGroups is a view. + // Making changes to completedDriverGroups will change newlyCompletedDriverGroups. + completedDriverGroups.addAll(newlyCompletedDriverGroups); + } + + private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo) + { + if (executionFailureInfo.getRemoteHost() == null || failureDetector.getState(executionFailureInfo.getRemoteHost()) != GONE) { + return executionFailureInfo; + } + + return new ExecutionFailureInfo( + executionFailureInfo.getType(), + executionFailureInfo.getMessage(), + executionFailureInfo.getCause(), + executionFailureInfo.getSuppressed(), + executionFailureInfo.getStack(), + executionFailureInfo.getErrorLocation(), + REMOTE_HOST_GONE.toErrorCode(), + executionFailureInfo.getRemoteHost()); + } + + @Override + public synchronized void addSourceTask(PlanFragmentId fragmentId, RemoteTask sourceTask) + { + requireNonNull(fragmentId, "fragmentId is null"); + + RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); + checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); + + sourceTasks.put(fragmentId, sourceTask); + + OutputBufferManager outputBufferManager = outputBufferManagers.get(fragmentId); + sourceTask.setOutputBuffers(outputBufferManager.getOutputBuffers()); + + for (RemoteTask destinationTask : getAllTasks()) { + destinationTask.addSplits(ImmutableMultimap.of(remoteSource.getId(), createExchangeSplit(destinationTask, sourceTask))); + } + } + + @Override + public synchronized void noMoreSourceTasks(PlanFragmentId fragmentId) + { + RemoteSourceNode remoteSource = exchangeSources.get(fragmentId); + checkArgument(remoteSource != null, "Unknown remote source %s. Known sources are %s", fragmentId, exchangeSources.keySet()); + + completeSourceFragments.add(fragmentId); + + // is the source now complete? + if (completeSourceFragments.containsAll(remoteSource.getSourceFragmentIds())) { + completeSources.add(remoteSource.getId()); + for (RemoteTask task : getAllTasks()) { + task.noMoreSplits(remoteSource.getId()); + } + } + } + + public List getAllTasks() + { + return ImmutableList.copyOf(tasks.values()); + } + + public List getTaskStatuses() + { + return getAllTasks().stream() + .map(RemoteTask::getTaskStatus) + .collect(toImmutableList()); + } + + public boolean isAnyTaskBlocked() + { + return getTaskStatuses().stream().anyMatch(TaskStatus::isOutputBufferOverutilized); + } + + public void recordGetSplitTime(long start) + { + stageExecution.recordGetSplitTime(start); + } + + public StageId getStageId() + { + return stageExecution.getStageId(); + } + + public PlanFragment getFragment() + { + return stageExecution.getFragment(); + } + + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + + private static Split createExchangeSplit(RemoteTask destinationTask, RemoteTask sourceTask) + { + // Fetch the results from the buffer assigned to the task based on id + URI exchangeLocation = sourceTask.getTaskStatus().getSelf(); + URI splitLocation = uriBuilderFrom(exchangeLocation).appendPath("results").appendPath(String.valueOf(destinationTask.getTaskId().getId())).build(); + return new Split(REMOTE_CONNECTOR_ID, new RemoteSplit(splitLocation), Lifespan.taskWide()); + } + + public enum State + { + /** + * Stage is planned but has not been scheduled yet. A stage will + * be in the planned state until, the dependencies of the stage + * have begun producing output. + */ + PLANNED(false, false), + /** + * Stage tasks are being scheduled on nodes. + */ + SCHEDULING(false, false), + /** + * All stage tasks have been scheduled, but splits are still being scheduled. + */ + SCHEDULING_SPLITS(false, false), + /** + * Stage has been scheduled on nodes and ready to execute, but all tasks are still queued. + */ + SCHEDULED(false, false), + /** + * Stage is running. + */ + RUNNING(false, false), + /** + * Stage has finished executing and output being consumed. + * In this state, at-least one of the tasks is flushing and the non-flushing tasks are finished + */ + FLUSHING(false, false), + /** + * Stage has finished executing and all output has been consumed. + */ + FINISHED(true, false), + /** + * Stage was canceled by a user. + */ + CANCELED(true, false), + /** + * Stage was aborted due to a failure in the query. The failure + * was not in this stage. + */ + ABORTED(true, true), + /** + * Stage execution failed. + */ + FAILED(true, true); + + private final boolean doneState; + private final boolean failureState; + + State(boolean doneState, boolean failureState) + { + checkArgument(!failureState || doneState, "%s is a non-done failure state", name()); + this.doneState = doneState; + this.failureState = failureState; + } + + /** + * Is this a terminal state. + */ + public boolean isDone() + { + return doneState; + } + + /** + * Is this a non-success terminal state. + */ + public boolean isFailure() + { + return failureState; + } + + public boolean canScheduleMoreTasks() + { + switch (this) { + case PLANNED: + case SCHEDULING: + // workers are still being added to the query + return true; + case SCHEDULING_SPLITS: + case SCHEDULED: + case RUNNING: + case FLUSHING: + case FINISHED: + case CANCELED: + // no more workers will be added to the query + return false; + case ABORTED: + case FAILED: + // DO NOT complete a FAILED or ABORTED stage. This will cause the + // stage above to finish normally, which will result in a query + // completing successfully when it should fail.. + return true; + } + throw new IllegalStateException("Unhandled state: " + this); + } + } + + private static class StreamingStageStateMachine + { + private static final Set TERMINAL_STAGE_STATES = Stream.of(State.values()).filter(State::isDone).collect(toImmutableSet()); + + private final StageId stageId; + private final StateMachine state; + private final AtomicReference schedulingComplete = new AtomicReference<>(); + private final AtomicReference failureCause = new AtomicReference<>(); + + private StreamingStageStateMachine(StageId stageId, Executor executor) + { + this.stageId = requireNonNull(stageId, "stageId is null"); + + state = new StateMachine<>("Streaming stage execution " + stageId, executor, PLANNED, TERMINAL_STAGE_STATES); + state.addStateChangeListener(state -> log.debug("Streaming stage execution %s is %s", stageId, state)); + } + + public State getState() + { + return state.get(); + } + + public boolean transitionToScheduling() + { + return state.compareAndSet(PLANNED, SCHEDULING); + } + + public boolean transitionToSchedulingSplits() + { + return state.setIf(SCHEDULING_SPLITS, currentState -> currentState == PLANNED || currentState == SCHEDULING); + } + + public boolean transitionToScheduled() + { + schedulingComplete.compareAndSet(null, DateTime.now()); + return state.setIf(SCHEDULED, currentState -> currentState == PLANNED || currentState == SCHEDULING || currentState == SCHEDULING_SPLITS); + } + + public boolean transitionToRunning() + { + return state.setIf(RUNNING, currentState -> currentState != RUNNING && currentState != FLUSHING && !currentState.isDone()); + } + + public boolean transitionToFlushing() + { + return state.setIf(FLUSHING, currentState -> currentState != FLUSHING && !currentState.isDone()); + } + + public boolean transitionToFinished() + { + return state.setIf(FINISHED, currentState -> !currentState.isDone()); + } + + public boolean transitionToCanceled() + { + return state.setIf(CANCELED, currentState -> !currentState.isDone()); + } + + public boolean transitionToAborted() + { + return state.setIf(ABORTED, currentState -> !currentState.isDone()); + } + + public boolean transitionToFailed(Throwable throwable) + { + requireNonNull(throwable, "throwable is null"); + + failureCause.compareAndSet(null, Failures.toFailure(throwable)); + boolean failed = state.setIf(FAILED, currentState -> !currentState.isDone()); + if (failed) { + log.error(throwable, "Streaming stage execution for stage %s failed", stageId); + } + else { + log.debug(throwable, "Failure in streaming stage execution for stage %s after finished", stageId); + } + return failed; + } + + public Optional getFailureCause() + { + return Optional.ofNullable(failureCause.get()); + } + + /** + * Listener is always notified asynchronously using a dedicated notification thread pool so, care should + * be taken to avoid leaking {@code this} when adding a listener in a constructor. Additionally, it is + * possible notifications are observed out of order due to the asynchronous execution. + */ + public void addStateChangeListener(StateChangeListener stateChangeListener) + { + state.addStateChangeListener(stateChangeListener); + } + } + + private static class ListenerManager + { + private final List> listeners = new ArrayList<>(); + private boolean frozen; + + public synchronized void addListener(Consumer listener) + { + checkState(!frozen, "Listeners have been invoked"); + listeners.add(listener); + } + + public synchronized void invoke(T payload, Executor executor) + { + frozen = true; + for (Consumer listener : listeners) { + executor.execute(() -> listener.accept(payload)); + } + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java index c6f76bf2949a..f058f938407c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStageExecution.java @@ -15,17 +15,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; import io.trino.client.NodeVersion; import io.trino.cost.StatsAndCosts; -import io.trino.execution.MockRemoteTaskFactory.MockRemoteTask; import io.trino.execution.scheduler.SplitSchedulerStats; -import io.trino.failuredetector.NoOpFailureDetector; import io.trino.metadata.InternalNode; -import io.trino.server.DynamicFilterService; import io.trino.spi.QueryId; import io.trino.spi.type.Type; -import io.trino.spi.type.TypeOperators; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanFragment; @@ -51,7 +49,6 @@ import static io.trino.execution.SqlStageExecution.createSqlStageExecution; import static io.trino.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; -import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -111,10 +108,7 @@ private void testFinalStageInfoInternal() true, nodeTaskMap, executor, - new NoOpFailureDetector(), - new DynamicFilterService(createTestMetadataManager(), new TypeOperators(), new DynamicFilterConfig()), new SplitSchedulerStats()); - stage.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY)); // add listener that fetches stage info when the final status is available SettableFuture finalStageInfo = SettableFuture.create(); @@ -133,7 +127,14 @@ private void testFinalStageInfoInternal() URI.create("http://10.0.0." + (i / 10_000) + ":" + (i % 10_000)), NodeVersion.UNKNOWN, false); - stage.scheduleTask(node, i); + stage.createTask( + node, + i, + Optional.empty(), + createInitialEmptyOutputBuffers(ARBITRARY), + ImmutableMultimap.of(), + ImmutableMultimap.of(), + ImmutableSet.of()); latch.countDown(); } } @@ -147,7 +148,7 @@ private void testFinalStageInfoInternal() // wait for some tasks to be created, and then abort the query latch.await(1, MINUTES); assertFalse(stage.getStageInfo().getTasks().isEmpty()); - stage.abort(); + stage.transitionToFinished(); // once the final stage info is available, verify that it is complete StageInfo stageInfo = finalStageInfo.get(1, MINUTES); @@ -159,43 +160,6 @@ private void testFinalStageInfoInternal() addTasksTask.cancel(true); } - @Test - public void testIsAnyTaskBlocked() - { - NodeTaskMap nodeTaskMap = new NodeTaskMap(new FinalizerService()); - - StageId stageId = new StageId(new QueryId("query"), 0); - SqlStageExecution stage = createSqlStageExecution( - stageId, - createExchangePlanFragment(), - ImmutableMap.of(), - new MockRemoteTaskFactory(executor, scheduledExecutor), - TEST_SESSION, - true, - nodeTaskMap, - executor, - new NoOpFailureDetector(), - new DynamicFilterService(createTestMetadataManager(), new TypeOperators(), new DynamicFilterConfig()), - new SplitSchedulerStats()); - stage.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY)); - - InternalNode node1 = new InternalNode("other1", URI.create("http://127.0.0.1:11"), NodeVersion.UNKNOWN, false); - InternalNode node2 = new InternalNode("other2", URI.create("http://127.0.0.2:12"), NodeVersion.UNKNOWN, false); - MockRemoteTask task1 = (MockRemoteTask) stage.scheduleTask(node1, 1).get(); - MockRemoteTask task2 = (MockRemoteTask) stage.scheduleTask(node2, 2).get(); - - // both tasks' buffers are under utilized - assertFalse(stage.isAnyTaskBlocked()); - - // set one of the task's buffer to be over utilized - task1.setOutputBufferOverUtilized(true); - assertTrue(stage.isAnyTaskBlocked()); - - // set both the tasks' buffers to be over utilized - task2.setOutputBufferOverUtilized(true); - assertTrue(stage.isAnyTaskBlocked()); - } - private static PlanFragment createExchangePlanFragment() { PlanNode planNode = new RemoteSourceNode( diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java index e3694ea733f4..9693c6d7b064 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java @@ -35,7 +35,6 @@ import java.util.concurrent.ExecutorService; import static io.airlift.concurrent.Threads.daemonThreadsNamed; -import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; @@ -76,14 +75,14 @@ public void testBasicStateChanges() assertTrue(stateMachine.transitionToScheduling()); assertState(stateMachine, StageState.SCHEDULING); - assertTrue(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); + assertTrue(stateMachine.transitionToPending()); + assertState(stateMachine, StageState.PENDING); + + assertTrue(stateMachine.transitionToRunning()); + assertState(stateMachine, StageState.RUNNING); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); @@ -103,10 +102,6 @@ public void testPlanned() assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - stateMachine = createStageStateMachine(); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); @@ -114,14 +109,6 @@ public void testPlanned() stateMachine = createStageStateMachine(); assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); } @Test @@ -134,19 +121,11 @@ public void testScheduling() assertFalse(stateMachine.transitionToScheduling()); assertState(stateMachine, StageState.SCHEDULING); - assertTrue(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - stateMachine = createStageStateMachine(); stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduling(); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - stateMachine = createStageStateMachine(); stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToFinished()); @@ -156,56 +135,6 @@ public void testScheduling() stateMachine.transitionToScheduling(); assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduling(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduling(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); - } - - @Test - public void testScheduled() - { - StageStateMachine stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - - assertFalse(stateMachine.transitionToScheduling()); - assertState(stateMachine, StageState.SCHEDULED); - - assertFalse(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.SCHEDULED); - - assertTrue(stateMachine.transitionToRunning()); - assertState(stateMachine, StageState.RUNNING); - - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToFinished()); - assertState(stateMachine, StageState.FINISHED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); - assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToScheduled(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); } @Test @@ -218,74 +147,24 @@ public void testRunning() assertFalse(stateMachine.transitionToScheduling()); assertState(stateMachine, StageState.RUNNING); - assertFalse(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.RUNNING); - assertFalse(stateMachine.transitionToRunning()); assertState(stateMachine, StageState.RUNNING); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToFinished()); - assertState(stateMachine, StageState.FINISHED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); - assertState(stateMachine, StageState.FAILED); + assertTrue(stateMachine.transitionToPending()); + assertState(stateMachine, StageState.PENDING); - stateMachine = createStageStateMachine(); - stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); + assertTrue(stateMachine.transitionToRunning()); + assertState(stateMachine, StageState.RUNNING); stateMachine = createStageStateMachine(); stateMachine.transitionToRunning(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); - } - - @Test - public void testFlushing() - { - StageStateMachine stateMachine = createStageStateMachine(); - assertTrue(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToScheduling()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToScheduled()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToRunning()); - assertState(stateMachine, StageState.FLUSHING); - - assertFalse(stateMachine.transitionToFlushing()); - assertState(stateMachine, StageState.FLUSHING); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); assertTrue(stateMachine.transitionToFinished()); assertState(stateMachine, StageState.FINISHED); stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); + stateMachine.transitionToRunning(); assertTrue(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, StageState.FAILED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); - assertTrue(stateMachine.transitionToAborted()); - assertState(stateMachine, StageState.ABORTED); - - stateMachine = createStageStateMachine(); - stateMachine.transitionToFlushing(); - assertTrue(stateMachine.transitionToCanceled()); - assertState(stateMachine, StageState.CANCELED); } @Test @@ -306,24 +185,6 @@ public void testFailed() assertFinalState(stateMachine, StageState.FAILED); } - @Test - public void testAborted() - { - StageStateMachine stateMachine = createStageStateMachine(); - - assertTrue(stateMachine.transitionToAborted()); - assertFinalState(stateMachine, StageState.ABORTED); - } - - @Test - public void testCanceled() - { - StageStateMachine stateMachine = createStageStateMachine(); - - assertTrue(stateMachine.transitionToCanceled()); - assertFinalState(stateMachine, StageState.CANCELED); - } - private static void assertFinalState(StageStateMachine stateMachine, StageState expectedState) { assertTrue(expectedState.isDone()); @@ -333,24 +194,18 @@ private static void assertFinalState(StageStateMachine stateMachine, StageState assertFalse(stateMachine.transitionToScheduling()); assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToScheduled()); + assertFalse(stateMachine.transitionToPending()); assertState(stateMachine, expectedState); assertFalse(stateMachine.transitionToRunning()); assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToFlushing()); - assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToFinished()); assertState(stateMachine, expectedState); assertFalse(stateMachine.transitionToFailed(FAILED_CAUSE)); assertState(stateMachine, expectedState); - assertFalse(stateMachine.transitionToAborted()); - assertState(stateMachine, expectedState); - // attempt to fail with another exception, which will fail assertFalse(stateMachine.transitionToFailed(new IOException("failure after finish"))); assertState(stateMachine, expectedState); @@ -359,7 +214,6 @@ private static void assertFinalState(StageStateMachine stateMachine, StageState private static void assertState(StageStateMachine stateMachine, StageState expectedState) { assertEquals(stateMachine.getStageId(), STAGE_ID); - assertSame(stateMachine.getSession(), TEST_SESSION); StageInfo stageInfo = stateMachine.getStageInfo(ImmutableList::of); assertEquals(stageInfo.getStageId(), STAGE_ID); @@ -383,7 +237,7 @@ private static void assertState(StageStateMachine stateMachine, StageState expec private StageStateMachine createStageStateMachine() { - return new StageStateMachine(STAGE_ID, TEST_SESSION, PLAN_FRAGMENT, ImmutableMap.of(), executor, new SplitSchedulerStats()); + return new StageStateMachine(STAGE_ID, PLAN_FRAGMENT, ImmutableMap.of(), executor, new SplitSchedulerStats()); } private static PlanFragment createValuesPlan() diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java index e951a17a3f7e..97b0e3fc27d4 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestBroadcastOutputBufferManager.java @@ -13,13 +13,10 @@ */ package io.trino.execution.scheduler; -import com.google.common.collect.ImmutableList; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import org.testng.annotations.Test; -import java.util.concurrent.atomic.AtomicReference; - import static io.trino.execution.buffer.OutputBuffers.BROADCAST_PARTITION_ID; import static io.trino.execution.buffer.OutputBuffers.BufferType.BROADCAST; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; @@ -30,34 +27,35 @@ public class TestBroadcastOutputBufferManager @Test public void test() { - AtomicReference outputBufferTarget = new AtomicReference<>(); - BroadcastOutputBufferManager hashOutputBufferManager = new BroadcastOutputBufferManager(outputBufferTarget::set); - assertEquals(outputBufferTarget.get(), createInitialEmptyOutputBuffers(BROADCAST)); + BroadcastOutputBufferManager hashOutputBufferManager = new BroadcastOutputBufferManager(); + assertEquals(hashOutputBufferManager.getOutputBuffers(), createInitialEmptyOutputBuffers(BROADCAST)); - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(0)), false); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(0)); OutputBuffers expectedOutputBuffers = createInitialEmptyOutputBuffers(BROADCAST).withBuffer(new OutputBufferId(0), BROADCAST_PARTITION_ID); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(1), new OutputBufferId(2)), false); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(1)); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(2)); expectedOutputBuffers = expectedOutputBuffers.withBuffer(new OutputBufferId(1), BROADCAST_PARTITION_ID); expectedOutputBuffers = expectedOutputBuffers.withBuffer(new OutputBufferId(2), BROADCAST_PARTITION_ID); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); // set no more buffers - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(3)), true); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(3)); + hashOutputBufferManager.noMoreBuffers(); expectedOutputBuffers = expectedOutputBuffers.withBuffer(new OutputBufferId(3), BROADCAST_PARTITION_ID); expectedOutputBuffers = expectedOutputBuffers.withNoMoreBufferIds(); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); // try to add another buffer, which should not result in an error // and output buffers should not change - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(5)), false); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(5)); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); // try to set no more buffers again, which should not result in an error // and output buffers should not change - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(6)), true); - assertEquals(outputBufferTarget.get(), expectedOutputBuffers); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(6)); + assertEquals(hashOutputBufferManager.getOutputBuffers(), expectedOutputBuffers); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java index 0fe667dd32a5..e209078c9dda 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestPartitionedOutputBufferManager.java @@ -13,13 +13,11 @@ */ package io.trino.execution.scheduler; -import com.google.common.collect.ImmutableList; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import org.testng.annotations.Test; import java.util.Map; -import java.util.concurrent.atomic.AtomicReference; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -32,30 +30,28 @@ public class TestPartitionedOutputBufferManager @Test public void test() { - AtomicReference outputBufferTarget = new AtomicReference<>(); - - PartitionedOutputBufferManager hashOutputBufferManager = new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 4, outputBufferTarget::set); + PartitionedOutputBufferManager hashOutputBufferManager = new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 4); // output buffers are set immediately when the manager is created - assertOutputBuffers(outputBufferTarget.get()); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); // add buffers, which does not cause an error - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(0)), false); - assertOutputBuffers(outputBufferTarget.get()); - hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(3)), true); - assertOutputBuffers(outputBufferTarget.get()); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(0)); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); + hashOutputBufferManager.addOutputBuffer(new OutputBufferId(3)); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); // try to a buffer out side of the partition range, which should result in an error - assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(5)), false)) + assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffer(new OutputBufferId(5))) .isInstanceOf(IllegalStateException.class) .hasMessage("Unexpected new output buffer 5"); - assertOutputBuffers(outputBufferTarget.get()); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); // try to a buffer out side of the partition range, which should result in an error - assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffers(ImmutableList.of(new OutputBufferId(6)), true)) + assertThatThrownBy(() -> hashOutputBufferManager.addOutputBuffer(new OutputBufferId(6))) .isInstanceOf(IllegalStateException.class) .hasMessage("Unexpected new output buffer 6"); - assertOutputBuffers(outputBufferTarget.get()); + assertOutputBuffers(hashOutputBufferManager.getOutputBuffers()); } private static void assertOutputBuffers(OutputBuffers outputBuffers) diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index d4f0c4a5e86b..10a5f10a3da8 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -30,7 +30,6 @@ import io.trino.execution.SqlStageExecution; import io.trino.execution.StageId; import io.trino.execution.TableInfo; -import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.failuredetector.NoOpFailureDetector; import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.InternalNode; @@ -83,10 +82,9 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.SessionTestUtils.TEST_SESSION; -import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; -import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.execution.scheduler.ScheduleResult.BlockedReason.SPLIT_QUEUES_FULL; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsStageScheduler; +import static io.trino.execution.scheduler.StreamingStageExecution.createStreamingStageExecution; import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.operator.StageExecutionDescriptor.ungroupedExecution; import static io.trino.spi.StandardErrorCode.NO_NODES_AVAILABLE; @@ -94,6 +92,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.DynamicFilters.createDynamicFilterExpression; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; @@ -111,7 +110,6 @@ public class TestSourcePartitionedScheduler { - public static final OutputBufferId OUT = new OutputBufferId(0); private static final CatalogName CONNECTOR_ID = TEST_TABLE_HANDLE.getCatalogName(); private static final QueryId QUERY_ID = new QueryId("query"); private static final DynamicFilterId DYNAMIC_FILTER_ID = new DynamicFilterId("filter1"); @@ -151,7 +149,7 @@ public void testScheduleNoSplits() { StageExecutionPlan plan = createPlan(createFixedSplitSource(0, TestingSplit::createRemoteSplit)); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); @@ -168,7 +166,7 @@ public void testScheduleSplitsOneAtATime() { StageExecutionPlan plan = createPlan(createFixedSplitSource(60, TestingSplit::createRemoteSplit)); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); @@ -205,7 +203,7 @@ public void testScheduleSplitsBatched() { StageExecutionPlan plan = createPlan(createFixedSplitSource(60, TestingSplit::createRemoteSplit)); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 7); @@ -242,7 +240,7 @@ public void testScheduleSplitsBlock() { StageExecutionPlan plan = createPlan(createFixedSplitSource(80, TestingSplit::createRemoteSplit)); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); @@ -307,7 +305,7 @@ public void testScheduleSlowSplitSource() QueuedSplitSource queuedSplitSource = new QueuedSplitSource(TestingSplit::createRemoteSplit); StageExecutionPlan plan = createPlan(queuedSplitSource); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = getSourcePartitionedScheduler(plan, stage, nodeManager, nodeTaskMap, 1); @@ -331,7 +329,7 @@ public void testNoNodes() NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap)); StageExecutionPlan plan = createPlan(createFixedSplitSource(20, TestingSplit::createRemoteSplit)); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( stage, @@ -358,7 +356,7 @@ public void testBalancedSplitAssignment() // Schedule 15 splits - there are 3 nodes, each node should get 5 splits StageExecutionPlan firstPlan = createPlan(createFixedSplitSource(15, TestingSplit::createRemoteSplit)); - SqlStageExecution firstStage = createSqlStageExecution(firstPlan, nodeTaskMap); + StreamingStageExecution firstStage = createStageExecution(firstPlan, nodeTaskMap); StageScheduler firstScheduler = getSourcePartitionedScheduler(firstPlan, firstStage, nodeManager, nodeTaskMap, 200); ScheduleResult scheduleResult = firstScheduler.schedule(); @@ -376,7 +374,7 @@ public void testBalancedSplitAssignment() // Schedule 5 splits in another query. Since the new node does not have any splits, all 5 splits are assigned to the new node StageExecutionPlan secondPlan = createPlan(createFixedSplitSource(5, TestingSplit::createRemoteSplit)); - SqlStageExecution secondStage = createSqlStageExecution(secondPlan, nodeTaskMap); + StreamingStageExecution secondStage = createStageExecution(secondPlan, nodeTaskMap); StageScheduler secondScheduler = getSourcePartitionedScheduler(secondPlan, secondStage, nodeManager, nodeTaskMap, 200); scheduleResult = secondScheduler.schedule(); @@ -404,7 +402,7 @@ public void testNewTaskScheduledWhenChildStageBufferIsUnderutilized() NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap, new Duration(0, SECONDS))); StageExecutionPlan plan = createPlan(createFixedSplitSource(500, TestingSplit::createRemoteSplit)); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); // setting under utilized child output buffer StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( @@ -443,7 +441,7 @@ public void testNoNewTaskScheduledWhenChildStageBufferIsOverutilized() NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap, new Duration(0, SECONDS))); StageExecutionPlan plan = createPlan(createFixedSplitSource(400, TestingSplit::createRemoteSplit)); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); // setting over utilized child output buffer StageScheduler scheduler = newSourcePartitionedSchedulerAsStageScheduler( @@ -474,7 +472,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() { StageExecutionPlan plan = createPlan(createBlockedSplitSource()); NodeTaskMap nodeTaskMap = new NodeTaskMap(finalizerService); - SqlStageExecution stage = createSqlStageExecution(plan, nodeTaskMap); + StreamingStageExecution stage = createStageExecution(plan, nodeTaskMap); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, new NodeSchedulerConfig().setIncludeCoordinator(false), nodeTaskMap)); DynamicFilterService dynamicFilterService = new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()); dynamicFilterService.registerQuery( @@ -512,7 +510,7 @@ public void testDynamicFiltersUnblockedOnBlockedBuildSource() assertEquals(scheduleResult.getSplitsScheduled(), 0); } - private static void assertPartitionedSplitCount(SqlStageExecution stage, int expectedPartitionedSplitCount) + private static void assertPartitionedSplitCount(StreamingStageExecution stage, int expectedPartitionedSplitCount) { assertEquals(stage.getAllTasks().stream().mapToInt(RemoteTask::getPartitionedSplitCount).sum(), expectedPartitionedSplitCount); } @@ -534,7 +532,7 @@ private static void assertEffectivelyFinished(ScheduleResult scheduleResult, Sta private StageScheduler getSourcePartitionedScheduler( StageExecutionPlan plan, - SqlStageExecution stage, + StreamingStageExecution stage, InternalNodeManager nodeManager, NodeTaskMap nodeTaskMap, int splitBatchSize) @@ -643,7 +641,7 @@ private static ConnectorSplitSource createFixedSplitSource(int splitCount, Suppl return new FixedSplitSource(splits.build()); } - private SqlStageExecution createSqlStageExecution(StageExecutionPlan tableScanPlan, NodeTaskMap nodeTaskMap) + private StreamingStageExecution createStageExecution(StageExecutionPlan tableScanPlan, NodeTaskMap nodeTaskMap) { StageId stageId = new StageId(QUERY_ID, 0); SqlStageExecution stage = SqlStageExecution.createSqlStageExecution(stageId, @@ -654,15 +652,30 @@ private SqlStageExecution createSqlStageExecution(StageExecutionPlan tableScanPl true, nodeTaskMap, queryExecutor, - new NoOpFailureDetector(), - new DynamicFilterService(metadata, typeOperators, new DynamicFilterConfig()), new SplitSchedulerStats()); - - stage.setOutputBuffers(createInitialEmptyOutputBuffers(PARTITIONED) - .withBuffer(OUT, 0) - .withNoMoreBufferIds()); - - return stage; + ImmutableMap.Builder outputBuffers = ImmutableMap.builder(); + outputBuffers.put(tableScanPlan.getFragment().getId(), new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 1)); + tableScanPlan.getFragment().getRemoteSourceNodes().stream() + .flatMap(node -> node.getSourceFragmentIds().stream()) + .forEach(fragmentId -> outputBuffers.put(fragmentId, new PartitionedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 10))); + return createStreamingStageExecution( + stage, + outputBuffers.build(), + new ResultsConsumer() + { + @Override + public void addSourceTask(PlanFragmentId fragmentId, RemoteTask sourceTask) + { + } + + @Override + public void noMoreSourceTasks(PlanFragmentId fragmentId) + { + } + }, + new NoOpFailureDetector(), + queryExecutor, + Optional.of(new int[] {0})); } private static class QueuedSplitSource diff --git a/testing/trino-tests/src/test/java/io/trino/execution/TestFlushingStageState.java b/testing/trino-tests/src/test/java/io/trino/execution/TestPendingStageState.java similarity index 87% rename from testing/trino-tests/src/test/java/io/trino/execution/TestFlushingStageState.java rename to testing/trino-tests/src/test/java/io/trino/execution/TestPendingStageState.java index 034aeb57ffeb..f5274623678f 100644 --- a/testing/trino-tests/src/test/java/io/trino/execution/TestFlushingStageState.java +++ b/testing/trino-tests/src/test/java/io/trino/execution/TestPendingStageState.java @@ -24,15 +24,13 @@ import static io.trino.SessionTestUtils.TEST_SESSION; import static io.trino.execution.QueryState.RUNNING; -import static io.trino.execution.StageState.CANCELED; -import static io.trino.execution.StageState.FLUSHING; import static io.trino.execution.TestQueryRunnerUtil.createQuery; import static io.trino.execution.TestQueryRunnerUtil.waitForQueryState; import static io.trino.testing.assertions.Assert.assertEventually; import static java.util.concurrent.TimeUnit.SECONDS; import static org.testng.Assert.assertEquals; -public class TestFlushingStageState +public class TestPendingStageState { private DistributedQueryRunner queryRunner; @@ -45,7 +43,7 @@ public void setup() } @Test(timeOut = 30_000) - public void testFlushingState() + public void testPendingState() throws Exception { QueryId queryId = createQuery(queryRunner, TEST_SESSION, "SELECT * FROM tpch.sf1000.lineitem limit 1"); @@ -54,18 +52,18 @@ public void testFlushingState() // wait for the query to finish producing results, but don't poll them assertEventually( new Duration(10, SECONDS), - () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getState(), FLUSHING)); + () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getState(), StageState.RUNNING)); - // wait for the sub stages to go to cancelled state + // wait for the sub stages to go to pending state assertEventually( new Duration(10, SECONDS), - () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getSubStages().get(0).getState(), CANCELED)); + () -> assertEquals(queryRunner.getCoordinator().getFullQueryInfo(queryId).getOutputStage().get().getSubStages().get(0).getState(), StageState.PENDING)); QueryInfo queryInfo = queryRunner.getCoordinator().getFullQueryInfo(queryId); assertEquals(queryInfo.getState(), RUNNING); - assertEquals(queryInfo.getOutputStage().get().getState(), FLUSHING); + assertEquals(queryInfo.getOutputStage().get().getState(), StageState.RUNNING); assertEquals(queryInfo.getOutputStage().get().getSubStages().size(), 1); - assertEquals(queryInfo.getOutputStage().get().getSubStages().get(0).getState(), CANCELED); + assertEquals(queryInfo.getOutputStage().get().getSubStages().get(0).getState(), StageState.PENDING); } @AfterClass(alwaysRun = true)