Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm/v1/worker/gpu/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.prompt_logprob import PromptLogprobsWorker
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.shutdown import free_before_shutdown
from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers,
Expand Down Expand Up @@ -1339,6 +1340,24 @@ def postprocess_pool(self, input_batch: InputBatch) -> None:
input_batch.num_scheduled_tokens
)

def shutdown(self) -> None:
"""Release GPU tensors (model weights, KV caches, workspace) so that
memory is reclaimable when running in the same process."""
torch.accelerator.synchronize()
if hasattr(self, "kv_caches"):
self.kv_caches.clear()
if hasattr(self, "attn_groups"):
self.attn_groups.clear()
if hasattr(self, "kv_cache_config"):
del self.kv_cache_config
free_before_shutdown(self.vllm_config)
if hasattr(self, "model"):
del self.model

gc.collect()
torch.accelerator.empty_cache()
logger.debug("Cleaned up model weights, KV caches, and workspace")
Comment on lines +1343 to +1359
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shutdown method is missing several critical components that hold significant GPU memory. Most importantly, self.model_state holds a reference to self.model, so del self.model will not actually free the model weights unless self.model_state is also deleted. Additionally, self.cudagraph_manager (which holds captured CUDA graphs), self.speculator, self.intermediate_tensors, self.encoder_cache, and self.pooling_runner should be explicitly deleted to ensure all memory is reclaimable when running in the same process.

    def shutdown(self) -> None:
        """Release GPU tensors (model weights, KV caches, workspace) so that
        memory is reclaimable when running in the same process."""
        torch.accelerator.synchronize()
        if hasattr(self, "kv_caches"):
            self.kv_caches.clear()
        if hasattr(self, "attn_groups"):
            self.attn_groups.clear()
        if hasattr(self, "kv_cache_config"):
            del self.kv_cache_config
        if hasattr(self, "model_state"):
            del self.model_state
        if hasattr(self, "cudagraph_manager"):
            del self.cudagraph_manager
        if hasattr(self, "speculator"):
            del self.speculator
        if hasattr(self, "intermediate_tensors"):
            del self.intermediate_tensors
        if hasattr(self, "encoder_cache"):
            del self.encoder_cache
        if hasattr(self, "pooling_runner"):
            del self.pooling_runner

        free_before_shutdown(self.vllm_config)
        if hasattr(self, "model"):
            del self.model

        gc.collect()
        torch.accelerator.empty_cache()
        logger.debug("Cleaned up model weights, KV caches, and workspace")

Comment on lines +1343 to +1359
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 The new shutdown() does not actually release the bulk of GPU memory it claims to free: self.model_state, self.speculator, self.pooling_runner, and self.cudagraph_manager all hold strong references to the model weights and/or large GPU buffers (captured CUDA graphs are typically several GiB), so del self.model followed by gc.collect()/empty_cache() cannot reclaim them. Also missing are self.intermediate_tensors (sized at max_num_batched_tokens on non-first PP ranks) and self.encoder_cache (multimodal embeddings dict). Fix by explicitly del'ing these attributes before the final gc.collect() (matches the inline review already left by gemini-code-assist).

Extended reasoning...

What is leaked

The shutdown method only does del self.model plus clearing kv_caches/attn_groups/kv_cache_config. It does not touch the following attributes on the runner instance, each of which independently pins GPU memory:

  1. self.model_stateinit_model_state(...) returns a DefaultModelState (or WhisperModelState); vllm/v1/worker/gpu/model_states/default.py:34 does self.model = model. So after del self.model on the runner, the model module is still strongly referenced via self.model_state.model — refcount > 0, gc cannot collect, empty_cache() cannot reclaim the weight memory.
  2. self.speculatorvllm/v1/worker/gpu/spec_decode/eagle/speculator.py:153 does self.model = load_eagle_model(target_model, self.vllm_config). The draft model is owned here, plus the speculator also holds self.target_model after load_model(self.model). Without releasing self.speculator, both target and draft weights stay pinned.
  3. self.pooling_runnervllm/v1/worker/gpu/pool/pooling_runner.py:20 stores self.model = cast(VllmModelForPooling, model). For pooling models this is the same anti-pattern: del self.model cannot free weights while the runner is alive.
  4. self.cudagraph_managerModelCudaGraphManager holds self.graphs: dict[..., torch.cuda.CUDAGraph] plus persistent self.hidden_states, self.aux_hidden_states, and self.intermediate_tensors GPU buffers. capture_model() itself logs the size as "took %.2f GiB" — typically several GiB. Captured CUDA graphs additionally hold internal references to the model parameter tensors baked in at capture time.
  5. self.intermediate_tensors — for non-first PP ranks, load_model() (model_runner.py:319-323) creates a persistent IntermediateTensors sized at max_num_batched_tokens via make_empty_intermediate_tensors(...). Not in static_forward_context or kv_caches, so nothing else releases it.
  6. self.encoder_cachevllm/v1/worker/gpu/mm/encoder_cache.py:13 keeps encoder_outputs: dict[str, torch.Tensor] of multimodal vision embeddings on GPU.

Why the existing cleanup does not save this

The caller (vllm/v1/worker/gpu_worker.py) keeps the model_runner instance alive after shutdown() returns. Therefore any attribute not explicitly del'd (or cleared) remains attached to the runner; gc.collect() has nothing to collect because the refcounts are still > 0. torch.accelerator.empty_cache() only releases cached blocks that have actually been freed by PyTorch's allocator — it cannot release tensors still referenced by Python objects.

Step-by-step proof (target model weights, generative non-pooling case)

  1. load_model() runs: self.model = model_loader.load_model(...) — refcount of model = 1.
  2. self.model_state = init_model_state(self.vllm_config, self.model, ...). Inside DefaultModelState.__init__ (line 34): self.model = model. Now refcount of model = 2.
  3. self.speculator.load_model(self.model) (when speculative_config is set). The Eagle speculator stores its own self.target_model reference. Refcount = 3.
  4. capture_model() runs and the captured CUDA graphs internally retain pointers into model parameter tensors (this is how CUDA graphs work — the kernels are baked with the parameter device pointers).
  5. shutdown() is called. It does del self.model → refcount drops to 2 (still alive via self.model_state.model and self.speculator.target_model).
  6. gc.collect() cannot collect the model — its refcount is still positive and there are no unreachable cycles to break.
  7. torch.accelerator.empty_cache() finds no freed parameter blocks to release. Model weights remain on GPU.

The same proof applies to (a) draft model via self.speculator.model, (b) pooling model weights via self.pooling_runner.model, and (c) all GPU buffers held inside self.cudagraph_manager / self.intermediate_tensors / self.encoder_cache.

Impact

This defeats the docstring's contract: "Release GPU tensors ... so that memory is reclaimable when running in the same process." In same-process scenarios (e.g., a hosting framework that loads/unloads vLLM engines), the bulk of GPU memory — model weights and the captured-graph memory pool, often the two largest allocations — is not actually reclaimable.

Fix

Add explicit del (or set to None) for self.model_state, self.speculator, self.cudagraph_manager, self.pooling_runner, self.intermediate_tensors, and self.encoder_cache before the final gc.collect(). This matches the inline patch the gemini-code-assist bot has already proposed on this PR. Order matters: drop the wrappers first so that the subsequent del self.model actually drops the last reference.


########### EPLB methods start ###########
@property
def eplb_state(self):
Expand Down
20 changes: 20 additions & 0 deletions vllm/v1/worker/gpu/shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.logger import init_logger

logger = init_logger(__name__)


def free_before_shutdown(vllm_config: VllmConfig) -> None:
from vllm.model_executor.layers.rotary_embedding import _ROPE_DICT
from vllm.v1.worker.workspace import reset_workspace_manager

cache_config = vllm_config.cache_config
cache_config.num_gpu_blocks = None

compilation_config = vllm_config.compilation_config
compilation_config.static_forward_context.clear()

_ROPE_DICT.clear()
reset_workspace_manager()
Loading