Skip to content

fix CUDAGraph memory being counted twice#37426

Merged
MatthewBonanni merged 4 commits intovllm-project:mainfrom
panpan0000:fix_mem_cacu_bug
Mar 20, 2026
Merged

fix CUDAGraph memory being counted twice#37426
MatthewBonanni merged 4 commits intovllm-project:mainfrom
panpan0000:fix_mem_cacu_bug

Conversation

@panpan0000
Copy link
Contributor

@panpan0000 panpan0000 commented Mar 18, 2026

Purpose

Fix twice-counted CUDAGraph memory

When VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1, CUDAGraph memory appears to be counted twice in the KV cache recommendation path.

In determine_available_memory , self.peak_activation_memory already includes the estimated CUDAGraph memory:

cudagraph_memory_estimate_applied = (
    cudagraph_memory_estimate
    if envs.VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
    else 0
)
self.peak_activation_memory = (
    profile_result.torch_peak_increase + cudagraph_memory_estimate_applied <--------- here
)

Later, in compile_or_warm_up_model , non_kv_cache_memory adds self.peak_activation_memory and also adds cuda_graph_memory_bytes (actual captured CUDAGraph memory), which can introduce double counting!!!!

This does not cause a real memory leak or runtime OOM , since this value is used for logging/recommendation
However, it overestimates non_kv_cache_memory , which leads to a smaller-than-necessary suggested KV cache upper bound

haha , when I'm fixing #37420, accidentally found this problem .

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a double-counting issue with CUDA graph memory estimation. When VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS is enabled, peak_activation_memory includes an estimated value for CUDA graph memory. The change correctly subtracts this estimate before adding the actual captured CUDA graph memory (cuda_graph_memory_bytes), preventing the overestimation of non_kv_cache_memory. The logic is sound and the implementation correctly fixes the issue. I have no high or critical severity issues to report.

@chaunceyjiang
Copy link
Collaborator

/cc @MatthewBonanni PTAL.

Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! I think the correct fix would actually be to just not add cudagraph_memory_estimate_applied to self.peak_activation_memory in the first place

Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Peter Pan <peter.pan@daocloud.io>
@mergify
Copy link

mergify bot commented Mar 20, 2026

Hi @panpan0000, the pre-commit checks have failed. Please run:

uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy failing?
mypy is run differently in CI. If the failure is related to this check, please use the following command to run it locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10

Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
@panpan0000
Copy link
Contributor Author

panpan0000 commented Mar 20, 2026

Thank you @MatthewBonanni , all fixed. please review again.

BTW, using AI to create a progress diagram below, so let other people can have a better view of the whole caculation

════════════════════════════════════════════════════════════════════
                   determine_available_memory()
════════════════════════════════════════════════════════════════════

  profile_run()
       │
       ├─→ torch_peak_increase          measured by allocated_bytes.all.peak
       │   (pure forward activation)    snapped BEFORE cudagraph profiling
       │
       ├─→ non_torch_increase           measured by mem_get_info() delta
       │   (NCCL / cuDNN / driver)      minus torch reserved portion
       │
       └─→ weights_memory               passed in (model params & buffers)

  profile_cudagraph_memory()            trial capture then discard
       │
       └─→ cudagraph_memory_estimate    measured by mem_get_info() delta
           × VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS
           = cudagraph_memory_estimate_applied

  ┌─────────────────────────────────────────────────────────────┐
  │  available_kv_cache_memory_bytes                            │
  │    = requested_memory                                       │
  │    - (torch_peak_increase + non_torch_increase              │
  │       + weights_memory)           ← non_kv_cache_memory     │
  │    - cudagraph_memory_estimate_applied                      │
  └─────────────────────────────────────────────────────────────┘


════════════════════════════════════════════════════════════════════
                   compile_or_warm_up_model()
════════════════════════════════════════════════════════════════════

  capture_model()
       │
       └─→ cuda_graph_memory_bytes      actual, via mem_get_info() delta

  ┌─────────────────────────────────────────────────────────────┐
  │  non_kv_cache_memory  (for --kv-cache-memory suggestion)    │
  │    = weights_memory                                         │
  │    + torch_peak_increase           (peak_activation_memory) │
  │    + non_torch_increase            (non_torch_memory)       │
  │    + cuda_graph_memory_bytes                                │
  └─────────────────────────────────────────────────────────────┘

       ```

Copy link
Collaborator

@MatthewBonanni MatthewBonanni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Mar 20, 2026
@MatthewBonanni MatthewBonanni added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 20, 2026
@MatthewBonanni MatthewBonanni enabled auto-merge (squash) March 20, 2026 15:22
@MatthewBonanni MatthewBonanni merged commit 79eb936 into vllm-project:main Mar 20, 2026
59 of 60 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Mar 20, 2026
chooper26 pushed a commit to intellistream/vllm-hust that referenced this pull request Mar 21, 2026
Signed-off-by: Peter Pan <Peter.Pan@daocloud.io>
Signed-off-by: Peter Pan <peter.pan@daocloud.io>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants