|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import math |
4 | 4 | from dataclasses import dataclass |
5 | | -from typing import TYPE_CHECKING, Optional |
| 5 | +from typing import TYPE_CHECKING, ClassVar, Optional |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 |
|
9 | 9 | from vllm.attention.backends.abstract import AttentionBackend |
10 | 10 | from vllm.config import VllmConfig |
11 | 11 | from vllm.v1.attention.backends.utils import ( |
12 | | - AttentionMetadataBuilder, CommonAttentionMetadata, |
| 12 | + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, |
13 | 13 | reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) |
14 | 14 | from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec |
15 | 15 |
|
@@ -87,6 +87,9 @@ class Mamba2AttentionMetadata: |
87 | 87 | class Mamba2AttentionMetadataBuilder( |
88 | 88 | AttentionMetadataBuilder[Mamba2AttentionMetadata]): |
89 | 89 |
|
| 90 | + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ |
| 91 | + AttentionCGSupport.PURE_DECODE_ONLY |
| 92 | + |
90 | 93 | def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig, |
91 | 94 | device: torch.device): |
92 | 95 | assert isinstance(kv_cache_spec, MambaSpec) |
@@ -168,3 +171,22 @@ def build(self, |
168 | 171 | state_indices_tensor=state_indices_tensor, |
169 | 172 | ) |
170 | 173 | return attn_metadata |
| 174 | + |
| 175 | + def can_run_in_cudagraph( |
| 176 | + self, common_attn_metadata: CommonAttentionMetadata) -> bool: |
| 177 | + return common_attn_metadata.max_query_len == 1 |
| 178 | + |
| 179 | + def build_for_cudagraph_capture( |
| 180 | + self, common_attn_metadata: CommonAttentionMetadata): |
| 181 | + """ |
| 182 | + This method builds the metadata for full cudagraph capture. |
| 183 | + Currently, only decode is supported for full cudagraphs with MLA. |
| 184 | + """ |
| 185 | + m = common_attn_metadata |
| 186 | + assert m.num_reqs == m.num_actual_tokens, \ |
| 187 | + "MLA only supports decode-only full CUDAGraph capture. " \ |
| 188 | + "Make sure all cudagraph capture sizes <= max_num_seq." |
| 189 | + |
| 190 | + m.max_query_len = 1 # decode-only |
| 191 | + |
| 192 | + return self.build(0, m) |
0 commit comments