diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index cdac67a2d650..2f539c9d397d 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -173,6 +173,68 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): else: assert rt_mode == CUDAGraphMode.NONE + @pytest.mark.parametrize( + "cudagraph_mode_str,compilation_mode,expected_modes", + [ + # FULL mode: only FULL keys, no PIECEWISE + ("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]), + # PIECEWISE mode: only PIECEWISE keys + ("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]), + # FULL_DECODE_ONLY: only FULL keys for uniform decode + ("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]), + # NONE mode: no keys + ("NONE", CompilationMode.NONE, []), + ], + ) + def test_get_capture_descs( + self, cudagraph_mode_str, compilation_mode, expected_modes + ): + """Test get_capture_descs returns correctly grouped and ordered descs.""" + comp_config = CompilationConfig( + cudagraph_mode=cudagraph_mode_str, + mode=compilation_mode, + cudagraph_capture_sizes=[1, 4, 8, 16], + ) + + config = _create_vllm_config(comp_config, max_num_seqs=16) + dispatcher = CudagraphDispatcher(config) + dispatcher.initialize_cudagraph_keys( + cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1 + ) + + capture_descs = dispatcher.get_capture_descs() + + # Verify we get the expected modes + actual_modes = [mode for mode, _ in capture_descs] + assert actual_modes == expected_modes + + # Verify each group is sorted largest-first + for mode, descs in capture_descs: + assert len(descs) > 0, "Each group should have at least one descriptor" + num_tokens_list = [d.num_tokens for d in descs] + assert num_tokens_list == sorted(num_tokens_list, reverse=True), ( + f"Descriptors for {mode} should be sorted largest-first" + ) + + # All descriptors in a group should have same uniform value + uniform_values = [d.uniform for d in descs] + assert len(set(uniform_values)) == 1, ( + "All descriptors in a group should have the same uniform value" + ) + + def test_get_capture_descs_empty_when_not_initialized(self): + """Test that get_capture_descs returns empty list when keys not initialized.""" + comp_config = CompilationConfig( + cudagraph_mode="FULL", + mode=CompilationMode.NONE, + cudagraph_capture_sizes=[1, 8], + ) + config = _create_vllm_config(comp_config, max_num_seqs=8) + dispatcher = CudagraphDispatcher(config) + # Don't initialize keys + + assert dispatcher.get_capture_descs() == [] + @pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") class TestCUDAGraphWrapper: diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 3e47e98c150c..f5738c6b3ca0 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -231,3 +231,26 @@ def dispatch( # finally, just return no cudagraphs and a trivial batch descriptor return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) + + def get_capture_descs(self) -> list[tuple[CUDAGraphMode, list[BatchDescriptor]]]: + """ + Returns capture descriptors for cudagraph capturing. + + Returns: + List of (runtime_mode, batch_descriptors) tuples, ordered PIECEWISE + first then FULL. Batch descriptors are sorted largest-first for + memory efficiency. + """ + if not self.keys_initialized or self.cudagraph_mode == CUDAGraphMode.NONE: + return [] + + result = [] + # Return in order: PIECEWISE first, then FULL + for mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL]: + descs = list(self.cudagraph_keys[mode]) + if descs: + # Sort by num_tokens descending (largest first) + descs.sort(key=lambda d: d.num_tokens, reverse=True) + result.append((mode, descs)) + + return result diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 982ae44c2def..40be51682027 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,7 +10,6 @@ from contextlib import contextmanager from copy import copy, deepcopy from functools import reduce -from itertools import product from typing import TYPE_CHECKING, Any, NamedTuple, TypeAlias, cast import numpy as np @@ -4809,50 +4808,14 @@ def freeze_gc(): set_cudagraph_capturing_enabled(True) with freeze_gc(), graph_capture(device=self.device): start_free_gpu_memory = torch.cuda.mem_get_info()[0] - cudagraph_mode = self.compilation_config.cudagraph_mode - assert cudagraph_mode is not None - if self.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] - - if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - # make sure we capture the largest batch size first - compilation_cases = list( - product(reversed(self.cudagraph_batch_sizes), lora_cases) - ) - self._capture_cudagraphs( - compilation_cases, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False, - ) - - # Capture full cudagraph for uniform decode batches if we - # don't already have full mixed prefill-decode cudagraphs. - if ( - cudagraph_mode.decode_mode() == CUDAGraphMode.FULL - and cudagraph_mode.separate_routine() - ): - max_num_tokens = ( - self.scheduler_config.max_num_seqs * self.uniform_decode_query_len - ) - decode_cudagraph_batch_sizes = [ - x - for x in self.cudagraph_batch_sizes - if max_num_tokens >= x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - product(reversed(decode_cudagraph_batch_sizes), lora_cases) - ) + for ( + runtime_mode, + batch_descs, + ) in self.cudagraph_dispatcher.get_capture_descs(): self._capture_cudagraphs( - compilation_cases=compilation_cases_decode, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True, + batch_descriptors=batch_descs, + cudagraph_runtime_mode=runtime_mode, ) torch.cuda.synchronize() @@ -4883,19 +4846,32 @@ def freeze_gc(): def _capture_cudagraphs( self, - compilation_cases: list[tuple[int, bool]], + batch_descriptors: list[BatchDescriptor], cudagraph_runtime_mode: CUDAGraphMode, - uniform_decode: bool, ): assert ( cudagraph_runtime_mode != CUDAGraphMode.NONE and cudagraph_runtime_mode.valid_runtime_modes() ), f"Invalid cudagraph runtime mode: {cudagraph_runtime_mode}" + if not batch_descriptors: + return + + uniform_decode = batch_descriptors[0].uniform + force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL + + dummy_run = functools.partial( + self._dummy_run, + uniform_decode=uniform_decode, + skip_eplb=True, + remove_lora=False, + force_attention=force_attention, + ) + # Only rank 0 should print progress bar during capture if is_global_first_rank(): - compilation_cases = tqdm( - compilation_cases, + batch_descriptors = tqdm( + batch_descriptors, disable=not self.load_config.use_tqdm_on_load, desc="Capturing CUDA graphs ({}, {})".format( "decode" if uniform_decode else "mixed prefill-decode", @@ -4904,7 +4880,10 @@ def _capture_cudagraphs( ) # We skip EPLB here since we don't want to record dummy metrics - for num_tokens, activate_lora in compilation_cases: + for batch_desc in batch_descriptors: + num_tokens = batch_desc.num_tokens + activate_lora = batch_desc.has_lora + # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens # is above the threshold. Otherwise we just capture a non-ubatched @@ -4922,28 +4901,22 @@ def _capture_cudagraphs( for _ in range(self.compilation_config.cudagraph_num_of_warmups): # Use CUDAGraphRuntimeStyle.NONE (default) for warmup. - # But be careful, warm up with `NONE`is orthogonal to + # But be careful, warm up with `NONE` is orthogonal to # if we want to warm up attention or not. This is # different from the case where `FULL` implies capture # attention while `PIECEWISE` implies no attention. - force_attention = cudagraph_runtime_mode == CUDAGraphMode.FULL - self._dummy_run( + dummy_run( num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, - force_attention=force_attention, - uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, activate_lora=activate_lora, ) - self._dummy_run( + + # Capture run + dummy_run( num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=uniform_decode, allow_microbatching=allow_microbatching, - skip_eplb=True, - remove_lora=False, activate_lora=activate_lora, is_graph_capturing=True, )