diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 98d889cdbb88..bf882a8af311 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -91,6 +91,7 @@ from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker from vllm.v1.worker.gpu.sample.sampler import Sampler +from vllm.v1.worker.gpu.shutdown import free_before_shutdown from vllm.v1.worker.gpu.spec_decode import init_speculator from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( set_eagle3_aux_hidden_state_layers, @@ -1339,6 +1340,24 @@ def postprocess_pool(self, input_batch: InputBatch) -> None: input_batch.num_scheduled_tokens ) + def shutdown(self) -> None: + """Release GPU tensors (model weights, KV caches, workspace) so that + memory is reclaimable when running in the same process.""" + torch.accelerator.synchronize() + if hasattr(self, "kv_caches"): + self.kv_caches.clear() + if hasattr(self, "attn_groups"): + self.attn_groups.clear() + if hasattr(self, "kv_cache_config"): + del self.kv_cache_config + free_before_shutdown(self.vllm_config) + if hasattr(self, "model"): + del self.model + + gc.collect() + torch.accelerator.empty_cache() + logger.debug("Cleaned up model weights, KV caches, and workspace") + ########### EPLB methods start ########### @property def eplb_state(self): diff --git a/vllm/v1/worker/gpu/shutdown.py b/vllm/v1/worker/gpu/shutdown.py new file mode 100644 index 000000000000..830083962347 --- /dev/null +++ b/vllm/v1/worker/gpu/shutdown.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm.config import VllmConfig +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def free_before_shutdown(vllm_config: VllmConfig) -> None: + from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT + from vllm.v1.worker.workspace import reset_workspace_manager + + cache_config = vllm_config.cache_config + cache_config.num_gpu_blocks = None + + compilation_config = vllm_config.compilation_config + compilation_config.static_forward_context.clear() + + _ROPE_DICT.clear() + reset_workspace_manager()