diff --git a/core/trino-main/src/main/java/io/trino/connector/system/NodeSystemTable.java b/core/trino-main/src/main/java/io/trino/connector/system/NodeSystemTable.java index 174635c44c50..4f573a063779 100644 --- a/core/trino-main/src/main/java/io/trino/connector/system/NodeSystemTable.java +++ b/core/trino-main/src/main/java/io/trino/connector/system/NodeSystemTable.java @@ -33,6 +33,8 @@ import static io.trino.metadata.MetadataUtil.TableMetadataBuilder.tableMetadataBuilder; import static io.trino.metadata.NodeState.ACTIVE; +import static io.trino.metadata.NodeState.DRAINED; +import static io.trino.metadata.NodeState.DRAINING; import static io.trino.metadata.NodeState.INACTIVE; import static io.trino.metadata.NodeState.SHUTTING_DOWN; import static io.trino.spi.connector.SystemTable.Distribution.SINGLE_COORDINATOR; @@ -81,6 +83,9 @@ public RecordCursor cursor(ConnectorTransactionHandle transactionHandle, Connect addRows(table, allNodes.getActiveNodes(), ACTIVE); addRows(table, allNodes.getInactiveNodes(), INACTIVE); addRows(table, allNodes.getShuttingDownNodes(), SHUTTING_DOWN); + addRows(table, allNodes.getDrainingNodes(), DRAINING); + addRows(table, allNodes.getDrainedNodes(), DRAINED); + return table.build().cursor(); } diff --git a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java index bae203ff5640..12744c32e582 100644 --- a/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java +++ b/core/trino-main/src/main/java/io/trino/memory/ClusterMemoryManager.java @@ -80,6 +80,8 @@ import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.SystemSessionProperties.resourceOvercommit; import static io.trino.metadata.NodeState.ACTIVE; +import static io.trino.metadata.NodeState.DRAINED; +import static io.trino.metadata.NodeState.DRAINING; import static io.trino.metadata.NodeState.SHUTTING_DOWN; import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY; import static java.lang.Math.min; @@ -433,6 +435,8 @@ private synchronized void updateNodes() Set aliveNodes = builder .addAll(nodeManager.getNodes(ACTIVE)) .addAll(nodeManager.getNodes(SHUTTING_DOWN)) + .addAll(nodeManager.getNodes(DRAINING)) + .addAll(nodeManager.getNodes(DRAINED)) .build(); ImmutableSet aliveNodeIds = aliveNodes.stream() diff --git a/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java b/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java index cecf16b9c9e7..75d34118a95e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java +++ b/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java @@ -24,13 +24,22 @@ public class AllNodes { private final Set activeNodes; private final Set inactiveNodes; + private final Set drainingNodes; + private final Set drainedNodes; private final Set shuttingDownNodes; private final Set activeCoordinators; - public AllNodes(Set activeNodes, Set inactiveNodes, Set shuttingDownNodes, Set activeCoordinators) + public AllNodes(Set activeNodes, + Set inactiveNodes, + Set drainingNodes, + Set drainedNodes, + Set shuttingDownNodes, + Set activeCoordinators) { this.activeNodes = ImmutableSet.copyOf(requireNonNull(activeNodes, "activeNodes is null")); this.inactiveNodes = ImmutableSet.copyOf(requireNonNull(inactiveNodes, "inactiveNodes is null")); + this.drainedNodes = ImmutableSet.copyOf(requireNonNull(drainedNodes, "drainedNodes is null")); + this.drainingNodes = ImmutableSet.copyOf(requireNonNull(drainingNodes, "drainingNodes is null")); this.shuttingDownNodes = ImmutableSet.copyOf(requireNonNull(shuttingDownNodes, "shuttingDownNodes is null")); this.activeCoordinators = ImmutableSet.copyOf(requireNonNull(activeCoordinators, "activeCoordinators is null")); } @@ -50,6 +59,16 @@ public Set getShuttingDownNodes() return shuttingDownNodes; } + public Set getDrainedNodes() + { + return drainedNodes; + } + + public Set getDrainingNodes() + { + return drainingNodes; + } + public Set getActiveCoordinators() { return activeCoordinators; @@ -67,6 +86,8 @@ public boolean equals(Object o) AllNodes allNodes = (AllNodes) o; return Objects.equals(activeNodes, allNodes.activeNodes) && Objects.equals(inactiveNodes, allNodes.inactiveNodes) && + Objects.equals(drainedNodes, allNodes.drainedNodes) && + Objects.equals(drainingNodes, allNodes.drainingNodes) && Objects.equals(shuttingDownNodes, allNodes.shuttingDownNodes) && Objects.equals(activeCoordinators, allNodes.activeCoordinators); } @@ -74,6 +95,6 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(activeNodes, inactiveNodes, shuttingDownNodes, activeCoordinators); + return Objects.hash(activeNodes, inactiveNodes, drainingNodes, drainedNodes, shuttingDownNodes, activeCoordinators); } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java b/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java index ebc32aa93729..4a6a26240fe5 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.SetMultimap; -import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import com.google.errorprone.annotations.ThreadSafe; import com.google.errorprone.annotations.concurrent.GuardedBy; @@ -56,7 +55,6 @@ import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.trino.connector.system.GlobalSystemConnector.CATALOG_HANDLE; -import static io.trino.metadata.NodeState.ACTIVE; import static io.trino.metadata.NodeState.INACTIVE; import static io.trino.metadata.NodeState.SHUTTING_DOWN; import static java.util.Locale.ENGLISH; @@ -169,6 +167,8 @@ private void pollWorkers() AllNodes allNodes = getAllNodes(); Set aliveNodes = ImmutableSet.builder() .addAll(allNodes.getActiveNodes()) + .addAll(allNodes.getDrainingNodes()) + .addAll(allNodes.getDrainedNodes()) .addAll(allNodes.getShuttingDownNodes()) .build(); @@ -216,6 +216,8 @@ private synchronized void refreshNodesInternal() ImmutableSet.Builder activeNodesBuilder = ImmutableSet.builder(); ImmutableSet.Builder inactiveNodesBuilder = ImmutableSet.builder(); + ImmutableSet.Builder drainingNodesBuilder = ImmutableSet.builder(); + ImmutableSet.Builder drainedNodesBuilder = ImmutableSet.builder(); ImmutableSet.Builder shuttingDownNodesBuilder = ImmutableSet.builder(); ImmutableSet.Builder coordinatorsBuilder = ImmutableSet.builder(); ImmutableSetMultimap.Builder byCatalogHandleBuilder = ImmutableSetMultimap.builder(); @@ -250,6 +252,12 @@ private synchronized void refreshNodesInternal() case INACTIVE: inactiveNodesBuilder.add(node); break; + case DRAINING: + drainingNodesBuilder.add(node); + break; + case DRAINED: + drainedNodesBuilder.add(node); + break; case SHUTTING_DOWN: shuttingDownNodesBuilder.add(node); break; @@ -260,12 +268,20 @@ private synchronized void refreshNodesInternal() } Set activeNodes = activeNodesBuilder.build(); + Set drainingNodes = drainingNodesBuilder.build(); + Set drainedNodes = drainedNodesBuilder.build(); Set inactiveNodes = inactiveNodesBuilder.build(); Set coordinators = coordinatorsBuilder.build(); Set shuttingDownNodes = shuttingDownNodesBuilder.build(); if (allNodes != null) { // log node that are no longer active (but not shutting down) - SetView missingNodes = difference(allNodes.getActiveNodes(), Sets.union(activeNodes, shuttingDownNodes)); + Set aliveNodes = ImmutableSet.builder() + .addAll(activeNodes) + .addAll(drainingNodes) + .addAll(drainedNodes) + .addAll(shuttingDownNodes) + .build(); + SetView missingNodes = difference(allNodes.getActiveNodes(), aliveNodes); for (InternalNode missingNode : missingNodes) { log.info("Previously active node is missing: %s (last seen at %s)", missingNode.getNodeIdentifier(), missingNode.getHost()); } @@ -276,7 +292,7 @@ private synchronized void refreshNodesInternal() activeNodesByCatalogHandle = Optional.of(byCatalogHandleBuilder.build()); } - AllNodes allNodes = new AllNodes(activeNodes, inactiveNodes, shuttingDownNodes, coordinators); + AllNodes allNodes = new AllNodes(activeNodes, inactiveNodes, drainingNodes, drainedNodes, shuttingDownNodes, coordinators); // only update if all nodes actually changed (note: this does not include the connectors registered with the nodes) if (!allNodes.equals(this.allNodes)) { // assign allNodes to a local variable for use in the callback below @@ -292,21 +308,17 @@ private synchronized void refreshNodesInternal() private NodeState getNodeState(InternalNode node) { if (expectedNodeVersion.equals(node.getNodeVersion())) { - if (isNodeShuttingDown(node.getNodeIdentifier())) { - return SHUTTING_DOWN; - } - return ACTIVE; + String nodeId = node.getNodeIdentifier(); + // The empty case that is being set to a default value of ACTIVE is limited to the case where a node + // has announced itself but no state has yet been successfully retrieved. RemoteNodeState will retain + // the previously known state if any has been reported. + return Optional.ofNullable(nodeStates.get(nodeId)) + .flatMap(RemoteNodeState::getNodeState) + .orElse(NodeState.ACTIVE); } return INACTIVE; } - private boolean isNodeShuttingDown(String nodeId) - { - return Optional.ofNullable(nodeStates.get(nodeId)) - .flatMap(RemoteNodeState::getNodeState) - .orElse(NodeState.ACTIVE) == SHUTTING_DOWN; - } - @Override public synchronized AllNodes getAllNodes() { @@ -325,6 +337,18 @@ public int getInactiveNodeCount() return getAllNodes().getInactiveNodes().size(); } + @Managed + public int getDrainingNodeCount() + { + return getAllNodes().getDrainingNodes().size(); + } + + @Managed + public int getDrainedNodeCount() + { + return getAllNodes().getDrainedNodes().size(); + } + @Managed public int getShuttingDownNodeCount() { @@ -337,6 +361,8 @@ public Set getNodes(NodeState state) return switch (state) { case ACTIVE -> getAllNodes().getActiveNodes(); case INACTIVE -> getAllNodes().getInactiveNodes(); + case DRAINING -> getAllNodes().getDrainingNodes(); + case DRAINED -> getAllNodes().getDrainedNodes(); case SHUTTING_DOWN -> getAllNodes().getShuttingDownNodes(); }; } diff --git a/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java b/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java index 9312dbbf7a1c..31a1027b0bb9 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java @@ -61,7 +61,7 @@ public Set getNodes(NodeState state) { return switch (state) { case ACTIVE -> ImmutableSet.copyOf(allNodes); - case INACTIVE, SHUTTING_DOWN -> ImmutableSet.of(); + case DRAINING, DRAINED, INACTIVE, SHUTTING_DOWN -> ImmutableSet.of(); }; } @@ -84,6 +84,8 @@ public AllNodes getAllNodes() ImmutableSet.copyOf(allNodes), ImmutableSet.of(), ImmutableSet.of(), + ImmutableSet.of(), + ImmutableSet.of(), ImmutableSet.of(CURRENT_NODE)); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/NodeState.java b/core/trino-main/src/main/java/io/trino/metadata/NodeState.java index 4bf511968a31..e3370da594d6 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/NodeState.java +++ b/core/trino-main/src/main/java/io/trino/metadata/NodeState.java @@ -15,7 +15,24 @@ public enum NodeState { + /** + * Server is up and running ready to handle tasks + */ ACTIVE, + /** + * Never used internally, might be used by discoveryNodeManager when communication error occurs + */ INACTIVE, + /** + * A reversible graceful shutdown, can go to forward to DRAINED or back to ACTIVE. + */ + DRAINING, + /** + * All tasks are finished, server can be safely and quickly stopped. Can also go back to ACTIVE. + */ + DRAINED, + /** + * Graceful shutdown, non-reversible, when observed will drain and terminate + */ SHUTTING_DOWN } diff --git a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownHandler.java b/core/trino-main/src/main/java/io/trino/server/GracefulShutdownHandler.java deleted file mode 100644 index fbef6a2c215b..000000000000 --- a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownHandler.java +++ /dev/null @@ -1,157 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.trino.server; - -import com.google.errorprone.annotations.concurrent.GuardedBy; -import com.google.inject.Inject; -import io.airlift.bootstrap.LifeCycleManager; -import io.airlift.log.Logger; -import io.airlift.units.Duration; -import io.trino.execution.SqlTaskManager; -import io.trino.execution.TaskInfo; - -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeoutException; - -import static com.google.common.collect.ImmutableList.toImmutableList; -import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; -import static io.airlift.concurrent.Threads.threadsNamed; -import static java.lang.Thread.currentThread; -import static java.util.Objects.requireNonNull; -import static java.util.concurrent.Executors.newSingleThreadExecutor; -import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; - -public class GracefulShutdownHandler -{ - private static final Logger log = Logger.get(GracefulShutdownHandler.class); - private static final Duration LIFECYCLE_STOP_TIMEOUT = new Duration(30, SECONDS); - - private final ScheduledExecutorService shutdownHandler = newSingleThreadScheduledExecutor(threadsNamed("shutdown-handler-%s")); - private final ExecutorService lifeCycleStopper = newSingleThreadExecutor(threadsNamed("lifecycle-stopper-%s")); - private final LifeCycleManager lifeCycleManager; - private final SqlTaskManager sqlTaskManager; - private final boolean isCoordinator; - private final ShutdownAction shutdownAction; - private final Duration gracePeriod; - - @GuardedBy("this") - private boolean shutdownRequested; - - @Inject - public GracefulShutdownHandler( - SqlTaskManager sqlTaskManager, - ServerConfig serverConfig, - ShutdownAction shutdownAction, - LifeCycleManager lifeCycleManager) - { - this.sqlTaskManager = requireNonNull(sqlTaskManager, "sqlTaskManager is null"); - this.shutdownAction = requireNonNull(shutdownAction, "shutdownAction is null"); - this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); - this.isCoordinator = serverConfig.isCoordinator(); - this.gracePeriod = serverConfig.getGracePeriod(); - } - - public synchronized void requestShutdown() - { - log.info("Shutdown requested"); - - if (isCoordinator) { - throw new UnsupportedOperationException("Cannot shutdown coordinator"); - } - - if (shutdownRequested) { - return; - } - shutdownRequested = true; - - // wait for a grace period (so that shutting down state is observed by the coordinator) to start the shutdown sequence - shutdownHandler.schedule(this::shutdown, gracePeriod.toMillis(), MILLISECONDS); - } - - private void shutdown() - { - List activeTasks = getActiveTasks(); - - // At this point no new tasks should be scheduled by coordinator on this worker node. - // Wait for all remaining tasks to finish. - while (activeTasks.size() > 0) { - CountDownLatch countDownLatch = new CountDownLatch(activeTasks.size()); - - for (TaskInfo taskInfo : activeTasks) { - sqlTaskManager.addStateChangeListener(taskInfo.taskStatus().getTaskId(), newState -> { - if (newState.isDone()) { - countDownLatch.countDown(); - } - }); - } - - log.info("Waiting for all tasks to finish"); - - try { - countDownLatch.await(); - } - catch (InterruptedException e) { - log.warn("Interrupted while waiting for all tasks to finish"); - currentThread().interrupt(); - } - - activeTasks = getActiveTasks(); - } - - // wait for another grace period for all task states to be observed by the coordinator - sleepUninterruptibly(gracePeriod.toMillis(), MILLISECONDS); - - Future shutdownFuture = lifeCycleStopper.submit(() -> { - lifeCycleManager.stop(); - return null; - }); - - // terminate the jvm if life cycle cannot be stopped in a timely manner - try { - shutdownFuture.get(LIFECYCLE_STOP_TIMEOUT.toMillis(), MILLISECONDS); - } - catch (TimeoutException e) { - log.warn(e, "Timed out waiting for the life cycle to stop"); - } - catch (InterruptedException e) { - log.warn(e, "Interrupted while waiting for the life cycle to stop"); - currentThread().interrupt(); - } - catch (ExecutionException e) { - log.warn(e, "Problem stopping the life cycle"); - } - - shutdownAction.onShutdown(); - } - - private List getActiveTasks() - { - return sqlTaskManager.getAllTaskInfo() - .stream() - .filter(taskInfo -> !taskInfo.taskStatus().getState().isDone()) - .collect(toImmutableList()); - } - - public synchronized boolean isShutdownRequested() - { - return shutdownRequested; - } -} diff --git a/core/trino-main/src/main/java/io/trino/server/NodeStateManager.java b/core/trino-main/src/main/java/io/trino/server/NodeStateManager.java new file mode 100644 index 000000000000..8b9d3cdaa577 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/NodeStateManager.java @@ -0,0 +1,283 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import com.google.inject.Inject; +import io.airlift.bootstrap.LifeCycleManager; +import io.airlift.log.Logger; +import io.airlift.units.Duration; +import io.trino.execution.SqlTaskManager; +import io.trino.execution.TaskInfo; +import io.trino.metadata.NodeState; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; +import static io.airlift.concurrent.Threads.threadsNamed; +import static io.trino.metadata.NodeState.ACTIVE; +import static io.trino.metadata.NodeState.DRAINED; +import static io.trino.metadata.NodeState.DRAINING; +import static io.trino.metadata.NodeState.SHUTTING_DOWN; +import static java.lang.String.format; +import static java.lang.Thread.currentThread; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadExecutor; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +public class NodeStateManager +{ + private static final Logger log = Logger.get(NodeStateManager.class); + private static final Duration LIFECYCLE_STOP_TIMEOUT = new Duration(30, SECONDS); + + private final ScheduledExecutorService shutdownHandler = newSingleThreadScheduledExecutor(threadsNamed("shutdown-handler-%s")); + private final ExecutorService lifeCycleStopper = newSingleThreadExecutor(threadsNamed("lifecycle-stopper-%s")); + private final LifeCycleManager lifeCycleManager; + private final SqlTaskManager sqlTaskManager; + private final boolean isCoordinator; + private final ShutdownAction shutdownAction; + private final Duration gracePeriod; + + private final ScheduledExecutorService executor = newSingleThreadScheduledExecutor(threadsNamed("drain-handler-%s")); + private final AtomicReference nodeState = new AtomicReference<>(ACTIVE); + + @Inject + public NodeStateManager( + SqlTaskManager sqlTaskManager, + ServerConfig serverConfig, + ShutdownAction shutdownAction, + LifeCycleManager lifeCycleManager) + { + this.sqlTaskManager = requireNonNull(sqlTaskManager, "sqlTaskManager is null"); + this.shutdownAction = requireNonNull(shutdownAction, "shutdownAction is null"); + this.lifeCycleManager = requireNonNull(lifeCycleManager, "lifeCycleManager is null"); + this.isCoordinator = serverConfig.isCoordinator(); + this.gracePeriod = serverConfig.getGracePeriod(); + } + + public NodeState getServerState() + { + return nodeState.get(); + } + + /* + Below is a diagram with possible states and transitions + + @startuml + [*] --> ACTIVE + note "state INACTIVE is not used internally\nis only used when the service is unavailable " as a + ACTIVE --> SHUTTING_DOWN : shutdown + ACTIVE --> DRAINING : drain + DRAINING --> ACTIVE: reactivate + DRAINING --> DRAINED + DRAINING --> SHUTTING_DOWN : gracefulShutdown + DRAINED --> ACTIVE: reactivate + DRAINED --> SHUTTING_DOWN : terminate + SHUTTING_DOWN --> [*] + @enduml + + NOTE: SHUTTING_DOWN is treated as one-way transition to be 100% backwards compatible. + */ + public synchronized void transitionState(NodeState state) + { + NodeState currState = nodeState.get(); + if (currState == state) { + return; + } + + switch (state) { + case ACTIVE -> { + if (currState == DRAINING && nodeState.compareAndSet(DRAINING, ACTIVE)) { + return; + } + if (currState == DRAINED && nodeState.compareAndSet(DRAINED, ACTIVE)) { + return; + } + } + case SHUTTING_DOWN -> { + if (currState == DRAINED && nodeState.compareAndSet(DRAINED, SHUTTING_DOWN)) { + requestTerminate(); + return; + } + requestGracefulShutdown(); + nodeState.set(SHUTTING_DOWN); + return; + } + case DRAINING -> { + if (currState == ACTIVE && nodeState.compareAndSet(ACTIVE, DRAINING)) { + requestDrain(); + return; + } + } + case DRAINED -> throw new IllegalStateException(format("Invalid state transition from %s to %s, transition to DRAINED is internal only", currState, state)); + + case INACTIVE -> throw new IllegalStateException(format("Invalid state transition from %s to %s, INACTIVE is not a valid internal state", currState, state)); + } + + throw new IllegalStateException(format("Invalid state transition from %s to %s", currState, state)); + } + + private synchronized void requestDrain() + { + log.debug("Drain requested, NodeState: " + getServerState()); + if (isCoordinator) { + throw new UnsupportedOperationException("Cannot drain coordinator"); + } + + // wait for a grace period (so that draining state is observed by the coordinator) before starting draining + // when coordinator observes draining no new tasks are assigned to this worker + executor.schedule(this::drain, gracePeriod.toMillis(), MILLISECONDS); + } + + private void requestTerminate() + { + log.info("Immediate Shutdown requested"); + if (isCoordinator) { + throw new UnsupportedOperationException("Cannot shutdown coordinator"); + } + + shutdownHandler.schedule(this::terminate, 0, MILLISECONDS); + } + + private void requestGracefulShutdown() + { + log.info("Shutdown requested"); + if (isCoordinator) { + throw new UnsupportedOperationException("Cannot shutdown coordinator"); + } + + // wait for a grace period (so that shutting down state is observed by the coordinator) to start the shutdown sequence + shutdownHandler.schedule(this::shutdown, gracePeriod.toMillis(), MILLISECONDS); + } + + private void shutdown() + { + waitActiveTasksToFinish(); + + terminate(); + } + + private void terminate() + { + Future shutdownFuture = lifeCycleStopper.submit(() -> { + lifeCycleManager.stop(); + return null; + }); + // terminate the jvm if life cycle cannot be stopped in a timely manner + try { + shutdownFuture.get(LIFECYCLE_STOP_TIMEOUT.toMillis(), MILLISECONDS); + } + catch (TimeoutException e) { + log.warn(e, "Timed out waiting for the life cycle to stop"); + } + catch (InterruptedException e) { + log.warn(e, "Interrupted while waiting for the life cycle to stop"); + currentThread().interrupt(); + } + catch (ExecutionException e) { + log.warn(e, "Problem stopping the life cycle"); + } + shutdownAction.onShutdown(); + } + + private void drain() + { + if (nodeState.get() == DRAINING) { + waitActiveTasksToFinish(); + } + drainingComplete(); + } + + private void drainingComplete() + { + boolean success = nodeState.compareAndSet(DRAINING, DRAINED); + if (success) { + log.info("NodeState: DRAINED, server can be safely SHUT DOWN."); + } + else { + log.info("NodeState: " + nodeState.get() + ", will not transition to DRAINED"); + } + } + + private void waitActiveTasksToFinish() + { + // At this point no new tasks should be scheduled by coordinator on this worker node. + // Wait for all remaining tasks to finish. + while (isShuttingDownOrDraining()) { + List activeTasks = getActiveTasks(); + log.info("Waiting for " + activeTasks.size() + " active tasks to finish"); + if (activeTasks.isEmpty()) { + break; + } + + waitTasksToFinish(activeTasks); + } + + // wait for another grace period for all task states to be observed by the coordinator + if (isShuttingDownOrDraining()) { + sleepUninterruptibly(gracePeriod.toMillis(), MILLISECONDS); + } + } + + private void waitTasksToFinish(List activeTasks) + { + final CountDownLatch countDownLatch = new CountDownLatch(activeTasks.size()); + + for (TaskInfo taskInfo : activeTasks) { + sqlTaskManager.addStateChangeListener(taskInfo.taskStatus().getTaskId(), newState -> { + if (newState.isDone()) { + countDownLatch.countDown(); + } + }); + } + + try { + while (!countDownLatch.await(1, TimeUnit.SECONDS)) { + if (!isShuttingDownOrDraining()) { + log.info("Wait for tasks interrupted, worker is no longer draining."); + + break; + } + } + } + catch (InterruptedException e) { + log.warn("Interrupted while waiting for all tasks to finish"); + currentThread().interrupt(); + } + } + + private boolean isShuttingDownOrDraining() + { + NodeState state = nodeState.get(); + return state == SHUTTING_DOWN || state == DRAINING; + } + + private List getActiveTasks() + { + return sqlTaskManager.getAllTaskInfo() + .stream() + .filter(taskInfo -> !taskInfo.taskStatus().getState().isDone()) + .collect(toImmutableList()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownModule.java b/core/trino-main/src/main/java/io/trino/server/NodeStateManagerModule.java similarity index 89% rename from core/trino-main/src/main/java/io/trino/server/GracefulShutdownModule.java rename to core/trino-main/src/main/java/io/trino/server/NodeStateManagerModule.java index cdb83a90e1f7..02422c2e47e6 100644 --- a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownModule.java +++ b/core/trino-main/src/main/java/io/trino/server/NodeStateManagerModule.java @@ -17,13 +17,13 @@ import com.google.inject.Scopes; import io.airlift.configuration.AbstractConfigurationAwareModule; -public class GracefulShutdownModule +public class NodeStateManagerModule extends AbstractConfigurationAwareModule { @Override protected void setup(Binder binder) { binder.bind(ShutdownAction.class).to(DefaultShutdownAction.class).in(Scopes.SINGLETON); - binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); + binder.bind(NodeStateManager.class).in(Scopes.SINGLETON); } } diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index f5ba8fbb1c55..fd7bf81b84bb 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -128,7 +128,7 @@ private void doStart(String trinoVersion) new CatalogManagerModule(), new TransactionManagerModule(), new ServerMainModule(trinoVersion), - new GracefulShutdownModule(), + new NodeStateManagerModule(), new WarningCollectorModule()); modules.addAll(getAdditionalModules()); diff --git a/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java b/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java index 2064080b3506..f85d7a5b6c23 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java @@ -14,6 +14,7 @@ package io.trino.server; import com.google.inject.Inject; +import io.airlift.log.Logger; import io.airlift.node.NodeInfo; import io.trino.client.NodeVersion; import io.trino.client.ServerInfo; @@ -31,8 +32,6 @@ import java.util.Optional; import static io.airlift.units.Duration.nanosSince; -import static io.trino.metadata.NodeState.ACTIVE; -import static io.trino.metadata.NodeState.SHUTTING_DOWN; import static io.trino.server.security.ResourceSecurity.AccessType.MANAGEMENT_WRITE; import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC; import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON; @@ -43,20 +42,22 @@ @Path("/v1/info") public class ServerInfoResource { + private static final Logger log = Logger.get(ServerInfoResource.class); + private final NodeVersion version; private final String environment; private final boolean coordinator; - private final GracefulShutdownHandler shutdownHandler; + private final NodeStateManager nodeStateManager; private final StartupStatus startupStatus; private final long startTime = System.nanoTime(); @Inject - public ServerInfoResource(NodeVersion nodeVersion, NodeInfo nodeInfo, ServerConfig serverConfig, GracefulShutdownHandler shutdownHandler, StartupStatus startupStatus) + public ServerInfoResource(NodeVersion nodeVersion, NodeInfo nodeInfo, ServerConfig serverConfig, NodeStateManager nodeStateManager, StartupStatus startupStatus) { this.version = requireNonNull(nodeVersion, "nodeVersion is null"); this.environment = nodeInfo.getEnvironment(); this.coordinator = serverConfig.isCoordinator(); - this.shutdownHandler = requireNonNull(shutdownHandler, "shutdownHandler is null"); + this.nodeStateManager = requireNonNull(nodeStateManager, "nodeStateManager is null"); this.startupStatus = requireNonNull(startupStatus, "startupStatus is null"); } @@ -77,13 +78,14 @@ public ServerInfo getInfo() public Response updateState(NodeState state) { requireNonNull(state, "state is null"); - return switch (state) { - case SHUTTING_DOWN -> { - shutdownHandler.requestShutdown(); - yield Response.ok().build(); - } - case ACTIVE, INACTIVE -> throw new BadRequestException(format("Invalid state transition to %s", state)); - }; + log.info("Worker State change requested: %s -> %s", nodeStateManager.getServerState().toString(), state.toString()); + try { + nodeStateManager.transitionState(state); + return Response.ok().build(); + } + catch (IllegalStateException e) { + throw new BadRequestException(format("Invalid state transition to %s", state)); + } } @ResourceSecurity(PUBLIC) @@ -92,10 +94,7 @@ public Response updateState(NodeState state) @Produces(APPLICATION_JSON) public NodeState getServerState() { - if (shutdownHandler.isShutdownRequested()) { - return SHUTTING_DOWN; - } - return ACTIVE; + return nodeStateManager.getServerState(); } @ResourceSecurity(PUBLIC) diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index 042ae70e3e7d..321628772c14 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -75,7 +75,7 @@ import io.trino.security.AccessControlConfig; import io.trino.security.AccessControlManager; import io.trino.security.GroupProviderManager; -import io.trino.server.GracefulShutdownHandler; +import io.trino.server.NodeStateManager; import io.trino.server.PluginInstaller; import io.trino.server.PrefixObjectNameGeneratorModule; import io.trino.server.QuerySessionSupplier; @@ -209,7 +209,7 @@ public static Builder builder() private final DispatchManager dispatchManager; private final SqlQueryManager queryManager; private final SqlTaskManager taskManager; - private final GracefulShutdownHandler gracefulShutdownHandler; + private final NodeStateManager nodeStateManager; private final ShutdownAction shutdownAction; private final MBeanServer mBeanServer; private final boolean coordinator; @@ -327,7 +327,7 @@ private TestingTrinoServer( binder.bind(AccessControl.class).annotatedWith(ForTracing.class).to(AccessControlManager.class).in(Scopes.SINGLETON); binder.bind(AccessControl.class).to(TracingAccessControl.class).in(Scopes.SINGLETON); binder.bind(ShutdownAction.class).to(TestShutdownAction.class).in(Scopes.SINGLETON); - binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); + binder.bind(NodeStateManager.class).in(Scopes.SINGLETON); binder.bind(ProcedureTester.class).in(Scopes.SINGLETON); binder.bind(ExchangeManagerRegistry.class).in(Scopes.SINGLETON); spanProcessor.ifPresent(processor -> newSetBinder(binder, SpanProcessor.class).addBinding().toInstance(processor)); @@ -426,7 +426,7 @@ private TestingTrinoServer( localMemoryManager = injector.getInstance(LocalMemoryManager.class); nodeManager = injector.getInstance(InternalNodeManager.class); serviceSelectorManager = injector.getInstance(ServiceSelectorManager.class); - gracefulShutdownHandler = injector.getInstance(GracefulShutdownHandler.class); + nodeStateManager = injector.getInstance(NodeStateManager.class); taskManager = injector.getInstance(SqlTaskManager.class); shutdownAction = injector.getInstance(ShutdownAction.class); mBeanServer = injector.getInstance(MBeanServer.class); @@ -666,9 +666,9 @@ public MBeanServer getMbeanServer() return mBeanServer; } - public GracefulShutdownHandler getGracefulShutdownHandler() + public NodeStateManager getNodeStateManager() { - return gracefulShutdownHandler; + return nodeStateManager; } public SqlTaskManager getTaskManager() diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java b/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java index ecb52a979e8c..06545ea9bffc 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java @@ -20,6 +20,7 @@ import com.google.common.util.concurrent.MoreExecutors; import io.trino.Session; import io.trino.execution.SqlTaskManager; +import io.trino.metadata.NodeState; import io.trino.server.BasicQueryInfo; import io.trino.server.testing.TestingTrinoServer; import io.trino.server.testing.TestingTrinoServer.TestShutdownAction; @@ -105,7 +106,7 @@ public void testShutdown() MILLISECONDS.sleep(500); } - worker.getGracefulShutdownHandler().requestShutdown(); + worker.getNodeStateManager().transitionState(NodeState.SHUTTING_DOWN); Futures.allAsList(queryFutures).get(); @@ -131,7 +132,7 @@ public void testCoordinatorShutdown() .filter(TestingTrinoServer::isCoordinator) .collect(onlyElement()); - assertThatThrownBy(coordinator.getGracefulShutdownHandler()::requestShutdown) + assertThatThrownBy(() -> coordinator.getNodeStateManager().transitionState(NodeState.SHUTTING_DOWN)) .isInstanceOf(UnsupportedOperationException.class) .hasMessage("Cannot shutdown coordinator"); }