Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 88 additions & 9 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
from vllm.v1.attention.backend import AttentionMetadata

logger = init_logger(__name__)

Expand Down Expand Up @@ -124,6 +125,33 @@ def log(self, log_fn: Callable[..., Any] = logger.info) -> None:
self.reset()


def _extract_metadata_addresses(
attn_metadata: dict[str, AttentionMetadata]
| list[dict[str, AttentionMetadata]]
| None,
) -> dict[str, int]:
addresses: dict[str, int] = {}
if attn_metadata is None:
return addresses

if isinstance(attn_metadata, list):
# This is for DBO. List of size two, one for each microbatch.
# We will check the metadata addresses for the first one,
# assuming the second one has the same addresses.
attn_metadata = attn_metadata[0]

for k, v in attn_metadata.items():
assert dataclasses.is_dataclass(v)
for field in dataclasses.fields(v):
if field.name.startswith("__"):
continue
tensor_candidate = getattr(v, field.name, None)
if isinstance(tensor_candidate, torch.Tensor) and tensor_candidate.is_cuda:
addresses[f"{k}.{field.name}"] = tensor_candidate.data_ptr()

return addresses


@dataclasses.dataclass
class CUDAGraphEntry:
batch_descriptor: BatchDescriptor
Expand All @@ -133,6 +161,7 @@ class CUDAGraphEntry:
# for cudagraph debugging, track the input addresses
# during capture, and check if they are the same during replay
input_addresses: list[int] | None = None
metadata_addresses: dict[str, int] | None = None


@dataclasses.dataclass
Expand Down Expand Up @@ -273,10 +302,18 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
# validate that cudagraph capturing is legal at this point.
validate_cudagraph_capturing_enabled()

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses = input_addresses
if self.is_debugging_mode:
entry.input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
entry.input_addresses.extend(
v.data_ptr() for v in kwargs.values() if isinstance(v, torch.Tensor)
)
if self.runtime_mode == CUDAGraphMode.FULL:
entry.metadata_addresses = _extract_metadata_addresses(
forward_context.attn_metadata
)

cudagraph = torch.cuda.CUDAGraph()

with ExitStack() as stack:
Expand Down Expand Up @@ -336,15 +373,57 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
return output

if self.is_debugging_mode:
# check if the input addresses are the same
new_input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
]
assert new_input_addresses == entry.input_addresses, (
f"Input addresses for cudagraphs are different "
f"during replay. Expected {entry.input_addresses}, "
f"got {new_input_addresses}"
new_input_addresses.extend(
v.data_ptr() for v in kwargs.values() if isinstance(v, torch.Tensor)
)
args_match = new_input_addresses == entry.input_addresses

# Check metadata dict (only in FULL mode)
metadata_match = True
new_metadata_addresses = None
if self.runtime_mode == CUDAGraphMode.FULL:
new_metadata_addresses = _extract_metadata_addresses(
forward_context.attn_metadata
)
metadata_match = new_metadata_addresses == entry.metadata_addresses

if not (args_match and metadata_match):
error_msg = [
"Input addresses for cudagraphs are different during replay."
]

if not args_match:
error_msg.append(
f"Expected {entry.input_addresses}, got {new_input_addresses}"
)

if not metadata_match:
old_meta = entry.metadata_addresses or {}
new_meta = new_metadata_addresses or {}

old_keys = set(old_meta.keys())
new_keys = set(new_meta.keys())
missing = old_keys - new_keys
added = new_keys - old_keys
changed = {
k: (old_meta[k], new_meta[k])
for k in old_keys & new_keys
if old_meta[k] != new_meta[k]
}
error_msg.append("Differences in attn_metadata detected:")
if missing:
error_msg.append(f" Missing tensors: {missing}")
if added:
error_msg.append(f" New tensors: {added}")
if changed:
error_msg.append(" Changed addresses:")
for k, (old_addr, new_addr) in changed.items():
error_msg.append(f" {k}: {old_addr} -> {new_addr}")

raise AssertionError("\n".join(error_msg))

# Sync offloader before replay - ensures any external dependencies
# from pre-capture prefetches are satisfied.
Expand Down
Loading