diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py index e84312970989..d94055cbe46b 100644 --- a/vllm/v1/attention/backends/mla/indexer.py +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from dataclasses import dataclass -from typing import ClassVar import torch @@ -25,6 +24,7 @@ split_decodes_and_prefills, split_prefill_chunks, ) +from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.worker.cp_utils import get_total_cp_world_size logger = init_logger(__name__) @@ -202,10 +202,22 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig): class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): - _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH - reorder_batch_threshold: int = 1 + @classmethod + def get_cudagraph_support( + cls, + vllm_config: VllmConfig, + kv_cache_spec: AttentionSpec, + ) -> AttentionCGSupport: + if not is_deep_gemm_supported(): + logger.warning_once( + "DeepGEMM is not available. Disabling CUDA graph support " + "for sparse attention indexer. This may reduce performance.", + ) + return AttentionCGSupport.NEVER + return AttentionCGSupport.UNIFORM_BATCH + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) scheduler_config = self.vllm_config.scheduler_config