Skip to content

Commit 01f10d6

Browse files
committed
Enable decode-only FCG for mamba
Signed-off-by: Thomas Parnell <[email protected]>
1 parent 067c34a commit 01f10d6

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

vllm/v1/attention/backends/mamba_attn.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import math
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Optional
5+
from typing import TYPE_CHECKING, ClassVar, Optional
66

77
import torch
88

99
from vllm.attention.backends.abstract import AttentionBackend
1010
from vllm.config import VllmConfig
1111
from vllm.v1.attention.backends.utils import (
12-
AttentionMetadataBuilder, CommonAttentionMetadata,
12+
AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata,
1313
reorder_batch_to_split_decodes_and_prefills, split_decodes_and_prefills)
1414
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
1515

@@ -86,6 +86,8 @@ class Mamba2AttentionMetadata:
8686

8787
class Mamba2AttentionMetadataBuilder(
8888
AttentionMetadataBuilder[Mamba2AttentionMetadata]):
89+
attn_cudagraph_support: ClassVar[AttentionCGSupport] = \
90+
AttentionCGSupport.PURE_DECODE_ONLY
8991

9092
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
9193
vllm_config: VllmConfig, device: torch.device):
@@ -168,3 +170,23 @@ def build(self,
168170
state_indices_tensor=state_indices_tensor,
169171
)
170172
return attn_metadata
173+
174+
def build_for_cudagraph_capture(
175+
self, common_attn_metadata: CommonAttentionMetadata):
176+
"""
177+
This method builds the metadata for full cudagraph capture.
178+
Currently, only decode is supported for full cudagraphs with Mamba.
179+
"""
180+
m = common_attn_metadata
181+
182+
assert m.num_reqs == m.num_actual_tokens, \
183+
"Mamba only supports decode-only full CUDAGraph capture. " \
184+
"Make sure all cudagraph capture sizes <= max_num_seq."
185+
186+
m.max_query_len = 1 # decode-only
187+
188+
return self.build(0, m)
189+
190+
def can_run_in_cudagraph(
191+
self, common_attn_metadata: CommonAttentionMetadata) -> bool:
192+
return common_attn_metadata.max_query_len == 1

0 commit comments

Comments
 (0)