Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,68 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
else:
assert rt_mode == CUDAGraphMode.NONE

@pytest.mark.parametrize(
"cudagraph_mode_str,compilation_mode,expected_modes",
[
# FULL mode: only FULL keys, no PIECEWISE
("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
# PIECEWISE mode: only PIECEWISE keys
("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
# FULL_DECODE_ONLY: only FULL keys for uniform decode
("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
# NONE mode: no keys
("NONE", CompilationMode.NONE, []),
],
)
def test_get_capture_descs(
self, cudagraph_mode_str, compilation_mode, expected_modes
):
"""Test get_capture_descs returns correctly grouped and ordered descs."""
comp_config = CompilationConfig(
cudagraph_mode=cudagraph_mode_str,
mode=compilation_mode,
cudagraph_capture_sizes=[1, 4, 8, 16],
)

config = _create_vllm_config(comp_config, max_num_seqs=16)
dispatcher = CudagraphDispatcher(config)
dispatcher.initialize_cudagraph_keys(
cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
)

capture_descs = dispatcher.get_capture_descs()

# Verify we get the expected modes
actual_modes = [mode for mode, _ in capture_descs]
assert actual_modes == expected_modes

# Verify each group is sorted largest-first
for mode, descs in capture_descs:
assert len(descs) > 0, "Each group should have at least one descriptor"
num_tokens_list = [d.num_tokens for d in descs]
assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
f"Descriptors for {mode} should be sorted largest-first"
)

# All descriptors in a group should have same uniform value
uniform_values = [d.uniform for d in descs]
assert len(set(uniform_values)) == 1, (
"All descriptors in a group should have the same uniform value"
)

def test_get_capture_descs_empty_when_not_initialized(self):
"""Test that get_capture_descs returns empty list when keys not initialized."""
comp_config = CompilationConfig(
cudagraph_mode="FULL",
mode=CompilationMode.NONE,
cudagraph_capture_sizes=[1, 8],
)
config = _create_vllm_config(comp_config, max_num_seqs=8)
dispatcher = CudagraphDispatcher(config)
# Don't initialize keys

assert dispatcher.get_capture_descs() == []


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
Expand Down
23 changes: 23 additions & 0 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,26 @@ def dispatch(

# finally, just return no cudagraphs and a trivial batch descriptor
return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)

def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]:
"""
Returns capture descriptors for cudagraph capturing.

Returns:
List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE
first then FULL. Batch descriptors are sorted largest-first for
memory efficiency.
"""
if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE:
return []

result = []
# Return in order: PIECEWISE first, then FULL
for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]:
descs = list(self.cudagraph_keys[mode])
if descs:
# Sort by num_tokens descending (largest first)
descs.sort(key=lambda d: d.num_tokens, reverse=True)
result.append((mode, descs))

return result
91 changes: 32 additions & 59 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from contextlib import contextmanager
from copy import copy, deepcopy
from functools import reduce
from itertools import product
from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast

import numpy as np
Expand Down Expand Up @@ -4809,50 +4808,14 @@ def freeze_gc():
set_cudagraph_capturing_enabled(True)
with freeze_gc(), graph_capture(device=self.device):
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
cudagraph_mode = self.compilation_config.cudagraph_mode
assert cudagraph_mode is not None

if self.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
lora_cases = [True, False]
else:
lora_cases = [True]
else:
lora_cases = [False]

if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
cudagraph_runtime_mode = cudagraph_mode.mixed_mode()
# make sure we capture the largest batch size first
compilation_cases = list(
product(reversed(self.cudagraph_batch_sizes), lora_cases)
)
self._capture_cudagraphs(
compilation_cases,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=False,
)

# Capture full cudagraph for uniform decode batches if we
# don't already have full mixed prefill-decode cudagraphs.
if (
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
and cudagraph_mode.separate_routine()
):
max_num_tokens = (
self.scheduler_config.max_num_seqs * self.uniform_decode_query_len
)
decode_cudagraph_batch_sizes = [
x
for x in self.cudagraph_batch_sizes
if max_num_tokens >= x >= self.uniform_decode_query_len
]
compilation_cases_decode = list(
product(reversed(decode_cudagraph_batch_sizes), lora_cases)
)
for (
runtime_mode,
batch_descs,
) in self.cudagraph_dispatcher.get_capture_descs():
self._capture_cudagraphs(
compilation_cases=compilation_cases_decode,
cudagraph_runtime_mode=CUDAGraphMode.FULL,
uniform_decode=True,
batch_descriptors=batch_descs,
cudagraph_runtime_mode=runtime_mode,
)

torch.cuda.synchronize()
Expand Down Expand Up @@ -4883,19 +4846,32 @@ def freeze_gc():

def _capture_cudagraphs(
self,
compilation_cases: list[tuple[int, bool]],
batch_descriptors: list[BatchDescriptor],
cudagraph_runtime_mode: CUDAGraphMode,
uniform_decode: bool,
):
assert (
cudagraph_runtime_mode != CUDAGraphMode.NONE
and cudagraph_runtime_mode.valid_runtime_modes()
), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}"

if not batch_descriptors:
return

uniform_decode = batch_descriptors[0].uniform
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL

dummy_run = functools.partial(
self._dummy_run,
uniform_decode=uniform_decode,
skip_eplb=True,
remove_lora=False,
force_attention=force_attention,
)

# Only rank 0 should print progress bar during capture
if is_global_first_rank():
compilation_cases = tqdm(
compilation_cases,
batch_descriptors = tqdm(
batch_descriptors,
disable=not self.load_config.use_tqdm_on_load,
desc="Capturing CUDA graphs ({}, {})".format(
"decode" if uniform_decode else "mixed prefill-decode",
Expand All @@ -4904,7 +4880,10 @@ def _capture_cudagraphs(
)

# We skip EPLB here since we don't want to record dummy metrics
for num_tokens, activate_lora in compilation_cases:
for batch_desc in batch_descriptors:
num_tokens = batch_desc.num_tokens
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we're moving closer and closer to passing BatchDescriptor to dummy run directly...

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

next 😄

activate_lora = batch_desc.has_lora

# We currently only capture ubatched graphs when its a FULL
# cudagraph, a uniform decode batch, and the number of tokens
# is above the threshold. Otherwise we just capture a non-ubatched
Expand All @@ -4922,28 +4901,22 @@ def _capture_cudagraphs(

for _ in range(self.compilation_config.cudagraph_num_of_warmups):
# Use CUDAGraphRuntimeStyle.NONE (default) for warmup.
# But be careful, warm up with `NONE`is orthogonal to
# But be careful, warm up with `NONE` is orthogonal to
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.
force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL
self._dummy_run(
dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
force_attention=force_attention,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
)
self._dummy_run(

# Capture run
dummy_run(
num_tokens,
cudagraph_runtime_mode=cudagraph_runtime_mode,
uniform_decode=uniform_decode,
allow_microbatching=allow_microbatching,
skip_eplb=True,
remove_lora=False,
activate_lora=activate_lora,
is_graph_capturing=True,
)
Expand Down