From bc3efcc477e97cbf5117b3496d8b944a590ba3a9 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 4 Jun 2025 16:53:51 +0000 Subject: [PATCH 01/17] enable full graph capture for TritonAttn Signed-off-by: charlifu --- vllm/v1/worker/gpu_model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a67e37781bf..a8b042dc7816 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2045,11 +2045,12 @@ def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: if self.vllm_config.compilation_config.full_cuda_graph: attn_backend_name = attn_backend_i.__name__ flash_attn_version = get_flash_attn_version() - if attn_backend_name != "FlashAttentionBackend" or \ - flash_attn_version != 3: + if not (attn_backend_name == "FlashAttentionBackend" and \ + flash_attn_version == 3) and \ + attn_backend_name != "TritonAttentionBackend": raise ValueError( f"full_cuda_graph is only supported with " - f"FA3. Current attention backend is " + f"FA3 or triton. Current attention backend is " f"{attn_backend_name}, FlashAttention version is " f"{flash_attn_version}.") From 78c13e7c33e34530927addaf287d18f0251083ac Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 4 Jun 2025 19:08:20 +0000 Subject: [PATCH 02/17] overwrite metadata build func for triton_attn.py Signed-off-by: charlifu --- vllm/v1/attention/backends/triton_attn.py | 104 +++++++++++++++++++++- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 968f13701118..6cb721024d5f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -16,7 +16,9 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionMetadata, FlashAttentionMetadataBuilder) + FlashAttentionMetadata, FlashAttentionMetadataBuilder, + make_local_attention_virtual_batches) +from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -30,9 +32,107 @@ class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder): def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): - super().__init__(runner, kv_cache_spec, block_table) + model_config = runner.model_config + compilation_config = runner.vllm_config.compilation_config + + self.runner = runner + self.use_full_cuda_graph = compilation_config.full_cuda_graph + self.num_heads_q = model_config.get_num_attention_heads( + runner.parallel_config) + self.num_heads_kv = model_config.get_num_kv_heads( + runner.parallel_config) + self.headdim = model_config.get_head_size() + self.block_size = kv_cache_spec.block_size + self.kv_cache_spec = kv_cache_spec + self.block_table = block_table + + # Sliding window size to be used with the AOT scheduler will be + # populated on first build() call. + self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_schedule = False + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata): + max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + block_table = self.block_table + block_table_tensor = block_table.get_device_tensor()[:num_reqs] + + block_table.slot_mapping[:num_actual_tokens].copy_( + block_table.slot_mapping_cpu[:num_actual_tokens], + non_blocking=True) + # Fill unused with -1. Needed for reshape_and_cache in full cuda graph + # mode. + block_table.slot_mapping[num_actual_tokens:].fill_(-1) + + slot_mapping = block_table.slot_mapping[:num_actual_tokens] + + # for local attention + local_attn_metadata = None + if self.runner.attention_chunk_size is not None: + seqlens_q_local_np, virt_q_cu_seqlens_np, virt_k_seqlens_np, \ + virt_block_table_tensor = make_local_attention_virtual_batches( + self.runner.attention_chunk_size, + self.runner.query_start_loc_np[:num_reqs + 1], + self.runner.seq_lens_np[:num_reqs], + block_table_tensor, + self.block_size, + ) + local_query_start_loc = torch.from_numpy(virt_q_cu_seqlens_np).to( + self.runner.device, non_blocking=True) + local_seqused_k = torch.from_numpy(virt_k_seqlens_np).to( + self.runner.device, non_blocking=True) + local_max_query_len = seqlens_q_local_np.max() + local_max_seq_len = virt_k_seqlens_np.max() + + local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_query_start_loc=local_query_start_loc, + local_seqused_k=local_seqused_k, + local_block_table=virt_block_table_tensor, + local_max_query_len=local_max_query_len, + local_max_seq_len=local_max_seq_len, + local_scheduler_metadata=None, + ) + + use_cascade = common_prefix_len > 0 + + if use_cascade: + cu_prefix_query_lens = torch.tensor([0, num_actual_tokens], + dtype=torch.int32, + device=self.runner.device) + prefix_kv_lens = torch.tensor([common_prefix_len], + dtype=torch.int32, + device=self.runner.device) + suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] - + common_prefix_len) + suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to( + self.runner.device) + else: + cu_prefix_query_lens = None + prefix_kv_lens = None + suffix_kv_lens = None + prefix_scheduler_metadata = None + + attn_metadata = FlashAttentionMetadata( + num_actual_tokens=num_actual_tokens, + max_query_len=max_query_len, + query_start_loc=query_start_loc, + max_seq_len=max_seq_len, + seq_lens=seq_lens, + block_table=block_table_tensor, + slot_mapping=slot_mapping, + use_cascade=use_cascade, + common_prefix_len=common_prefix_len, + cu_prefix_query_lens=cu_prefix_query_lens, + prefix_kv_lens=prefix_kv_lens, + suffix_kv_lens=suffix_kv_lens, + local_attn_metadata=local_attn_metadata, + prefix_scheduler_metadata=prefix_scheduler_metadata, + ) + return attn_metadata + class TritonAttentionBackend(AttentionBackend): From 113dbccba592327a375979f93ac8c5182a755865 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 5 Jun 2025 15:51:20 +0000 Subject: [PATCH 03/17] increase max seq len supportted by custom paged attention to 128 * 1024 Signed-off-by: charlifu --- vllm/platforms/rocm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a929366db49c..445b24d72f07 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -141,7 +141,8 @@ def use_rocm_custom_paged_attention( and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) - and max_seq_len <= 32768 and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) + and max_seq_len <= 128 * 1024 + and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN and envs.VLLM_ROCM_USE_AITER)) @@ -151,7 +152,7 @@ def use_rocm_custom_paged_attention( and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 and (gqa_ratio >= 3 and gqa_ratio <= 16) - and max_seq_len <= 32768 and alibi_slopes is None + and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) From 6412b8af1f1684ab87b980195e74e74a72f7eec3 Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 5 Jun 2025 16:35:26 +0000 Subject: [PATCH 04/17] fix graph capture too slow Signed-off-by: charlifu --- vllm/v1/attention/backends/flash_attn.py | 8 ++++++-- vllm/v1/attention/backends/triton_attn.py | 14 ++++++++++++-- vllm/v1/worker/gpu_model_runner.py | 1 + 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index a9f748d026f4..977e4d54f773 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -324,9 +324,13 @@ def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + def build(self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + common_attn_metadata: CommonAttentionMetadata, + for_cudagraph_capture: Optional[bool] = False): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 6cb721024d5f..16c2f3435c9d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -51,15 +51,25 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_schedule = False - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + def build(self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + common_attn_metadata: CommonAttentionMetadata, + for_cudagraph_capture: Optional[bool] = False): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + if self.use_full_cuda_graph and for_cudagraph_capture: + seq_lens.fill_(1) + block_table.slot_mapping[:num_actual_tokens].copy_( block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a8b042dc7816..d7550c3bcc65 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1770,6 +1770,7 @@ def _dummy_run( max_query_len=num_tokens, common_prefix_len=0, common_attn_metadata=common_attn_metadata, + for_cudagraph_capture=True, )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i From 4ffb3dffe42e77479b71e1162f31712b5ebf543e Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 5 Jun 2025 21:11:35 +0000 Subject: [PATCH 05/17] revert max seq len for gfx1* Signed-off-by: charlifu --- vllm/platforms/rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 445b24d72f07..a3404c6eafe5 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -152,7 +152,7 @@ def use_rocm_custom_paged_attention( and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 and (gqa_ratio >= 3 and gqa_ratio <= 16) - and max_seq_len <= 128 * 1024 and alibi_slopes is None + and max_seq_len <= 32768 and alibi_slopes is None and kv_cache_dtype == "auto" and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) From 10701e7673dde393bc8df8389615046a9e78010f Mon Sep 17 00:00:00 2001 From: charlifu Date: Thu, 5 Jun 2025 22:01:04 +0000 Subject: [PATCH 06/17] guard full graph capture from unified attn Signed-off-by: charlifu --- vllm/v1/attention/backends/triton_attn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 8330330b174f..1e182acb9d08 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -68,6 +68,9 @@ def build(self, # max_model_len will cause graph capture to be extremely # slow, so here we set it to 1. if self.use_full_cuda_graph and for_cudagraph_capture: + assert envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION, \ + "full_cuda_graph is only supported with " \ + "VLLM_V1_USE_PREFILL_DECODE_ATTENTION enabled" seq_lens.fill_(1) block_table.slot_mapping[:num_actual_tokens].copy_( From 7188dbf41aebebd70bb43d36d983b782c0aaaaad Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 6 Jun 2025 15:45:20 +0000 Subject: [PATCH 07/17] use 128K for navi Signed-off-by: charlifu --- vllm/platforms/rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index a3404c6eafe5..445b24d72f07 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -152,7 +152,7 @@ def use_rocm_custom_paged_attention( and (qtype == torch.half or qtype == torch.bfloat16) and head_size == 128 and block_size == 16 and (gqa_ratio >= 3 and gqa_ratio <= 16) - and max_seq_len <= 32768 and alibi_slopes is None + and max_seq_len <= 128 * 1024 and alibi_slopes is None and kv_cache_dtype == "auto" and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN) From 72a67af32bf6c970f3c72db913192938eec2f040 Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 6 Jun 2025 15:54:52 +0000 Subject: [PATCH 08/17] unify metadata build function signature Signed-off-by: charlifu --- vllm/v1/attention/backends/cpu_attn.py | 8 ++++++-- vllm/v1/attention/backends/flash_attn.py | 2 +- vllm/v1/attention/backends/mla/common.py | 8 ++++++-- vllm/v1/attention/backends/triton_attn.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index d7a580c2883c..5d777822ad18 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -118,9 +118,13 @@ def reorder_batch(self, input_batch: InputBatch, return True - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + def build(self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): + common_attn_metadata: CommonAttentionMetadata, + for_cudagraph_capture: bool = False): runner = self.runner block_table = self.block_table seq_lens_np = runner.seq_lens_np[:num_reqs] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 495ac4578d61..68112c3d3dcc 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -346,7 +346,7 @@ def build(self, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - for_cudagraph_capture: Optional[bool] = False): + for_cudagraph_capture: bool = False): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index e6b4f6404632..937a6786f319 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -450,9 +450,13 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens, ) - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, + def build(self, + num_reqs: int, + num_actual_tokens: int, + max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata) -> M: + common_attn_metadata: CommonAttentionMetadata, + for_cudagraph_capture: bool = False) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 1e182acb9d08..00bfb68dd0f9 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -57,7 +57,7 @@ def build(self, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - for_cudagraph_capture: Optional[bool] = False): + for_cudagraph_capture: bool = False): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens From 4a96aa383d658686896c9723426f58fa578a69ba Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 9 Jun 2025 16:38:20 +0000 Subject: [PATCH 09/17] remove the guard of unified attention kernel Signed-off-by: charlifu --- vllm/v1/attention/backends/triton_attn.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 00bfb68dd0f9..77ddd245cee4 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -68,9 +68,6 @@ def build(self, # max_model_len will cause graph capture to be extremely # slow, so here we set it to 1. if self.use_full_cuda_graph and for_cudagraph_capture: - assert envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION, \ - "full_cuda_graph is only supported with " \ - "VLLM_V1_USE_PREFILL_DECODE_ATTENTION enabled" seq_lens.fill_(1) block_table.slot_mapping[:num_actual_tokens].copy_( From a189247a5462ba5cd721a40dfa8770dec5dbdb7e Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 10 Jun 2025 16:34:57 +0000 Subject: [PATCH 10/17] use build_for_cudagraph_capture Signed-off-by: charlifu --- vllm/v1/attention/backends/cpu_attn.py | 8 ++----- vllm/v1/attention/backends/flash_attn.py | 15 +++++++------ vllm/v1/attention/backends/mla/common.py | 8 ++----- vllm/v1/attention/backends/triton_attn.py | 26 ++++++++++++----------- vllm/v1/worker/gpu_model_runner.py | 8 +++---- 5 files changed, 30 insertions(+), 35 deletions(-) diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 5d777822ad18..d7a580c2883c 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -118,13 +118,9 @@ def reorder_batch(self, input_batch: InputBatch, return True - def build(self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - for_cudagraph_capture: bool = False): + common_attn_metadata: CommonAttentionMetadata): runner = self.runner block_table = self.block_table seq_lens_np = runner.seq_lens_np[:num_reqs] diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 68112c3d3dcc..d3e1e7c07768 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -340,13 +340,16 @@ def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: return False - def build(self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, + def build_for_cudagraph_capture( + self, num_reqs: int, num_tokens: int, + common_attn_metadata: CommonAttentionMetadata + ) -> FlashAttentionMetadata: + return self.build(num_reqs, num_tokens, num_tokens, 0, + common_attn_metadata) + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - for_cudagraph_capture: bool = False): + common_attn_metadata: CommonAttentionMetadata): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 937a6786f319..e6b4f6404632 100644 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -450,13 +450,9 @@ def _build_decode(self, block_table_tensor: torch.Tensor, seq_lens=seq_lens, ) - def build(self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - for_cudagraph_capture: bool = False) -> M: + common_attn_metadata: CommonAttentionMetadata) -> M: assert self._num_decodes + self._num_prefills == num_reqs # Note(simon): be careful about the CPU <> GPU memory movement in this diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 77ddd245cee4..4b6b1e85889f 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -51,25 +51,27 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_sliding_window: Optional[tuple[int, int]] = None self.aot_schedule = False - def build(self, - num_reqs: int, - num_actual_tokens: int, - max_query_len: int, + def build_for_cudagraph_capture( + self, num_reqs: int, num_tokens: int, + common_attn_metadata: CommonAttentionMetadata): + attn_metadata = self.build(num_reqs, num_tokens, num_tokens, 0, + common_attn_metadata) + # When doing full graph capture, setting seq_lens to + # max_model_len will cause graph capture to be extremely + # slow, so here we set it to 1. + if self.use_full_cuda_graph: + attn_metadata.seq_lens.fill_(1) + return attn_metadata + + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata, - for_cudagraph_capture: bool = False): + common_attn_metadata: CommonAttentionMetadata): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens block_table = self.block_table block_table_tensor = block_table.get_device_tensor()[:num_reqs] - # When doing full graph capture, setting seq_lens to - # max_model_len will cause graph capture to be extremely - # slow, so here we set it to 1. - if self.use_full_cuda_graph and for_cudagraph_capture: - seq_lens.fill_(1) - block_table.slot_mapping[:num_actual_tokens].copy_( block_table.slot_mapping_cpu[:num_actual_tokens], non_blocking=True) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d3cbf07dbf06..4a43ec7e2890 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1802,13 +1802,11 @@ def _dummy_run( for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id].build( + self.attn_metadata_builders[kv_cache_group_id]. + build_for_cudagraph_catpure( num_reqs=num_reqs, - num_actual_tokens=num_tokens, - max_query_len=num_tokens, - common_prefix_len=0, + num_tokens=num_tokens, common_attn_metadata=common_attn_metadata, - for_cudagraph_capture=True, )) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i From bf13c6c0ebb140f849543f415edae1b6004b72bd Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 10 Jun 2025 16:43:08 +0000 Subject: [PATCH 11/17] fix typo Signed-off-by: charlifu --- vllm/v1/worker/gpu_model_runner.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4a43ec7e2890..3f92b3a629fc 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1801,13 +1801,12 @@ def _dummy_run( attn_metadata = {} for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - attn_metadata_i = ( - self.attn_metadata_builders[kv_cache_group_id]. - build_for_cudagraph_catpure( + attn_metadata_i = self.attn_metadata_builders[ + kv_cache_group_id].build_for_cudagraph_capture( num_reqs=num_reqs, num_tokens=num_tokens, common_attn_metadata=common_attn_metadata, - )) + ) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i From 35f611fc4abce75eeefa2691eb81202d2def91d1 Mon Sep 17 00:00:00 2001 From: charlifu Date: Tue, 10 Jun 2025 18:51:39 +0000 Subject: [PATCH 12/17] add unit test for rocm Signed-off-by: charlifu --- tests/compile/piecewise/test_full_cudagraph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index 134bade48607..b35b64d88887 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -62,7 +62,8 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int): return llm.generate(prompts, sampling_params) -@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0), +@pytest.mark.skipif(current_platform.is_cuda() + and current_platform.get_device_capability() != (9, 0), reason="Only Hopper GPUs support FlashAttention 3") @pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), (16, 10), (25, 10), @@ -91,6 +92,7 @@ def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, 0].text == full_cudagraph_responses[i].outputs[0].text +@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_full_cudagraph_with_invalid_backend(): with temporary_environ({ "VLLM_USE_V1": "1", From d098636bb3aeefd30b1501de7dad38af3fa82cc6 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 11 Jun 2025 19:25:24 +0000 Subject: [PATCH 13/17] fix test Signed-off-by: charlifu --- tests/compile/piecewise/test_full_cudagraph.py | 9 +++++---- vllm/v1/attention/backends/triton_attn.py | 3 +-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py index b35b64d88887..7ebb66d9db43 100644 --- a/tests/compile/piecewise/test_full_cudagraph.py +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -62,9 +62,10 @@ def generate_text(llm: LLM, batch_size: int, max_tokens: int): return llm.generate(prompts, sampling_params) -@pytest.mark.skipif(current_platform.is_cuda() - and current_platform.get_device_capability() != (9, 0), - reason="Only Hopper GPUs support FlashAttention 3") +@pytest.mark.skipif(not current_platform.is_cuda_alike() or + (current_platform.is_cuda() + and current_platform.get_device_capability() != (9, 0)), + reason="Only run for Hopper GPUs and AMD GPUs") @pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), (16, 10), (25, 10), (32, 10), (45, 10), @@ -92,7 +93,7 @@ def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, 0].text == full_cudagraph_responses[i].outputs[0].text -@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") +@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda") def test_full_cudagraph_with_invalid_backend(): with temporary_environ({ "VLLM_USE_V1": "1", diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4b6b1e85889f..4c145f2720eb 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -59,8 +59,7 @@ def build_for_cudagraph_capture( # When doing full graph capture, setting seq_lens to # max_model_len will cause graph capture to be extremely # slow, so here we set it to 1. - if self.use_full_cuda_graph: - attn_metadata.seq_lens.fill_(1) + attn_metadata.seq_lens.fill_(1) return attn_metadata def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, From d03096e88e3adc057c61b864612270c7e0c7f41b Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 11 Jun 2025 20:34:29 +0000 Subject: [PATCH 14/17] decouple TritonAttentionMetadata from FlashAttentionMetadata Signed-off-by: charlifu --- vllm/v1/attention/backends/triton_attn.py | 64 ++++++++++++++++++++--- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 4c145f2720eb..c6d0fb129c07 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Optional import torch @@ -16,19 +17,62 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionMetadata, FlashAttentionMetadataBuilder, - make_local_attention_virtual_batches) + FlashAttentionMetadata, make_local_attention_virtual_batches) from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) -class TritonAttentionMetadataBuilder(FlashAttentionMetadataBuilder): +@dataclass +class TritonAttentionMetadata: + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + + num_actual_tokens: int # Number of tokens excluding padding. + max_query_len: int + query_start_loc: torch.Tensor + max_seq_len: int + seq_lens: torch.Tensor + block_table: torch.Tensor + slot_mapping: torch.Tensor + + # For cascade attention. + use_cascade: bool + common_prefix_len: int + cu_prefix_query_lens: Optional[torch.Tensor] + prefix_kv_lens: Optional[torch.Tensor] + suffix_kv_lens: Optional[torch.Tensor] + + # Optional aot scheduling + scheduler_metadata: Optional[torch.Tensor] = None + prefix_scheduler_metadata: Optional[torch.Tensor] = None + + # for local attention + @dataclass + class LocalAttentionMetadata: + local_query_start_loc: torch.Tensor + local_seqused_k: torch.Tensor + local_block_table: torch.Tensor + local_max_query_len: int + local_max_seq_len: int + local_scheduler_metadata: Optional[torch.Tensor] + + local_attn_metadata: Optional[LocalAttentionMetadata] = None + + +class TritonAttentionMetadataBuilder: def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -62,6 +106,10 @@ def build_for_cudagraph_capture( attn_metadata.seq_lens.fill_(1) return attn_metadata + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + return False + def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata): @@ -98,7 +146,8 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, local_max_query_len = seqlens_q_local_np.max() local_max_seq_len = virt_k_seqlens_np.max() - local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata( + local_attn_metadata = TritonAttentionMetadata \ + .LocalAttentionMetadata( local_query_start_loc=local_query_start_loc, local_seqused_k=local_seqused_k, local_block_table=virt_block_table_tensor, @@ -126,7 +175,7 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, suffix_kv_lens = None prefix_scheduler_metadata = None - attn_metadata = FlashAttentionMetadata( + attn_metadata = TritonAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, query_start_loc=query_start_loc, @@ -144,6 +193,9 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) return attn_metadata + def use_cascade_attention(self, *args, **kwargs) -> bool: + return False + class TritonAttentionBackend(AttentionBackend): @@ -163,7 +215,7 @@ def get_impl_cls() -> type["TritonAttentionImpl"]: @staticmethod def get_metadata_cls() -> type["AttentionMetadata"]: - return FlashAttentionMetadata + return TritonAttentionMetadata @staticmethod def get_kv_cache_shape( From 121853d14e47a45ef484aa742e992e032cfd8834 Mon Sep 17 00:00:00 2001 From: charlifu Date: Wed, 11 Jun 2025 20:42:34 +0000 Subject: [PATCH 15/17] move make_local_attention_virtual_batches into utils.py Signed-off-by: charlifu --- vllm/v1/attention/backends/flash_attn.py | 169 +--------------------- vllm/v1/attention/backends/triton_attn.py | 6 +- vllm/v1/attention/backends/utils.py | 169 ++++++++++++++++++++++ 3 files changed, 174 insertions(+), 170 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d3e1e7c07768..227ec45097f4 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -21,7 +21,8 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils import cdiv -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable @@ -129,172 +130,6 @@ class LocalAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None -# -# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into -# local attention blocks, where each block is passed to the attention kernel -# as an independent local ("virtual") batch item. -# -# For example, if are performing a chunked prefill a batch of 3 sequences: -# q_seqlens = [4, 10, 5] -# kv_seqlens = [6, 17, 9] -# Then normally for regular attention we would compute with an attention mask -# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) -# k_toks > 0 1 2 3 4 5 -# q_toks v _____________ -# 0 | 1 1 1 -# 1 | 1 1 1 1 -# 2 | 1 1 1 1 1 -# 3 | 1 1 1 1 1 1 -# -# for local attention (with attn_chunk_size = 4) we would compute with an -# attention mask like: -# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) -# k_toks > 0 1 2 3 4 5 -# q_toks v _____________ -# 0 | 1 1 1 -# 1 | 1 1 1 1 -# 2 | 1 -# 3 | 1 1 -# -# We can simulate this mask using standard flash-attention by breaking the -# sequences into local ("virtual") batches, where each local batch item is a -# local attention block, so in this case batch idx 0 would be broken up into: -# -# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) -# k_toks > 0 1 2 3 -# q_toks v _____________ -# 0 | 1 1 1 -# 1 | 1 1 1 1 -# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) -# k_toks > 4 5 -# q_toks v _____________ -# 2 | 1 -# 3 | 1 1 -# -# e.g. if we have: -# attn_chunk_size = 4 -# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) -# Then this function would return: -# __b0__ ______b1______ __b2__ < orig batch indices -# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] -# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] -# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] -# block_table_local : shape[local_virtual_batches, pages_per_local_batch] -def make_local_attention_virtual_batches( - attn_chunk_size: int, - query_start_loc_np: np.ndarray, - seq_lens_np: np.ndarray, - block_table: torch.Tensor, - block_size: int = 0, -) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: - q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] - actual_batch_size = seq_lens_np.shape[0] - - # Handle if we are starting in the middle of a local attention block, - # we assume q_seqlens > 0 (for all elements), for each batch idx we compute - # the number of tokens that are not in the first local attention block and - # then we can simply use a cdiv for the rest. - # For example if we have: - # attn_chunk_size = 4 - # q_seqlens = [4, 10, 5] - # k_seqlens = [6, 17, 9] - # Then we would get: - # new_tokens_in_first_block = [2, 1, 4] - # local_blocks = [2, 4, 2] - q_tokens_in_first_block = np.minimum( - attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), - q_seqlens).astype(np.int32) - tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) - local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, - attn_chunk_size) - - # Once we know the number of local blocks we can compute the request spans - # for each batch idx, we can figure out the number of "virtual" requests we - # have to make, - # For the above example we would get: - # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] - # - # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) - # (TODO: max a utility to share this code with _prepare_inputs) - # arange step 1. [2, 4, 2] -> [2, 6, 8] - cu_num_blocks = np.cumsum(local_blocks) - virtual_batches = cu_num_blocks[-1] - # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] - block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) - # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] - arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets - # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) - rarange = np.repeat(local_blocks, local_blocks) - arange - 1 - # Then we can compute the seqlens_q_local, handling the fact that the - # first and last blocks could be partial - seqlens_q_local = \ - np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) - # set the first block since this may be a partial block - seqlens_q_local[arange == 0] = q_tokens_in_first_block - # set the remaining blocks - seqlens_q_local[arange > 0] = np.minimum( - seqlens_q_local - attn_chunk_size * (arange - 1), - attn_chunk_size)[arange > 0] - - # convert from q_seqlens to cu_seqlens_q - cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ - .astype(np.int32) - - # compute the seqlens_k_local, - # basically a full local attention block for all but the last block in each - # batch - # For our example this will be: - # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] - seqlens_k_local = np.full(cu_num_blocks[-1], - attn_chunk_size, - dtype=np.int32) - seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block - - k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ - (rarange * attn_chunk_size + \ - np.repeat(tokens_in_last_block, local_blocks)) - # For the example the local attention blocks start at: - # _b0_ _____b1_____ _b2_ - # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] - block_starts = k_seqstarts_absolute // block_size - assert attn_chunk_size % block_size == 0, \ - f"attn_chunk_size {attn_chunk_size} is not " \ - f"divisible by block_size {block_size}" - pages_per_local_batch = attn_chunk_size // block_size - - # Create a block_table for the local attention blocks - # For out example if we have a block-table like (assuming block_size=2): - # block_table = [ - # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 - # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 - # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 - # ] - # Then for the local batches we would want a block-table like - # block_table_local = [ - # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) - # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) - # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) - # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) - # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) - # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) - # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) - # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) - # ] - block_indices= np.broadcast_to( - np.arange(pages_per_local_batch, dtype=np.int32), - (virtual_batches, pages_per_local_batch)) \ - + np.expand_dims(block_starts, axis=1) - block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) - batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), - local_blocks * pages_per_local_batch) - block_table_local = block_table[batch_indices, block_indices]\ - .view(virtual_batches, -1) - - return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ - block_table_local - - def _get_sliding_window_configs( vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]: """Get the set of all sliding window configs used in the model.""" diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index c6d0fb129c07..a3600740664b 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -16,9 +16,9 @@ from vllm.attention.ops.triton_unified_attention import unified_attention from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.v1.attention.backends.flash_attn import ( - FlashAttentionMetadata, make_local_attention_virtual_batches) -from vllm.v1.attention.backends.utils import CommonAttentionMetadata +from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata +from vllm.v1.attention.backends.utils import ( + CommonAttentionMetadata, make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 72c764353927..977729bc08cf 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -2,8 +2,11 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass +import numpy as np import torch +from vllm.utils import cdiv + @dataclass class CommonAttentionMetadata: @@ -50,3 +53,169 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name, raise ValueError( error_msg + f"must be the same type as the current layer ({expected}).") + + +# +# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into +# local attention blocks, where each block is passed to the attention kernel +# as an independent local ("virtual") batch item. +# +# For example, if are performing a chunked prefill a batch of 3 sequences: +# q_seqlens = [4, 10, 5] +# kv_seqlens = [6, 17, 9] +# Then normally for regular attention we would compute with an attention mask +# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 1 1 1 1 +# 3 | 1 1 1 1 1 1 +# +# for local attention (with attn_chunk_size = 4) we would compute with an +# attention mask like: +# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) +# k_toks > 0 1 2 3 4 5 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# 2 | 1 +# 3 | 1 1 +# +# We can simulate this mask using standard flash-attention by breaking the +# sequences into local ("virtual") batches, where each local batch item is a +# local attention block, so in this case batch idx 0 would be broken up into: +# +# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) +# k_toks > 0 1 2 3 +# q_toks v _____________ +# 0 | 1 1 1 +# 1 | 1 1 1 1 +# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) +# k_toks > 4 5 +# q_toks v _____________ +# 2 | 1 +# 3 | 1 1 +# +# e.g. if we have: +# attn_chunk_size = 4 +# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) +# Then this function would return: +# __b0__ ______b1______ __b2__ < orig batch indices +# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] +# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] +# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] +# block_table_local : shape[local_virtual_batches, pages_per_local_batch] +def make_local_attention_virtual_batches( + attn_chunk_size: int, + query_start_loc_np: np.ndarray, + seq_lens_np: np.ndarray, + block_table: torch.Tensor, + block_size: int = 0, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: + q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] + actual_batch_size = seq_lens_np.shape[0] + + # Handle if we are starting in the middle of a local attention block, + # we assume q_seqlens > 0 (for all elements), for each batch idx we compute + # the number of tokens that are not in the first local attention block and + # then we can simply use a cdiv for the rest. + # For example if we have: + # attn_chunk_size = 4 + # q_seqlens = [4, 10, 5] + # k_seqlens = [6, 17, 9] + # Then we would get: + # new_tokens_in_first_block = [2, 1, 4] + # local_blocks = [2, 4, 2] + q_tokens_in_first_block = np.minimum( + attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), + q_seqlens).astype(np.int32) + tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) + local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, + attn_chunk_size) + + # Once we know the number of local blocks we can compute the request spans + # for each batch idx, we can figure out the number of "virtual" requests we + # have to make, + # For the above example we would get: + # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] + # + # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) + # (TODO: max a utility to share this code with _prepare_inputs) + # arange step 1. [2, 4, 2] -> [2, 6, 8] + cu_num_blocks = np.cumsum(local_blocks) + virtual_batches = cu_num_blocks[-1] + # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] + block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) + # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] + arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets + # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) + rarange = np.repeat(local_blocks, local_blocks) - arange - 1 + # Then we can compute the seqlens_q_local, handling the fact that the + # first and last blocks could be partial + seqlens_q_local = \ + np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) + # set the first block since this may be a partial block + seqlens_q_local[arange == 0] = q_tokens_in_first_block + # set the remaining blocks + seqlens_q_local[arange > 0] = np.minimum( + seqlens_q_local - attn_chunk_size * (arange - 1), + attn_chunk_size)[arange > 0] + + # convert from q_seqlens to cu_seqlens_q + cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ + .astype(np.int32) + + # compute the seqlens_k_local, + # basically a full local attention block for all but the last block in each + # batch + # For our example this will be: + # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] + seqlens_k_local = np.full(cu_num_blocks[-1], + attn_chunk_size, + dtype=np.int32) + seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block + + k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ + (rarange * attn_chunk_size + \ + np.repeat(tokens_in_last_block, local_blocks)) + # For the example the local attention blocks start at: + # _b0_ _____b1_____ _b2_ + # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] + block_starts = k_seqstarts_absolute // block_size + assert attn_chunk_size % block_size == 0, \ + f"attn_chunk_size {attn_chunk_size} is not " \ + f"divisible by block_size {block_size}" + pages_per_local_batch = attn_chunk_size // block_size + + # Create a block_table for the local attention blocks + # For out example if we have a block-table like (assuming block_size=2): + # block_table = [ + # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 + # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 + # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 + # ] + # Then for the local batches we would want a block-table like + # block_table_local = [ + # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) + # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) + # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) + # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) + # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) + # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) + # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) + # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) + # ] + block_indices= np.broadcast_to( + np.arange(pages_per_local_batch, dtype=np.int32), + (virtual_batches, pages_per_local_batch)) \ + + np.expand_dims(block_starts, axis=1) + block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) + batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), + local_blocks * pages_per_local_batch) + block_table_local = block_table[batch_indices, block_indices]\ + .view(virtual_batches, -1) + + return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ + block_table_local From c3bf5a349dad193168ce531d3068257b8628112f Mon Sep 17 00:00:00 2001 From: charlifu Date: Fri, 13 Jun 2025 19:22:42 +0000 Subject: [PATCH 16/17] refactor tritionatten builder Signed-off-by: charlifu --- vllm/v1/attention/backends/triton_attn.py | 37 ++++++++++++----------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index ec7ac47088eb..cd5f9f66c7d3 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with PagedAttention and Triton prefix prefill.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Optional import torch @@ -18,13 +18,12 @@ from vllm.platforms import current_platform from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.utils import ( - CommonAttentionMetadata, make_local_attention_virtual_batches) + AttentionMetadataBuilder, CommonAttentionMetadata, + make_local_attention_virtual_batches) from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.block_table import BlockTable if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner logger = init_logger(__name__) @@ -72,7 +71,9 @@ class LocalAttentionMetadata: local_attn_metadata: Optional[LocalAttentionMetadata] = None -class TritonAttentionMetadataBuilder: +class TritonAttentionMetadataBuilder( + AttentionMetadataBuilder[TritonAttentionMetadata]): + full_cudagraph_supported: ClassVar[bool] = True def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): @@ -96,23 +97,23 @@ def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, self.aot_schedule = False def build_for_cudagraph_capture( - self, num_reqs: int, num_tokens: int, - common_attn_metadata: CommonAttentionMetadata): - attn_metadata = self.build(num_reqs, num_tokens, num_tokens, 0, - common_attn_metadata) + self, common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + attn_metadata = self.build(0, common_attn_metadata) # When doing full graph capture, setting seq_lens to # max_model_len will cause graph capture to be extremely # slow, so here we set it to 1. attn_metadata.seq_lens.fill_(1) return attn_metadata - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - return False + def build( + self, common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata + ) -> TritonAttentionMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + max_query_len = common_attn_metadata.max_query_len - def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, - common_prefix_len: int, - common_attn_metadata: CommonAttentionMetadata): max_seq_len = int(self.runner.seq_lens_np[:num_reqs].max()) query_start_loc = common_attn_metadata.query_start_loc seq_lens = common_attn_metadata.seq_lens @@ -193,8 +194,10 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, ) return attn_metadata - def use_cascade_attention(self, *args, **kwargs) -> bool: - return False + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + # Full CUDA Graph always supported + return True class TritonAttentionBackend(AttentionBackend): From 224e27ba799fca06214f5009f33fabbd9057e63e Mon Sep 17 00:00:00 2001 From: charlifu Date: Mon, 16 Jun 2025 18:59:03 +0000 Subject: [PATCH 17/17] remove unused vars Signed-off-by: charlifu --- vllm/v1/attention/backends/triton_attn.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 35929978b126..ecb92bb1e416 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -77,25 +77,11 @@ class TritonAttentionMetadataBuilder( def __init__(self, runner: "GPUModelRunner", kv_cache_spec: AttentionSpec, block_table: BlockTable): - model_config = runner.model_config - compilation_config = runner.vllm_config.compilation_config - self.runner = runner - self.use_full_cuda_graph = compilation_config.full_cuda_graph - self.num_heads_q = model_config.get_num_attention_heads( - runner.parallel_config) - self.num_heads_kv = model_config.get_num_kv_heads( - runner.parallel_config) - self.headdim = model_config.get_head_size() self.block_size = kv_cache_spec.block_size self.kv_cache_spec = kv_cache_spec self.block_table = block_table - # Sliding window size to be used with the AOT scheduler will be - # populated on first build() call. - self.aot_sliding_window: Optional[tuple[int, int]] = None - self.aot_schedule = False - def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> TritonAttentionMetadata: