diff --git a/tests/models/language/generation/test_hybrid.py b/tests/models/language/generation/test_hybrid.py index 4934da95179d..76f6c226bab7 100644 --- a/tests/models/language/generation/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -384,3 +384,63 @@ def test_distributed_correctness( name_0="vllm_tp_1", name_1="vllm_tp_2", ) + + +@pytest.mark.parametrize("model", ["Zyphra/Zamba2-1.2B-instruct"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_full_cuda_graph( + hf_runner, + vllm_runner, + example_prompts, + monkeypatch, + model: str, + max_tokens: int, + num_logprobs: int, +) -> None: + + try: + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + model_info.check_transformers_version(on_fail="skip") + except ValueError: + pass + + with hf_runner(model) as hf_model: + if model not in HF_UNSUPPORTED_MODELS: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + else: + hf_outputs = None + + with vllm_runner(model, max_num_seqs=MAX_NUM_SEQS) as vllm_model: + vllm_v0_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + if model in HYBRID_MODELS: + # required due to reorder_batch behaviour + m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER") + with vllm_runner(model, + max_num_seqs=MAX_NUM_SEQS, + compilation_config={'full_cuda_graph': True}, + enable_prefix_caching=False) as vllm_model: + vllm_v1_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + if hf_outputs is not None: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_v0_outputs, + name_0="hf", + name_1="vllm-v0", + ) + + ref_outputs = hf_outputs if hf_outputs is not None else vllm_v0_outputs + check_logprobs_close( + outputs_0_lst=ref_outputs, + outputs_1_lst=vllm_v1_outputs, + name_0="hf" if hf_outputs is not None else "vllm-v0", + name_1="vllm-v1", + ) diff --git a/vllm/v1/attention/backends/mamba_attn.py b/vllm/v1/attention/backends/mamba_attn.py index 66a8d91db89c..7c1226049f69 100644 --- a/vllm/v1/attention/backends/mamba_attn.py +++ b/vllm/v1/attention/backends/mamba_attn.py @@ -7,8 +7,10 @@ import torch from vllm.attention.backends.abstract import AttentionBackend +from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig -from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, CommonAttentionMetadata, split_decodes_and_prefills) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec @@ -82,6 +84,8 @@ class Mamba2AttentionMetadata: class Mamba2AttentionMetadataBuilder( AttentionMetadataBuilder[Mamba2AttentionMetadata]): + attn_cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.PURE_DECODE_ONLY reorder_batch_threshold: ClassVar[int] = 1 @@ -90,8 +94,18 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], assert isinstance(kv_cache_spec, MambaSpec) self.kv_cache_spec = kv_cache_spec self.chunk_size = vllm_config.model_config.get_mamba_chunk_size() + self.vllm_config = vllm_config + self.compilation_config = vllm_config.compilation_config assert self.chunk_size is not None, ( "chunk_size needs to be set in the model config for Mamba2 models") + self.decode_cudagraph_max_bs = min( + self.vllm_config.scheduler_config.max_num_seqs, + self.compilation_config.max_capture_size) + self.state_indices_tensor = torch.empty( + (self.decode_cudagraph_max_bs, ), + dtype=torch.int32, + device=device, + ) def build(self, common_prefix_len: int, @@ -144,6 +158,14 @@ def build(self, query_start_loc_p, self.chunk_size, num_prefill_tokens)) + elif num_decodes <= self.decode_cudagraph_max_bs: + # Pad state tensor for CUDA graph + num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes) + self.state_indices_tensor[:num_decodes].copy_(state_indices_tensor, + non_blocking=True) + state_indices_tensor = self.state_indices_tensor[:num_input_tokens] + state_indices_tensor[num_decodes:] = PAD_SLOT_ID + attn_metadata = Mamba2AttentionMetadata( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, @@ -160,3 +182,23 @@ def build(self, state_indices_tensor=state_indices_tensor, ) return attn_metadata + + def build_for_cudagraph_capture( + self, common_attn_metadata: CommonAttentionMetadata): + """ + This method builds the metadata for full cudagraph capture. + Currently, only decode is supported for full cudagraphs with Mamba. + """ + m = common_attn_metadata + + assert m.num_reqs == m.num_actual_tokens, \ + "Mamba only supports decode-only full CUDAGraph capture. " \ + "Make sure all cudagraph capture sizes <= max_num_seq." + + m.max_query_len = 1 # decode-only + + return self.build(0, m) + + def can_run_in_cudagraph( + self, common_attn_metadata: CommonAttentionMetadata) -> bool: + return common_attn_metadata.max_query_len == 1