Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 170 additions & 140 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,19 +163,23 @@ def log_warnings(cls):

@classmethod
def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
from vllm.v1.attention.backends.registry import AttentionBackendEnum

parallel_config = vllm_config.parallel_config
model_config = vllm_config.model_config

if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker"

cache_config = vllm_config.cache_config
user_specified_block_size = (
cache_config is not None and cache_config.block_size is not None
)
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

# TODO(lucas): handle this more gracefully
# For MLA models, determine the backend that will be selected and set
# block_size based on that backend's requirements. This ensures
# consistency between the block_size set here and the backend that
# get_attn_backend_cls will select later.
# Note: model_config may be None during testing
# Note: block_size is initialized in
# HybridAttentionMambaModelConfig.verify_and_update_config
Expand All @@ -186,99 +190,70 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None:
and model_config.use_mla
and cache_config.block_size is not None
):
use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk")
# If `--attention-config.backend` is not set and we are using MLA,
# then we default to FlashMLA backend for non-blackwell GPUs,
# else we default to CutlassMLA. For each case, we force the
# required block_size.
use_flashmla = False
use_cutlass_mla = False
use_flashinfer_mla = False
use_flashmla_sparse = False
use_flashinfer_mla_sparse = False

from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported

if vllm_config.attention_config.backend is None:
# Default case
hf_text_config = model_config.hf_text_config
qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1)
if (
cls.is_device_capability_family(100)
and not use_sparse
and qk_nope_head_dim == 128
):
# Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2)
# and only if qk_nope_head_dim == 128 (kernel constraint)
use_flashinfer_mla = True
# Set the backend in AttentionConfig so it's used during
# backend selection
vllm_config.attention_config.backend = (
AttentionBackendEnum.FLASHINFER_MLA
)
elif cls.is_device_capability_family(100) and not use_sparse:
# Fall back to CUTLASS_MLA as 2nd priority on Blackwell
use_cutlass_mla = True
elif is_flashmla_dense_supported()[0]:
# Non-Blackwell with FlashMLA support
use_flashmla = True
else:
# Fallback: will use Triton MLA or other compatible backend
pass
else:
# Forced case
backend = vllm_config.attention_config.backend
use_flashmla = backend == AttentionBackendEnum.FLASHMLA
use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA
use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA
use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE
use_flashinfer_mla_sparse = (
backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE
)

if (
use_flashmla
and is_flashmla_dense_supported()[0]
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info("Forcing kv cache block size to 64 for FlashMLA backend.")

if use_cutlass_mla and cache_config.block_size % 128 != 0:
cache_config.block_size = 128
logger.info(
"Forcing kv cache block size to 128 for CUTLASS_MLA backend."
from vllm.config.vllm import set_current_vllm_config
from vllm.v1.attention.selector import AttentionSelectorConfig

use_sparse = hasattr(model_config.hf_config, "index_topk")
device_capability = cls.get_device_capability()

# Only adjust block size if we can determine device capability
if device_capability is not None:
# Build a minimal selector config to find which backend will
# be selected. Note: model_config.dtype is torch.dtype after init
attn_selector_config = AttentionSelectorConfig(
head_size=model_config.get_head_size(),
dtype=model_config.dtype, # type: ignore[arg-type]
kv_cache_dtype=cache_config.cache_dtype,
block_size=None,
use_mla=True,
has_sink=False,
use_sparse=use_sparse,
use_mm_prefix=model_config.is_mm_prefix_lm,
)

if (
use_flashinfer_mla
and cache_config.block_size != 32
and cache_config.block_size % 64 != 0
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLA backend."
# Use the same selection logic as get_attn_backend_cls
attention_config = vllm_config.attention_config
requested_backend = (
attention_config.backend if attention_config else None
)

if use_sparse:
if not (use_flashmla_sparse or use_flashinfer_mla_sparse):
use_flashmla_sparse = True

if use_flashmla_sparse and cache_config.block_size != 64:
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashMLASparse backend."
)
elif use_flashinfer_mla_sparse and cache_config.block_size not in (
32,
64,
):
cache_config.block_size = 64
logger.info(
"Forcing kv cache block size to 64 for FlashInferMLASparse "
"backend."
num_heads = model_config.get_num_attention_heads(parallel_config)
with set_current_vllm_config(vllm_config):
chosen_backend = cls.select_attention_backend(
selected_backend=requested_backend,
attn_selector_config=attn_selector_config,
device_capability=device_capability,
raise_on_invalid=False,
num_heads=num_heads,
)

if chosen_backend is not None:
try:
backend_class = chosen_backend.get_class()
preferred_block_size = backend_class.get_preferred_block_size(
cache_config.block_size
)
if cache_config.block_size != preferred_block_size:
if user_specified_block_size:
raise ValueError(
f"User-specified block_size="
f"{cache_config.block_size} is "
f"incompatible with "
f"{chosen_backend.name} backend "
f"(requires block_size="
f"{preferred_block_size}). "
f"Either remove --block-size to "
f"auto-select, or use --block-size "
f"{preferred_block_size}."
)
logger.info(
"Setting kv cache block size to %d for %s backend.",
preferred_block_size,
chosen_backend.name,
)
cache_config.block_size = preferred_block_size # type: ignore[assignment]
except ImportError:
pass # Backend selection will fail later with a better error

scheduler_config = vllm_config.scheduler_config
# Note: model_config may be None during testing
if (
Expand Down Expand Up @@ -336,77 +311,132 @@ def get_valid_backends(
return valid_backends_priorities, invalid_reasons

@classmethod
def get_attn_backend_cls(
def select_attention_backend(
cls,
selected_backend: "AttentionBackendEnum",
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
device_capability: "DeviceCapability",
raise_on_invalid: bool = True,
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None

) -> "AttentionBackendEnum | None":
"""Select the best attention backend for the given configuration.

This is the single source of truth for backend selection, used by both
check_and_update_config (to set block_size) and get_attn_backend_cls
(to get the backend class path).

Args:
selected_backend: User-specified backend, or None for auto-selection
attn_selector_config: Configuration for attention selection
device_capability: Device capability info
raise_on_invalid: If True, raise ValueError when no valid backend
num_heads: Number of attention heads per GPU, used for backend
priority ordering on Blackwell GPUs

Returns:
The selected backend enum, or None if no valid backend found
and raise_on_invalid is False
"""
# Ensure we don't constrain by block_size during selection
attn_selector_config = attn_selector_config._replace(block_size=None)

# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
validation_errors = backend_class.validate_configuration(
device_capability=device_capability,
**attn_selector_config._asdict(),
)
except ImportError:
invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError(
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {invalid_reasons}"
)
else:
logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
validation_errors = ["ImportError"]
if validation_errors:
if raise_on_invalid:
raise ValueError(
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {validation_errors}"
)
return None
return selected_backend

# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
# No selected backend, so find the best valid one.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)

if len(valid_backends_priorities) == 0:
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
if raise_on_invalid:
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
return None

# We have found some valid backends. Select the one with the
# highest priority.
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info_once(
"Using %s attention backend out of potential backends: %s.",
selected_backend.name,
"[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]",
scope="local",
# Select the one with the highest priority (lowest index).
sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1])
return sorted_backends[0][0]

@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum | None",
attn_selector_config: "AttentionSelectorConfig",
num_heads: int | None = None,
) -> str:
device_capability = cls.get_device_capability()
assert device_capability is not None

chosen_backend = cls.select_attention_backend(
selected_backend=selected_backend,
attn_selector_config=attn_selector_config,
num_heads=num_heads,
device_capability=device_capability,
raise_on_invalid=True,
)
assert chosen_backend is not None # raise_on_invalid=True guarantees this

# Log the selection
if selected_backend is not None:
logger.info("Using %s backend.", chosen_backend)
else:
# Get all valid backends for logging
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
device_capability=device_capability,
attn_selector_config=attn_selector_config._replace(block_size=None),
num_heads=num_heads,
)
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = attn_selector_config.__repr__()
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
logger.info_once(
"Using %s attention backend out of potential backends: %s",
chosen_backend.name,
tuple(b[0].name for b in valid_backends_priorities),
scope="local",
)

return selected_backend.get_path()
return chosen_backend.get_path()

@classmethod
def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]:
Expand Down
Loading