|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
3 | 3 | import math |
4 | 4 | from dataclasses import dataclass |
5 | | -from typing import TYPE_CHECKING, Optional |
| 5 | +from typing import TYPE_CHECKING, ClassVar, Optional |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 |
|
9 | 9 | from vllm.attention.backends.abstract import AttentionBackend |
10 | 10 | from vllm.config import VllmConfig |
11 | 11 | from vllm.v1.attention.backends.utils import ( |
12 | | - AttentionMetadataBuilder, CommonAttentionMetadata, |
| 12 | + AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, |
13 | 13 | reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills) |
14 | 14 | from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec |
15 | 15 |
|
@@ -86,6 +86,8 @@ class Mamba2AttentionMetadata: |
86 | 86 |
|
87 | 87 | class Mamba2AttentionMetadataBuilder( |
88 | 88 | AttentionMetadataBuilder[Mamba2AttentionMetadata]): |
| 89 | + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ |
| 90 | + AttentionCGSupport.PURE_DECODE_ONLY |
89 | 91 |
|
90 | 92 | def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], |
91 | 93 | vllm_config: VllmConfig, device: torch.device): |
@@ -168,3 +170,23 @@ def build(self, |
168 | 170 | state_indices_tensor=state_indices_tensor, |
169 | 171 | ) |
170 | 172 | return attn_metadata |
| 173 | + |
| 174 | + def build_for_cudagraph_capture( |
| 175 | + self, common_attn_metadata: CommonAttentionMetadata): |
| 176 | + """ |
| 177 | + This method builds the metadata for full cudagraph capture. |
| 178 | + Currently, only decode is supported for full cudagraphs with Mamba. |
| 179 | + """ |
| 180 | + m = common_attn_metadata |
| 181 | + |
| 182 | + assert m.num_reqs == m.num_actual_tokens, \ |
| 183 | + "Mamba only supports decode-only full CUDAGraph capture. " \ |
| 184 | + "Make sure all cudagraph capture sizes <= max_num_seq." |
| 185 | + |
| 186 | + m.max_query_len = 1 # decode-only |
| 187 | + |
| 188 | + return self.build(0, m) |
| 189 | + |
| 190 | + def can_run_in_cudagraph( |
| 191 | + self, common_attn_metadata: CommonAttentionMetadata) -> bool: |
| 192 | + return common_attn_metadata.max_query_len == 1 |
0 commit comments