-
-
Notifications
You must be signed in to change notification settings - Fork 16.6k
[MRV2] Add shutdown() method #41297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRV2] Add shutdown() method #41297
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -91,6 +91,7 @@ | |
| from vllm.v1.worker.gpu.sample.output import SamplerOutput | ||
| from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker | ||
| from vllm.v1.worker.gpu.sample.sampler import Sampler | ||
| from vllm.v1.worker.gpu.shutdown import free_before_shutdown | ||
| from vllm.v1.worker.gpu.spec_decode import init_speculator | ||
| from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( | ||
| set_eagle3_aux_hidden_state_layers, | ||
|
|
@@ -1339,6 +1340,24 @@ def postprocess_pool(self, input_batch: InputBatch) -> None: | |
| input_batch.num_scheduled_tokens | ||
| ) | ||
|
|
||
| def shutdown(self) -> None: | ||
| """Release GPU tensors (model weights, KV caches, workspace) so that | ||
| memory is reclaimable when running in the same process.""" | ||
| torch.accelerator.synchronize() | ||
| if hasattr(self, "kv_caches"): | ||
| self.kv_caches.clear() | ||
| if hasattr(self, "attn_groups"): | ||
| self.attn_groups.clear() | ||
| if hasattr(self, "kv_cache_config"): | ||
| del self.kv_cache_config | ||
| free_before_shutdown(self.vllm_config) | ||
| if hasattr(self, "model"): | ||
| del self.model | ||
|
|
||
| gc.collect() | ||
| torch.accelerator.empty_cache() | ||
| logger.debug("Cleaned up model weights, KV caches, and workspace") | ||
|
Comment on lines
+1343
to
+1359
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 The new Extended reasoning...What is leaked The shutdown method only does
Why the existing cleanup does not save this The caller ( Step-by-step proof (target model weights, generative non-pooling case)
The same proof applies to (a) draft model via Impact This defeats the docstring's contract: "Release GPU tensors ... so that memory is reclaimable when running in the same process." In same-process scenarios (e.g., a hosting framework that loads/unloads vLLM engines), the bulk of GPU memory — model weights and the captured-graph memory pool, often the two largest allocations — is not actually reclaimable. Fix Add explicit |
||
|
|
||
| ########### EPLB methods start ########### | ||
| @property | ||
| def eplb_state(self): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from vllm.config import VllmConfig | ||
| from vllm.logger import init_logger | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| def free_before_shutdown(vllm_config: VllmConfig) -> None: | ||
| from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT | ||
| from vllm.v1.worker.workspace import reset_workspace_manager | ||
|
|
||
| cache_config = vllm_config.cache_config | ||
| cache_config.num_gpu_blocks = None | ||
|
|
||
| compilation_config = vllm_config.compilation_config | ||
| compilation_config.static_forward_context.clear() | ||
|
|
||
| _ROPE_DICT.clear() | ||
| reset_workspace_manager() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
shutdownmethod is missing several critical components that hold significant GPU memory. Most importantly,self.model_stateholds a reference toself.model, sodel self.modelwill not actually free the model weights unlessself.model_stateis also deleted. Additionally,self.cudagraph_manager(which holds captured CUDA graphs),self.speculator,self.intermediate_tensors,self.encoder_cache, andself.pooling_runnershould be explicitly deleted to ensure all memory is reclaimable when running in the same process.