Skip to content

Commit 6b91fc3

Browse files
committed
Enable FCG decode-only for Mamba
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 1928556 commit 6b91fc3

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, ClassVar, Optional
66

77
import torch
88

99
from vllm.attention.backends.abstract import AttentionBackend
1010
from vllm.config import VllmConfig
1111
from vllm.v1.attention.backends.utils import (
12-
AttentionMetadataBuilder, CommonAttentionMetadata,
12+
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
1313
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
1414
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
1515

@@ -87,6 +87,9 @@ class Mamba2AttentionMetadata:
8787
class Mamba2AttentionMetadataBuilder(
8888
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
8989

90+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
91+
AttentionCGSupport.PURE_DECODE_ONLY
92+
9093
def __init__(self, kv_cache_spec: AttentionSpec, vllm_config: VllmConfig,
9194
device: torch.device):
9295
assert isinstance(kv_cache_spec, MambaSpec)
@@ -168,3 +171,22 @@ def build(self,
168171
state_indices_tensor=state_indices_tensor,
169172
)
170173
return attn_metadata
174+
175+
def can_run_in_cudagraph(
176+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
177+
return common_attn_metadata.max_query_len == 1
178+
179+
def build_for_cudagraph_capture(
180+
self, common_attn_metadata: CommonAttentionMetadata):
181+
"""
182+
This method builds the metadata for full cudagraph capture.
183+
Currently, only decode is supported for full cudagraphs with MLA.
184+
"""
185+
m = common_attn_metadata
186+
assert m.num_reqs == m.num_actual_tokens, \
187+
"MLA only supports decode-only full CUDAGraph capture. " \
188+
"Make sure all cudagraph capture sizes <= max_num_seq."
189+
190+
m.max_query_len = 1 # decode-only
191+
192+
return self.build(0, m)

0 commit comments

Comments
 (0)