[MRV2] Add shutdown() method#41297
Conversation
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
There was a problem hiding this comment.
Claude Code Review
This repository is configured for manual code reviews. Comment @claude review to trigger a review and subscribe this PR to future pushes, or @claude review once for a one-time review.
Tip: disable this comment in your organization's Code Review settings.
|
@claude review once |
There was a problem hiding this comment.
Code Review
This pull request introduces a shutdown method to the GPU ModelRunner and a free_before_shutdown utility to release GPU memory and resources. Feedback highlights that the shutdown implementation is incomplete because it fails to clear several objects that hold GPU memory, such as model_state, cudagraph_manager, and speculator, which prevents the model weights from being fully reclaimed.
| def shutdown(self) -> None: | ||
| """Release GPU tensors (model weights, KV caches, workspace) so that | ||
| memory is reclaimable when running in the same process.""" | ||
| torch.accelerator.synchronize() | ||
| if hasattr(self, "kv_caches"): | ||
| self.kv_caches.clear() | ||
| if hasattr(self, "attn_groups"): | ||
| self.attn_groups.clear() | ||
| if hasattr(self, "kv_cache_config"): | ||
| del self.kv_cache_config | ||
| free_before_shutdown(self.vllm_config) | ||
| if hasattr(self, "model"): | ||
| del self.model | ||
|
|
||
| gc.collect() | ||
| torch.accelerator.empty_cache() | ||
| logger.debug("Cleaned up model weights, KV caches, and workspace") |
There was a problem hiding this comment.
The shutdown method is missing several critical components that hold significant GPU memory. Most importantly, self.model_state holds a reference to self.model, so del self.model will not actually free the model weights unless self.model_state is also deleted. Additionally, self.cudagraph_manager (which holds captured CUDA graphs), self.speculator, self.intermediate_tensors, self.encoder_cache, and self.pooling_runner should be explicitly deleted to ensure all memory is reclaimable when running in the same process.
def shutdown(self) -> None:
"""Release GPU tensors (model weights, KV caches, workspace) so that
memory is reclaimable when running in the same process."""
torch.accelerator.synchronize()
if hasattr(self, "kv_caches"):
self.kv_caches.clear()
if hasattr(self, "attn_groups"):
self.attn_groups.clear()
if hasattr(self, "kv_cache_config"):
del self.kv_cache_config
if hasattr(self, "model_state"):
del self.model_state
if hasattr(self, "cudagraph_manager"):
del self.cudagraph_manager
if hasattr(self, "speculator"):
del self.speculator
if hasattr(self, "intermediate_tensors"):
del self.intermediate_tensors
if hasattr(self, "encoder_cache"):
del self.encoder_cache
if hasattr(self, "pooling_runner"):
del self.pooling_runner
free_before_shutdown(self.vllm_config)
if hasattr(self, "model"):
del self.model
gc.collect()
torch.accelerator.empty_cache()
logger.debug("Cleaned up model weights, KV caches, and workspace")| def shutdown(self) -> None: | ||
| """Release GPU tensors (model weights, KV caches, workspace) so that | ||
| memory is reclaimable when running in the same process.""" | ||
| torch.accelerator.synchronize() | ||
| if hasattr(self, "kv_caches"): | ||
| self.kv_caches.clear() | ||
| if hasattr(self, "attn_groups"): | ||
| self.attn_groups.clear() | ||
| if hasattr(self, "kv_cache_config"): | ||
| del self.kv_cache_config | ||
| free_before_shutdown(self.vllm_config) | ||
| if hasattr(self, "model"): | ||
| del self.model | ||
|
|
||
| gc.collect() | ||
| torch.accelerator.empty_cache() | ||
| logger.debug("Cleaned up model weights, KV caches, and workspace") |
There was a problem hiding this comment.
🔴 The new shutdown() does not actually release the bulk of GPU memory it claims to free: self.model_state, self.speculator, self.pooling_runner, and self.cudagraph_manager all hold strong references to the model weights and/or large GPU buffers (captured CUDA graphs are typically several GiB), so del self.model followed by gc.collect()/empty_cache() cannot reclaim them. Also missing are self.intermediate_tensors (sized at max_num_batched_tokens on non-first PP ranks) and self.encoder_cache (multimodal embeddings dict). Fix by explicitly del'ing these attributes before the final gc.collect() (matches the inline review already left by gemini-code-assist).
Extended reasoning...
What is leaked
The shutdown method only does del self.model plus clearing kv_caches/attn_groups/kv_cache_config. It does not touch the following attributes on the runner instance, each of which independently pins GPU memory:
self.model_state—init_model_state(...)returns aDefaultModelState(orWhisperModelState);vllm/v1/worker/gpu/model_states/default.py:34doesself.model = model. So afterdel self.modelon the runner, the model module is still strongly referenced viaself.model_state.model— refcount > 0, gc cannot collect,empty_cache()cannot reclaim the weight memory.self.speculator—vllm/v1/worker/gpu/spec_decode/eagle/speculator.py:153doesself.model = load_eagle_model(target_model, self.vllm_config). The draft model is owned here, plus the speculator also holdsself.target_modelafterload_model(self.model). Without releasingself.speculator, both target and draft weights stay pinned.self.pooling_runner—vllm/v1/worker/gpu/pool/pooling_runner.py:20storesself.model = cast(VllmModelForPooling, model). For pooling models this is the same anti-pattern:del self.modelcannot free weights while the runner is alive.self.cudagraph_manager—ModelCudaGraphManagerholdsself.graphs: dict[..., torch.cuda.CUDAGraph]plus persistentself.hidden_states,self.aux_hidden_states, andself.intermediate_tensorsGPU buffers.capture_model()itself logs the size as "took %.2f GiB" — typically several GiB. Captured CUDA graphs additionally hold internal references to the model parameter tensors baked in at capture time.self.intermediate_tensors— for non-first PP ranks,load_model()(model_runner.py:319-323) creates a persistentIntermediateTensorssized atmax_num_batched_tokensviamake_empty_intermediate_tensors(...). Not instatic_forward_contextorkv_caches, so nothing else releases it.self.encoder_cache—vllm/v1/worker/gpu/mm/encoder_cache.py:13keepsencoder_outputs: dict[str, torch.Tensor]of multimodal vision embeddings on GPU.
Why the existing cleanup does not save this
The caller (vllm/v1/worker/gpu_worker.py) keeps the model_runner instance alive after shutdown() returns. Therefore any attribute not explicitly del'd (or cleared) remains attached to the runner; gc.collect() has nothing to collect because the refcounts are still > 0. torch.accelerator.empty_cache() only releases cached blocks that have actually been freed by PyTorch's allocator — it cannot release tensors still referenced by Python objects.
Step-by-step proof (target model weights, generative non-pooling case)
load_model()runs:self.model = model_loader.load_model(...)— refcount of model = 1.self.model_state = init_model_state(self.vllm_config, self.model, ...). InsideDefaultModelState.__init__(line 34):self.model = model. Now refcount of model = 2.self.speculator.load_model(self.model)(when speculative_config is set). The Eagle speculator stores its ownself.target_modelreference. Refcount = 3.capture_model()runs and the captured CUDA graphs internally retain pointers into model parameter tensors (this is how CUDA graphs work — the kernels are baked with the parameter device pointers).shutdown()is called. It doesdel self.model→ refcount drops to 2 (still alive viaself.model_state.modelandself.speculator.target_model).gc.collect()cannot collect the model — its refcount is still positive and there are no unreachable cycles to break.torch.accelerator.empty_cache()finds no freed parameter blocks to release. Model weights remain on GPU.
The same proof applies to (a) draft model via self.speculator.model, (b) pooling model weights via self.pooling_runner.model, and (c) all GPU buffers held inside self.cudagraph_manager / self.intermediate_tensors / self.encoder_cache.
Impact
This defeats the docstring's contract: "Release GPU tensors ... so that memory is reclaimable when running in the same process." In same-process scenarios (e.g., a hosting framework that loads/unloads vLLM engines), the bulk of GPU memory — model weights and the captured-graph memory pool, often the two largest allocations — is not actually reclaimable.
Fix
Add explicit del (or set to None) for self.model_state, self.speculator, self.cudagraph_manager, self.pooling_runner, self.intermediate_tensors, and self.encoder_cache before the final gc.collect(). This matches the inline patch the gemini-code-assist bot has already proposed on this PR. Order matters: drop the wrappers first so that the subsequent del self.model actually drops the last reference.
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Joachim Studnia <joachim@mistral.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Co-authored-by: hongbolv <33214277+hongbolv@users.noreply.github.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai> Signed-off-by: Ifta Khairul Alam Adil <ikaadil007@gmail.com>
Add a missing
shutdownmethod to MRV2