diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index b51934a3ab29..41c4e308d0be 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2 + GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 2b0f8a95d49f..debf9aeaa4d7 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import replace from unittest.mock import MagicMock, patch import pytest @@ -132,36 +133,39 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): # Test dispatch logic # 1. non-uniform batch, size in cudagraph size list - desc_full_exact = BatchDescriptor( - num_tokens=8, - uniform=False, - ) + # FULL mode uses exact keys with num_reqs set + desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False) + # PIECEWISE mode uses relaxed keys with num_reqs=None + desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False) rt_mode, key = dispatcher.dispatch( num_tokens=8, uniform_decode=False, has_lora=False ) if cudagraph_mode_str == "FULL": assert rt_mode == CUDAGraphMode.FULL - assert key == desc_full_exact + assert key == desc_full_with_reqs elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert rt_mode == CUDAGraphMode.PIECEWISE - assert key == desc_full_exact + assert key == desc_piecewise else: assert rt_mode == CUDAGraphMode.NONE # 2. uniform decode batch, size in cudagraph size list desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True) + desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False) rt_mode, key = dispatcher.dispatch( num_tokens=8, uniform_decode=True, has_lora=False ) if cudagraph_mode_str == "FULL": + # Pure FULL mode uses non-uniform keys for all batches assert rt_mode == CUDAGraphMode.FULL - assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs() + assert key == desc_non_uniform elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]: + # These modes have separate uniform decode keys assert rt_mode == CUDAGraphMode.FULL assert key == desc_uniform_exact elif cudagraph_mode_str == "PIECEWISE": assert rt_mode == CUDAGraphMode.PIECEWISE - assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs() + assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False) else: assert rt_mode == CUDAGraphMode.NONE @@ -180,7 +184,7 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): if "PIECEWISE" in cudagraph_mode_str: # string contains check assert rt_mode == CUDAGraphMode.PIECEWISE - assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs() + assert key == replace(desc_full_exact, num_reqs=None, uniform=False) else: assert rt_mode == CUDAGraphMode.NONE diff --git a/vllm/forward_context.py b/vllm/forward_context.py index d357c8929d83..a0753b19e434 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -5,7 +5,7 @@ from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass, field -from typing import Any, NamedTuple +from typing import Any import torch @@ -26,7 +26,8 @@ batchsize_forward_time: defaultdict = defaultdict(list) -class BatchDescriptor(NamedTuple): +@dataclass(frozen=True) +class BatchDescriptor: """ Batch descriptor for cudagraph dispatching. We should keep the num of items as minimal as possible to properly and uniquely describe the padded @@ -56,19 +57,6 @@ class BatchDescriptor(NamedTuple): to be properly captured. """ - def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": - """ - Return a relaxed version of current batch descriptor that is still compatible - with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). - """ - return BatchDescriptor( - self.num_tokens, - num_reqs=None, - uniform=False, - has_lora=self.has_lora, - num_active_loras=self.num_active_loras, - ) - def _compute_sp_num_tokens( num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e786ab3bc056..ecd1b274c8ce 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -40,7 +40,7 @@ vllm_is_batch_invariant, ) from vllm.platforms.interface import DeviceCapability -from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv, round_up from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -310,8 +310,17 @@ def __init__( self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.aot_schedule: + # FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4 + # The +1 is for the tile_count_semaphore (synchronization). + # The 4 slots per batch element (num_prepare_batch_vectors) are: + # prepare_varlen + dynamic_split + sort_batches + head_swizzle + # See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501 + max_batch_size = max( + vllm_config.scheduler_config.max_num_seqs, + self.max_cudagraph_size or 0, + ) self.scheduler_metadata = torch.zeros( - vllm_config.scheduler_config.max_num_seqs + 1, + 1 + round_up(max_batch_size, 4) * 4, dtype=torch.int32, device=self.device, ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index e160d3255688..33f89603563e 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -21,6 +21,7 @@ vllm_is_batch_invariant, ) from vllm.platforms.interface import DeviceCapability +from vllm.utils.math_utils import round_up from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionLayer, @@ -129,8 +130,17 @@ def __init__( self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size if self.use_full_cuda_graph and self.fa_aot_schedule: + # FA3 scheduler_metadata size: 1 + round_up(batch_size, 4) * 4 + # The +1 is for the tile_count_semaphore (synchronization). + # The 4 slots per batch element (num_prepare_batch_vectors) are: + # prepare_varlen + dynamic_split + sort_batches + head_swizzle + # See: https://github.com/vllm-project/flash-attention/blob/5824e6e/hopper/flash_api.cpp#L664-L671 # noqa: E501 + max_batch_size = max( + vllm_config.scheduler_config.max_num_seqs, + self.max_cudagraph_size or 0, + ) self.scheduler_metadata = torch.zeros( - vllm_config.scheduler_config.max_num_seqs + 1, + 1 + round_up(max_batch_size, 4) * 4, dtype=torch.int32, device=self.device, ) diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index 6f3e029c793b..6817c571b8b7 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import replace from itertools import product from vllm.config import CUDAGraphMode, VllmConfig @@ -180,12 +181,14 @@ def initialize_cudagraph_keys( for bs, num_active_loras in product( self.compilation_config.cudagraph_capture_sizes, lora_cases ): - self.add_cudagraph_key( - cudagraph_mode.mixed_mode(), - self._create_padded_batch_descriptor( - bs, False, num_active_loras > 0, num_active_loras - ).relax_for_mixed_batch_cudagraphs(), + batch_desc = self._create_padded_batch_descriptor( + bs, False, num_active_loras > 0, num_active_loras ) + # Only relax for PIECEWISE mode. FULL mode needs exact num_reqs + # because FA3's scheduler_metadata computation depends on it. + if cudagraph_mode.mixed_mode() == CUDAGraphMode.PIECEWISE: + batch_desc = replace(batch_desc, num_reqs=None, uniform=False) + self.add_cudagraph_key(cudagraph_mode.mixed_mode(), batch_desc) # if decode cudagraph mode is FULL, and we don't already have mixed # mode full cudagraphs then add them here. @@ -264,21 +267,23 @@ def dispatch( batch_desc = self._create_padded_batch_descriptor( num_tokens, uniform_decode, has_lora, effective_num_active_loras ) - relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() - - if not disable_full: - # check if key exists for full cudagraph - if batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, batch_desc - # otherwise, check if the relaxed key exists - if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.FULL]: - return CUDAGraphMode.FULL, relaxed_batch_desc + # check if key exists for full cudagraph + # For pure FULL mode, keys are registered with uniform=False. + batch_desc_to_check = batch_desc + if self.cudagraph_mode == CUDAGraphMode.FULL: + batch_desc_to_check = replace(batch_desc, uniform=False) + if ( + not disable_full + and batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.FULL] + ): + return CUDAGraphMode.FULL, batch_desc_to_check # also check if the relaxed key exists for more "general" # piecewise cudagraph - if relaxed_batch_desc in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: - return CUDAGraphMode.PIECEWISE, relaxed_batch_desc + batch_desc_to_check = replace(batch_desc, num_reqs=None, uniform=False) + if batch_desc_to_check in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]: + return CUDAGraphMode.PIECEWISE, batch_desc_to_check # finally, just return no cudagraphs and a trivial batch descriptor return CUDAGraphMode.NONE, BatchDescriptor(num_tokens)