Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
82494de
Add cudagraph memory profiling
MatthewBonanni Dec 11, 2025
a7ffd72
Compare only FULL size
MatthewBonanni Dec 11, 2025
845c345
Use freeze_gc context manager
MatthewBonanni Dec 11, 2025
b41ab58
Factor out freeze_gc
MatthewBonanni Dec 11, 2025
bacd7c0
Include piecewise
MatthewBonanni Dec 11, 2025
ff34c42
Warm up separately
MatthewBonanni Dec 11, 2025
94095ef
Don't count first capture
MatthewBonanni Dec 11, 2025
b2a1165
Handle FULL cudagraph first pass allocations, e.g. FA split buffer
MatthewBonanni Dec 12, 2025
ac4cb05
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Dec 15, 2025
401006e
Remove dangerous line
MatthewBonanni Dec 16, 2025
2eea04d
Empty cache when necessary and edit log message
MatthewBonanni Dec 16, 2025
74caa0b
Add empty_cache to cut down on memory fragmentation
MatthewBonanni Dec 16, 2025
5a4dd22
Add warmups
MatthewBonanni Dec 17, 2025
93a387f
Make sure self.cache_config.num_gpu_blocks is set
MatthewBonanni Dec 17, 2025
013611e
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Dec 18, 2025
06cf23b
Bugfix
MatthewBonanni Dec 18, 2025
b9d0319
Don't increment counter for profiling captures
MatthewBonanni Dec 18, 2025
c6f00b3
Allocate enough blocks for mamba
MatthewBonanni Dec 18, 2025
868dd88
Bugfix
MatthewBonanni Dec 19, 2025
4bf0bf9
Clean up
MatthewBonanni Dec 19, 2025
7d2f264
PIECEWISE graphs are reused
MatthewBonanni Dec 19, 2025
e5ea128
Decrease max_model_len for test
MatthewBonanni Dec 19, 2025
febc94c
Clear references to kv cache in attention layers to allow garbage
MatthewBonanni Dec 19, 2025
41c4808
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Jan 5, 2026
82d2c26
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Jan 12, 2026
71e97f2
Add missing import
MatthewBonanni Jan 12, 2026
b929219
Add env variable, default false
MatthewBonanni Jan 12, 2026
a744491
Use set_current_vllm_config
MatthewBonanni Jan 12, 2026
4f05a7b
Clean up, fix edge case
MatthewBonanni Jan 12, 2026
5a0b6f0
Clean up
MatthewBonanni Jan 12, 2026
d2d5ca1
Handle FULL mode properly
MatthewBonanni Jan 12, 2026
925bb59
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Jan 26, 2026
c946aae
Remove unused full_graph_memory_bytes
MatthewBonanni Jan 26, 2026
041c1fc
Cross-platform
MatthewBonanni Jan 26, 2026
8ec55af
Cleanup
MatthewBonanni Feb 18, 2026
50fc6d6
Simplify
MatthewBonanni Feb 18, 2026
9c9d5cf
Assume at least two capture sizes
MatthewBonanni Feb 18, 2026
acc816b
Cleanup
MatthewBonanni Feb 18, 2026
e33df31
Cleanup
MatthewBonanni Feb 18, 2026
10f319b
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Feb 18, 2026
e9932c1
Reorder arguments
MatthewBonanni Feb 18, 2026
2b5fe67
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Feb 23, 2026
4ce160f
Merge branch 'main' into cg_memory_profiling
LucasWilkinson Feb 24, 2026
afcedd3
Clean up using get_capture_descs
MatthewBonanni Feb 24, 2026
829ca46
Clean up LoRAs
MatthewBonanni Feb 24, 2026
70e7380
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Feb 26, 2026
e89fdee
Refactor profile_cudagraph_memory to use capture_descs loop
LucasWilkinson Feb 26, 2026
0ddead6
Merge pull request #1 from neuralmagic/cg_memory_profiling_cleanup
MatthewBonanni Mar 3, 2026
0a85aaa
Fix
MatthewBonanni Mar 3, 2026
de21a1f
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 3, 2026
493d656
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 4, 2026
3fcb57e
Use temporary pool for profiling
MatthewBonanni Mar 4, 2026
d9d3cf8
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 5, 2026
aeb4c4a
Add transition info logging
MatthewBonanni Mar 5, 2026
9c19a67
Fix
MatthewBonanni Mar 5, 2026
7c67862
Move log info
MatthewBonanni Mar 5, 2026
0197626
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 5, 2026
87a67c6
Fix pre-commit
MatthewBonanni Mar 5, 2026
ec96d3f
Always warm up
MatthewBonanni Mar 5, 2026
ec346c9
Remove unnecessary line
MatthewBonanni Mar 5, 2026
196bfd2
Fix bad merge
MatthewBonanni Mar 5, 2026
14c3227
Don't double count
MatthewBonanni Mar 5, 2026
6d30ce2
Clear piecewise graphs too
MatthewBonanni Mar 5, 2026
423d498
Fix count
MatthewBonanni Mar 6, 2026
e9f9451
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 6, 2026
3cb44f3
Use temporary pools for all graphs
MatthewBonanni Mar 6, 2026
d5555d6
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 6, 2026
0038b17
Merge branch 'main' into cg_memory_profiling
MatthewBonanni Mar 6, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
import weakref
from collections import Counter
from collections.abc import Callable
from contextlib import ExitStack
from typing import Any
from typing import Any, ClassVar
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -162,6 +163,14 @@ class CUDAGraphWrapper:
guaranteed when VLLM_LOGGING_LEVEL == "DEBUG".
"""

_all_instances: ClassVar[weakref.WeakSet["CUDAGraphWrapper"]] = weakref.WeakSet()

@classmethod
def clear_all_graphs(cls) -> None:
"""Clear captured graphs from all CUDAGraphWrapper instances."""
for instance in list(cls._all_instances):
instance.clear_graphs()

def __init__(
self,
runnable: Callable[..., Any],
Expand Down Expand Up @@ -192,6 +201,8 @@ def __init__(
# cudagraphs for.
self.concrete_cudagraph_entries: dict[BatchDescriptor, CUDAGraphEntry] = {}

CUDAGraphWrapper._all_instances.add(self)

def __getattr__(self, key: str) -> Any:
# allow accessing the attributes of the runnable.
if hasattr(self.runnable, key):
Expand All @@ -205,6 +216,13 @@ def unwrap(self) -> Callable[..., Any]:
# in case we need to access the original runnable.
return self.runnable

@property
def cudagraph_wrapper(self) -> "CUDAGraphWrapper":
return self

def clear_graphs(self) -> None:
self.concrete_cudagraph_entries.clear()

def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
forward_context = get_forward_context()
batch_descriptor = forward_context.batch_descriptor
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@
VLLM_CUDA_COMPATIBILITY_PATH: str | None = None
VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False
VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -1628,6 +1629,12 @@ def _get_or_set_default() -> str:
"VLLM_ELASTIC_EP_DRAIN_REQUESTS": lambda: bool(
int(os.getenv("VLLM_ELASTIC_EP_DRAIN_REQUESTS", "0"))
),
# If set to 1, enable CUDA graph memory estimation during memory profiling.
# This profiles CUDA graph memory usage to provide more accurate KV cache
# memory allocation. Disabled by default to preserve existing behavior.
"VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS": lambda: bool(
int(os.getenv("VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS", "0"))
),
}


Expand Down
7 changes: 5 additions & 2 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,11 @@ def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode])
if descs:
# Sort by num_tokens descending (largest first)
descs.sort(key=lambda d: d.num_tokens, reverse=True)
# Sort by (num_tokens, num_active_loras) descending
descs.sort(
key=lambda d: (d.num_tokens, d.num_active_loras),
reverse=True,
)
result.append((mode, descs))

return result
Loading