Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm/v1/attention/backends/mla/indexer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from typing import ClassVar

import torch

Expand All @@ -25,6 +24,7 @@
split_decodes_and_prefills,
split_prefill_chunks,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.cp_utils import get_total_cp_world_size

logger = init_logger(__name__)
Expand Down Expand Up @@ -202,10 +202,22 @@ def get_max_prefill_buffer_size(vllm_config: VllmConfig):


class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

reorder_batch_threshold: int = 1

@classmethod
def get_cudagraph_support(
cls,
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
if not is_deep_gemm_supported():
logger.warning_once(
"DeepGEMM is not available. Disabling CUDA graph support "
"for sparse attention indexer. This may reduce performance.",
)
return AttentionCGSupport.NEVER
return AttentionCGSupport.UNIFORM_BATCH

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
scheduler_config = self.vllm_config.scheduler_config
Expand Down