Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
GIT_TAG 2adfc8c2177c5b0e8ddeedfd5a8990d80eb496ff
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,15 @@ def __init__(
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
max_num_seqs = vllm_config.scheduler_config.max_num_seqs

if self.use_full_cuda_graph and self.aot_schedule:
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
dtype=torch.int32,
device=self.device,
)
Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/attention/backends/mla/flashattn_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,15 @@ def __init__(
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
)
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
max_num_seqs = vllm_config.scheduler_config.max_num_seqs

if self.use_full_cuda_graph and self.fa_aot_schedule:
# Times 4 due to:
# https://github.com/vllm-project/flash-attention/blob/3223650ccabe622a0fcae65eec706a50186a89f7/hopper/flash_api.cpp#L650-L653
# For some tests max_cudagraph_size > max_num_seqs,
# so we need to use the larger one.
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
max(self.max_cudagraph_size or 0, max_num_seqs) * 4 + 1,
dtype=torch.int32,
device=self.device,
)
Expand Down