diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index b51934a3ab29..dbdfd5e81443 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2 + GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 9275725314e4..232b0b0daff6 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -308,10 +308,15 @@ def __init__( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size + max_num_seqs = vllm_config.scheduler_config.max_num_seqs if self.use_full_cuda_graph and self.aot_schedule: + # Times 4 due to: + # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 + # For some tests max_cudagraph_size > max_num_seqs, + # so we need to use the larger one. self.scheduler_metadata = torch.zeros( - vllm_config.scheduler_config.max_num_seqs + 1, + max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1, dtype=torch.int32, device=self.device, ) diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py index e160d3255688..f0ba259362ff 100644 --- a/vllm/v1/attention/backends/mla/flashattn_mla.py +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -127,10 +127,15 @@ def __init__( self.compilation_config.cudagraph_mode.has_full_cudagraphs() ) self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size + max_num_seqs = vllm_config.scheduler_config.max_num_seqs if self.use_full_cuda_graph and self.fa_aot_schedule: + # Times 4 due to: + # https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653 + # For some tests max_cudagraph_size > max_num_seqs, + # so we need to use the larger one. self.scheduler_metadata = torch.zeros( - vllm_config.scheduler_config.max_num_seqs + 1, + max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1, dtype=torch.int32, device=self.device, )