Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions tests/compile/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,28 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from contextlib import nullcontext
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pytest
from pydantic import ValidationError

from vllm.compilation.counter import compilation_counter
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.config import CompilationConfig, CUDAGraphMode, ParallelConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CUDAGraphMode,
ParallelConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.config.compilation import CompilationMode, PassConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
_is_torch_equal_or_newer,
is_torch_equal,
)
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher

# This import automatically registers `torch.ops.silly.attention`
from . import silly_attention # noqa: F401
Expand Down Expand Up @@ -472,6 +479,19 @@ def test_cached_compilation_config(default_vllm_config):
assert "torch.ops._C.static_scaled_fp8_quant.default(" in code


def _create_vllm_config_for_validation(
compilation_config: CompilationConfig,
) -> MagicMock:
"""Helper to create a mock VllmConfig for padding validation testing."""
mock_config = MagicMock(spec=VllmConfig)
mock_config.compilation_config = compilation_config
mock_config.scheduler_config = SchedulerConfig.default_factory(max_num_seqs=8)
mock_config.parallel_config = ParallelConfig()
mock_config.speculative_config = None
mock_config.lora_config = None
return mock_config


def test_compile_sizes_padding_validation():
"""Test that compile_sizes with values that would be padded raises an error."""
# cudagraph_capture_sizes=[1, 2, 4, 8] means:
Expand All @@ -488,29 +508,39 @@ def test_compile_sizes_padding_validation():
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=[3],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)

with pytest.raises(ValueError, match="would be padded to"):
config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=[5],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL)

config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=[1, 2, 4, 8],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.FULL) # Should not raise

config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 4, 8],
max_cudagraph_capture_size=8,
compile_sizes=["cudagraph_capture_sizes"],
cudagraph_mode=CUDAGraphMode.FULL,
)
config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [1, 2, 4, 8]
Expand All @@ -535,3 +565,5 @@ def test_compile_sizes_padding_validation():
)
config.post_init_cudagraph_sizes()
assert sorted(config.compile_sizes) == [3, 5, 7]
dispatcher = CudagraphDispatcher(_create_vllm_config_for_validation(config))
dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise
10 changes: 9 additions & 1 deletion tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tests.utils import multi_gpu_test
from vllm.engine.arg_utils import EngineArgs
from vllm.sampling_params import SamplingParams
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher

from ...utils import check_logprobs_close, check_outputs_equal

Expand Down Expand Up @@ -172,7 +173,14 @@ def test_mamba_cache_cg_padding(
tensor dimensions aren't compatible.
"""
vllm_config = EngineArgs(model=model, trust_remote_code=True).create_engine_config()
while len(example_prompts) == vllm_config.pad_for_cudagraph(len(example_prompts)):
cudagraph_dispatcher = CudagraphDispatcher(vllm_config)
cudagraph_dispatcher.initialize_cudagraph_keys(
vllm_config.compilation_config.cudagraph_mode
)
while (
len(example_prompts)
== cudagraph_dispatcher.dispatch(len(example_prompts))[1].num_tokens
):
Comment on lines 175 to +183

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Dispatch loop never terminates in hybrid padding test

In test_mamba_cache_cg_padding a CudagraphDispatcher is created and immediately used in the while condition without ever calling initialize_cudagraph_keys, so dispatch returns a BatchDescriptor with num_tokens unchanged when keys_initialized is False. The condition len(example_prompts) == ...num_tokens therefore stays true on every iteration and the loop appends forever, hanging the test before it exercises any logic. This test now never completes under any configuration.

Useful? React with 👍 / 👎.

example_prompts.append(example_prompts[0])

try:
Expand Down
10 changes: 4 additions & 6 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ def _create_vllm_config(
)

compilation_config.post_init_cudagraph_sizes()
mock_config.pad_for_cudagraph = (
lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
)

return mock_config

Expand Down Expand Up @@ -169,6 +166,7 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
rt_mode, key = dispatcher.dispatch(
num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
)

if "PIECEWISE" in cudagraph_mode_str: # string contains check
assert rt_mode == CUDAGraphMode.PIECEWISE
assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
Expand Down Expand Up @@ -360,7 +358,7 @@ def test_capture_replay_bypass_logic(self):
):
full_wrapper(input_1)

rt_mode, key = self.dispatcher.dispatch(desc_1)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
# 1. Capture first shape
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "capture_global"
Expand All @@ -369,7 +367,7 @@ def test_capture_replay_bypass_logic(self):
action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
assert action == "replay"

rt_mode, key = self.dispatcher.dispatch(desc_2)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
# 3. Capture second shape
action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
assert action == "capture_global"
Expand All @@ -381,7 +379,7 @@ def test_capture_replay_bypass_logic(self):
assert action == "replay"

# 5. Bypass if no key match
rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
assert rt_mode == CUDAGraphMode.NONE
action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
assert action == "bypass"
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/piecewise_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.config.utils import Range
from vllm.logger import init_logger

logger = init_logger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion vllm/compilation/sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch._inductor.pattern_matcher import PatternMatcherPass

from vllm.config import VllmConfig
from vllm.config.compilation import Range
from vllm.config.utils import Range
from vllm.distributed import get_tp_group, tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
Expand Down
47 changes: 0 additions & 47 deletions vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,15 +581,6 @@ class CompilationConfig:
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""

bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
)
"""optimization:
Intuitively, bs_to_padded_graph_size should be dict[int, int].
since we know all keys are in a range [0, max_cudagraph_capture_size],
we can optimize it to list[int] for better lookup performance."""

# keep track of enabled and disabled custom ops
enabled_custom_ops: Counter[str] = field(default_factory=Counter, init=False)
"""custom ops that are enabled"""
Expand Down Expand Up @@ -639,7 +630,6 @@ def compute_hash(self) -> str:
"debug_dump_path",
"cache_dir",
"local_cache_dir",
"bs_to_padded_graph_size",
"traced_files",
"compilation_time",
"static_forward_context",
Expand All @@ -661,7 +651,6 @@ def __repr__(self) -> str:
"enabled_custom_ops": True,
"disabled_custom_ops": True,
"compilation_time": True,
"bs_to_padded_graph_size": True,
"traced_files": True,
"inductor_compile_config": {
"post_grad_custom_post_pass": True,
Expand Down Expand Up @@ -882,7 +871,6 @@ def post_init_cudagraph_sizes(self) -> None:
"""To complete the initialization after cudagraph related
configs are set. This includes:
- initialize compile_sizes
- pre-compute the mapping bs_to_padded_graph_size
"""

computed_compile_sizes = []
Expand All @@ -906,23 +894,6 @@ def post_init_cudagraph_sizes(self) -> None:
if self.cudagraph_capture_sizes:
assert self.cudagraph_capture_sizes[-1] == self.max_cudagraph_capture_size

# May get recomputed in the model runner if adjustment is needed for spec-decode
self.compute_bs_to_padded_graph_size()

# Validate that compile_sizes won't be changed by padding.
# Only validate when cudagraphs are actually being used.
if self.compile_sizes and self.cudagraph_mode != CUDAGraphMode.NONE:
for size in self.compile_sizes:
if size <= self.max_cudagraph_capture_size:
padded = self.bs_to_padded_graph_size[size]
if padded != size:
raise ValueError(
f"compile_sizes contains {size} which would be "
f"padded to {padded}. All compile_sizes must be "
"values that won't be changed by cudagraph padding. "
"Use values from cudagraph_capture_sizes."
)

def set_splitting_ops_for_v1(
self, all2all_backend: str, data_parallel_size: int = 1
):
Expand Down Expand Up @@ -1134,24 +1105,6 @@ def adjust_cudagraph_sizes_for_spec_decode(
self.max_cudagraph_capture_size = rounded_sizes[-1]
self.cudagraph_capture_sizes = rounded_sizes

# Recompute after adjusting the cudagraph sizes
self.compute_bs_to_padded_graph_size()

def compute_bs_to_padded_graph_size(self):
# pre-compute the mapping from batch size to padded graph size
self.bs_to_padded_graph_size = [
0 for i in range(self.max_cudagraph_capture_size + 1)
]
for end, start in zip(
self.cudagraph_capture_sizes + [self.max_cudagraph_capture_size + 1],
[0] + self.cudagraph_capture_sizes,
):
for bs in range(start, end):
if bs == start:
self.bs_to_padded_graph_size[bs] = start
else:
self.bs_to_padded_graph_size[bs] = end

def get_compile_ranges(self) -> list[Range]:
"""Get the compile ranges for the compilation config."""
if self.compile_ranges_split_points is None:
Expand Down
60 changes: 55 additions & 5 deletions vllm/v1/cudagraph_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,47 @@ def __init__(self, vllm_config: VllmConfig):
)

self.keys_initialized = False
# Default cudagraph_mode to NONE until initialize_cudagraph_keys is called
self.cudagraph_mode = CUDAGraphMode.NONE

def _compute_bs_to_padded_graph_size(self) -> None:
"""Pre-compute the mapping from batch size to padded graph size."""
max_size = self.compilation_config.max_cudagraph_capture_size
capture_sizes = self.compilation_config.cudagraph_capture_sizes
self._bs_to_padded_graph_size: list[int] = [0] * (max_size + 1)
for end, start in zip(
capture_sizes + [max_size + 1],
[0] + capture_sizes,
):
for bs in range(start, end):
if bs == start:
self._bs_to_padded_graph_size[bs] = start
else:
self._bs_to_padded_graph_size[bs] = end

# Validate that compile_sizes won't be changed by padding.
# Only validate when cudagraphs are actually being used.
if (
self.compilation_config.compile_sizes
and self.cudagraph_mode != CUDAGraphMode.NONE
):
for size in self.compilation_config.compile_sizes:
if size <= self.compilation_config.max_cudagraph_capture_size:
padded = self._bs_to_padded_graph_size[size]
if padded != size:
raise ValueError(
f"compile_sizes contains {size} which would be "
f"padded to {padded}. All compile_sizes must be "
"values that won't be changed by cudagraph padding. "
"Use values from cudagraph_capture_sizes."
)

def _create_padded_batch_descriptor(
self, num_tokens: int, uniform_decode: bool, has_lora: bool
) -> BatchDescriptor:
max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs
uniform_decode_query_len = self.uniform_decode_query_len
num_tokens_padded = self.vllm_config.pad_for_cudagraph(num_tokens)
num_tokens_padded = self._bs_to_padded_graph_size[num_tokens]

if uniform_decode and self.cudagraph_mode.has_mode(CUDAGraphMode.FULL):
num_reqs = num_tokens_padded // uniform_decode_query_len
Expand All @@ -88,12 +122,19 @@ def add_cudagraph_key(
self.cudagraph_keys[runtime_mode].add(batch_descriptor)

def initialize_cudagraph_keys(
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1
):
# This should be called only after attention backend is initialized. So we can
# get the correct cudagraph mode after backend support is resolved.
self.cudagraph_mode = cudagraph_mode

# Early exit if cudagraphs are disabled
if cudagraph_mode == CUDAGraphMode.NONE:
self.keys_initialized = True
return

self._compute_bs_to_padded_graph_size()

# LoRA activation cases to specialize the cuda graphs on
if self.vllm_config.lora_config:
if self.compilation_config.cudagraph_specialize_lora:
Expand Down Expand Up @@ -143,15 +184,24 @@ def initialize_cudagraph_keys(
def dispatch(
self,
num_tokens: int,
uniform_decode: bool,
has_lora: bool,
uniform_decode: bool = False,
has_lora: bool = False,
disable_full: bool = False,
) -> tuple[CUDAGraphMode, BatchDescriptor]:
"""
Given conditions(e.g.,batch descriptor and if using cascade attention),
Given conditions(e.g.,batch descriptor and if using piecewise only),
dispatch to a cudagraph runtime mode and the valid batch descriptor.
A new batch descriptor is returned as we might dispatch a uniform batch
to a graph that supports a more general batch (uniform to non-uniform).

Args:
num_tokens: Number of tokens in the batch.
uniform_decode: Whether the batch is uniform decode (i.e. uniform and query
length is uniform_decode_query_len).
has_lora: Whether LoRA is active.
disable_full: If True, skip FULL cudagraph checks and
return PIECEWISE or NONE only. (can be used for features like
cascade attention that are not supported by full cudagraphs)
"""
if (
not self.keys_initialized
Expand Down
Loading