diff --git a/vllm/config/cache.py b/vllm/config/cache.py index daceaa6c2bb4..0823b00a351c 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -19,7 +19,6 @@ logger = init_logger(__name__) -BlockSize = Literal[1, 8, 16, 32, 64, 128, 256] CacheDType = Literal[ "auto", "bfloat16", @@ -39,13 +38,11 @@ class CacheConfig: """Configuration for the KV cache.""" - block_size: SkipValidation[BlockSize] = None # type: ignore[assignment] - """Size of a contiguous cache block in number of tokens. On CUDA devices, - only block sizes up to 32 are supported. + block_size: SkipValidation[int] = None # type: ignore[assignment] + """Size of a contiguous cache block in number of tokens. - This config has no static default. If left unspecified by the user, it will - be set in `Platform.check_and_update_config()` based on the current - platform.""" + This is None until `Platform.check_and_update_config()` sets it based on + the current platform. Always an int by the time the engine starts.""" gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8ea96de4913e..1d9a924bdcee 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -59,7 +59,6 @@ get_attr_docs, ) from vllm.config.cache import ( - BlockSize, CacheDType, KVOffloadingBackend, MambaCacheMode, @@ -431,7 +430,7 @@ class EngineArgs: max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) - block_size: BlockSize = CacheConfig.block_size + block_size: int = None # type: ignore[assignment] enable_prefix_caching: bool | None = None prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index c2fcde4ab1cf..2314d0a8b675 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -163,8 +163,6 @@ def log_warnings(cls): @classmethod def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: - from vllm.v1.attention.backends.registry import AttentionBackendEnum - parallel_config = vllm_config.parallel_config model_config = vllm_config.model_config @@ -172,112 +170,19 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: parallel_config.worker_cls = "vllm.v1.worker.gpu_worker.Worker" cache_config = vllm_config.cache_config - if cache_config and cache_config.block_size is None: + user_specified_block_size = cache_config.block_size is not None + if not user_specified_block_size: cache_config.block_size = 16 - # TODO(lucas): handle this more gracefully - # Note: model_config may be None during testing - # Note: block_size is initialized in - # HybridAttentionMambaModelConfig.verify_and_update_config - # for models with both attention and mamba, - # and doesn't need to be reinitialized here - if ( - model_config is not None - and model_config.use_mla - and cache_config.block_size is not None - ): - use_sparse = hasattr(vllm_config.model_config.hf_config, "index_topk") - # If `--attention-config.backend` is not set and we are using MLA, - # then we default to FlashMLA backend for non-blackwell GPUs, - # else we default to CutlassMLA. For each case, we force the - # required block_size. - use_flashmla = False - use_cutlass_mla = False - use_flashinfer_mla = False - use_flashmla_sparse = False - use_flashinfer_mla_sparse = False - - from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported - - if vllm_config.attention_config.backend is None: - # Default case - hf_text_config = model_config.hf_text_config - qk_nope_head_dim = getattr(hf_text_config, "qk_nope_head_dim", 1) - if ( - cls.is_device_capability_family(100) - and not use_sparse - and qk_nope_head_dim == 128 - ): - # Blackwell => Force FlashInfer MLA (unless sparse, i.e. DSv3.2) - # and only if qk_nope_head_dim == 128 (kernel constraint) - use_flashinfer_mla = True - # Set the backend in AttentionConfig so it's used during - # backend selection - vllm_config.attention_config.backend = ( - AttentionBackendEnum.FLASHINFER_MLA - ) - elif cls.is_device_capability_family(100) and not use_sparse: - # Fall back to CUTLASS_MLA as 2nd priority on Blackwell - use_cutlass_mla = True - elif is_flashmla_dense_supported()[0]: - # Non-Blackwell with FlashMLA support - use_flashmla = True - else: - # Fallback: will use Triton MLA or other compatible backend - pass - else: - # Forced case - backend = vllm_config.attention_config.backend - use_flashmla = backend == AttentionBackendEnum.FLASHMLA - use_cutlass_mla = backend == AttentionBackendEnum.CUTLASS_MLA - use_flashinfer_mla = backend == AttentionBackendEnum.FLASHINFER_MLA - use_flashmla_sparse = backend == AttentionBackendEnum.FLASHMLA_SPARSE - use_flashinfer_mla_sparse = ( - backend == AttentionBackendEnum.FLASHINFER_MLA_SPARSE - ) - - if ( - use_flashmla - and is_flashmla_dense_supported()[0] - and cache_config.block_size % 64 != 0 - ): - cache_config.block_size = 64 - logger.info("Forcing kv cache block size to 64 for FlashMLA backend.") - - if use_cutlass_mla and cache_config.block_size % 128 != 0: - cache_config.block_size = 128 - logger.info( - "Forcing kv cache block size to 128 for CUTLASS_MLA backend." - ) - - if ( - use_flashinfer_mla - and cache_config.block_size != 32 - and cache_config.block_size % 64 != 0 - ): - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashInferMLA backend." - ) - - if use_sparse: - if not (use_flashmla_sparse or use_flashinfer_mla_sparse): - use_flashmla_sparse = True - - if use_flashmla_sparse and cache_config.block_size != 64: - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashMLASparse backend." - ) - elif use_flashinfer_mla_sparse and cache_config.block_size not in ( - 32, - 64, - ): - cache_config.block_size = 64 - logger.info( - "Forcing kv cache block size to 64 for FlashInferMLASparse " - "backend." - ) + # Ensure block_size is compatible with the attention backend. + # Note: model_config may be None during testing. + # Skip hybrid (attention+mamba) models — their block_size is + # managed by HybridAttentionMambaModelConfig + if model_config is not None and not model_config.is_hybrid: + cls._update_block_size_for_backend( + vllm_config, + user_specified_block_size, + ) scheduler_config = vllm_config.scheduler_config # Note: model_config may be None during testing @@ -293,6 +198,150 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: ) scheduler_config.disable_chunked_mm_input = True + @classmethod + def _update_block_size_for_backend( + cls, + vllm_config: "VllmConfig", + user_specified_block_size: bool, + ) -> None: + """Ensure block_size is compatible with the attention backend. + + If the user specified --block-size, the selector validates/filters + backends by that block size (raising on incompatibility). Otherwise, + the backend is selected unconstrained and block_size is set to the + backend's preferred value. + """ + from vllm.config.vllm import set_current_vllm_config + from vllm.v1.attention.selector import AttentionSelectorConfig + + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + + device_capability = cls.get_device_capability() + if device_capability is None: + return + + use_mla = model_config.use_mla + attn_selector_config = AttentionSelectorConfig( + head_size=model_config.get_head_size(), + dtype=model_config.dtype, # type: ignore[arg-type] + kv_cache_dtype=cache_config.cache_dtype, + block_size=cache_config.block_size if user_specified_block_size else None, + use_mla=use_mla, + has_sink=False, + use_sparse=use_mla and hasattr(model_config.hf_config, "index_topk"), + use_mm_prefix=model_config.is_mm_prefix_lm, + ) + + user_specified_backend = vllm_config.attention_config.backend + num_heads = model_config.get_num_attention_heads( + vllm_config.parallel_config, + ) + with set_current_vllm_config(vllm_config): + chosen_backend = cls.select_attention_backend( + selected_backend=user_specified_backend, + attn_selector_config=attn_selector_config, + device_capability=device_capability, + # Don't raise here — we produce better errors below. + raise_on_invalid=False, + num_heads=num_heads, + ) + + # If the user's --block-size forced a non-optimal backend, + # warn them. Only relevant when the user didn't also specify + # --attention-backend (in which case the choice is explicit). + if ( + chosen_backend is not None + and user_specified_block_size + and user_specified_backend is None + ): + optimal = cls.select_attention_backend( + selected_backend=None, + attn_selector_config=attn_selector_config._replace( + block_size=None, + ), + device_capability=device_capability, + raise_on_invalid=False, + num_heads=num_heads, + ) + if optimal is not None and optimal != chosen_backend: + logger.warning( + "--block-size %d is not supported by the preferred " + "%s backend. Using %s instead, which may result " + "in reduced performance. Consider removing " + "--block-size to auto-select the optimal " + "block size.", + cache_config.block_size, + optimal.name, + chosen_backend.name, + ) + + if chosen_backend is not None: + if user_specified_block_size: + # User's block_size is compatible with the chosen + # backend. + return + # User didn't specify --block-size, so auto-select the + # preferred block size for the chosen backend. + try: + backend_class = chosen_backend.get_class() + except ImportError: + return # Will fail later with a better error + preferred = backend_class.get_preferred_block_size( + cache_config.block_size, + ) + if cache_config.block_size != preferred: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + chosen_backend.name, + ) + cache_config.block_size = preferred + return + + # No valid backend found. If the user didn't constrain the + # selection, defer the error to get_attn_backend_cls where + # the full config (including per-layer settings) is + # available. + if not user_specified_block_size: + return + + if user_specified_backend is not None: + # User specified --block-size and --attention-backend + # and they are incompatible. + try: + backend_class = user_specified_backend.get_class() + supported = backend_class.get_supported_kernel_block_sizes() + except ImportError: + supported = None + raise ValueError( + f"User-specified --block-size " + f"{cache_config.block_size} is incompatible with " + f"the specified --attention-backend " + f"{user_specified_backend.name} (supported kernel " + f"block sizes: {supported}). Either remove " + f"--block-size to auto-select, or choose a " + f"compatible value." + ) + else: + # User specified --block-size but no backend supports + # it. + _, invalid_reasons = cls.get_valid_backends( + device_capability=device_capability, + attn_selector_config=attn_selector_config, + num_heads=num_heads, + ) + reasons_str = ", ".join( + f"{b.name}: [{', '.join(r)}]" for b, r in invalid_reasons.items() + ) + raise ValueError( + f"No valid attention backend found for " + f"--block-size {cache_config.block_size}. " + f"Reasons: {{{reasons_str}}}. Either remove " + f"--block-size to auto-select, or choose a " + f"compatible value." + ) + @classmethod def get_current_memory_usage( cls, device: torch.types.Device | None = None @@ -336,77 +385,125 @@ def get_valid_backends( return valid_backends_priorities, invalid_reasons @classmethod - def get_attn_backend_cls( + def select_attention_backend( cls, - selected_backend: "AttentionBackendEnum", + selected_backend: "AttentionBackendEnum | None", attn_selector_config: "AttentionSelectorConfig", + device_capability: "DeviceCapability", + raise_on_invalid: bool = True, num_heads: int | None = None, - ) -> str: - device_capability = cls.get_device_capability() - assert device_capability is not None - - attn_selector_config = attn_selector_config._replace(block_size=None) + ) -> "AttentionBackendEnum | None": + """Select the best attention backend for the given configuration. + + Args: + selected_backend: User-specified backend, or None for auto-selection + attn_selector_config: Configuration for attention selection + device_capability: Device capability info + raise_on_invalid: If True, raise ValueError when no valid backend + num_heads: Number of attention heads per GPU, used for backend + priority ordering on Blackwell GPUs + + Returns: + The selected backend enum, or None if no valid backend found + and raise_on_invalid is False + """ # First try checking just the selected backend, if there is one. if selected_backend is not None: try: backend_class = selected_backend.get_class() - invalid_reasons = backend_class.validate_configuration( + validation_errors = backend_class.validate_configuration( device_capability=device_capability, **attn_selector_config._asdict(), ) except ImportError: - invalid_reasons = ["ImportError"] - if invalid_reasons: - raise ValueError( - f"Selected backend {selected_backend} is not valid for " - f"this configuration. Reason: {invalid_reasons}" - ) - else: - logger.info("Using %s backend.", selected_backend) - return selected_backend.get_path() + validation_errors = ["ImportError"] + if validation_errors: + if raise_on_invalid: + raise ValueError( + f"Selected backend {selected_backend} is not valid for " + f"this configuration. Reason: {validation_errors}" + ) + return None + return selected_backend - # No selected backend or the selected backend is invalid, - # so we try finding a valid backend. + # No selected backend, so find the best valid one. valid_backends_priorities, invalid_reasons = cls.get_valid_backends( device_capability=device_capability, attn_selector_config=attn_selector_config, num_heads=num_heads, ) - reasons_str = ( - "{" - + ", ".join( - f"{backend.name}: [{', '.join(reasons)}]" - for backend, reasons in invalid_reasons.items() - ) - + "}" - ) - config_str = attn_selector_config.__repr__() - logger.debug_once( - f"Some attention backends are not valid for {cls.device_name} with " - f"{config_str}. Reasons: {reasons_str}." - ) + if len(valid_backends_priorities) == 0: - raise ValueError( - f"No valid attention backend found for {cls.device_name} " - f"with {config_str}. Reasons: {reasons_str}." - ) + if raise_on_invalid: + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" + ) + config_str = attn_selector_config.__repr__() + raise ValueError( + f"No valid attention backend found for {cls.device_name} " + f"with {config_str}. Reasons: {reasons_str}." + ) + return None - # We have found some valid backends. Select the one with the - # highest priority. - sorted_indices = sorted( - range(len(valid_backends_priorities)), - key=lambda i: valid_backends_priorities[i][1], - ) - selected_index = sorted_indices[0] - selected_backend = valid_backends_priorities[selected_index][0] - logger.info_once( - "Using %s attention backend out of potential backends: %s.", - selected_backend.name, - "[" + ", ".join(f"'{b[0].name}'" for b in valid_backends_priorities) + "]", - scope="local", + # Select the one with the highest priority (lowest index). + sorted_backends = sorted(valid_backends_priorities, key=lambda x: x[1]) + return sorted_backends[0][0] + + @classmethod + def get_attn_backend_cls( + cls, + selected_backend: "AttentionBackendEnum | None", + attn_selector_config: "AttentionSelectorConfig", + num_heads: int | None = None, + ) -> str: + device_capability = cls.get_device_capability() + assert device_capability is not None + + chosen_backend = cls.select_attention_backend( + selected_backend=selected_backend, + attn_selector_config=attn_selector_config, + num_heads=num_heads, + device_capability=device_capability, + raise_on_invalid=True, ) + assert chosen_backend is not None # raise_on_invalid=True guarantees this + + # Log the selection + if selected_backend is not None: + logger.info("Using %s backend.", chosen_backend) + else: + # Get all valid backends for logging + valid_backends_priorities, invalid_reasons = cls.get_valid_backends( + device_capability=device_capability, + attn_selector_config=attn_selector_config, + num_heads=num_heads, + ) + reasons_str = ( + "{" + + ", ".join( + f"{backend.name}: [{', '.join(reasons)}]" + for backend, reasons in invalid_reasons.items() + ) + + "}" + ) + config_str = attn_selector_config.__repr__() + logger.debug_once( + f"Some attention backends are not valid for {cls.device_name} with " + f"{config_str}. Reasons: {reasons_str}." + ) + logger.info_once( + "Using %s attention backend out of potential backends: %s", + chosen_backend.name, + tuple(b[0].name for b in valid_backends_priorities), + scope="local", + ) - return selected_backend.get_path() + return chosen_backend.get_path() @classmethod def get_supported_vit_attn_backends(cls) -> list["AttentionBackendEnum"]: diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 9c004d7724dd..f31e2635a0f1 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, replace from enum import Enum -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar, get_args +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, TypeVar import numpy as np import torch @@ -144,15 +144,9 @@ def supports_kv_cache_dtype(cls, kv_cache_dtype: "CacheDType | None") -> bool: @classmethod def supports_block_size(cls, block_size: int | None) -> bool: - from vllm.config.cache import BlockSize - if block_size is None: return True - valid_sizes = get_args(BlockSize) - if block_size not in valid_sizes: - return False - supported_kernel_block_sizes = cls.get_supported_kernel_block_sizes() if not supported_kernel_block_sizes: return True @@ -167,6 +161,17 @@ def supports_block_size(cls, block_size: int | None) -> bool: return True return False + @classmethod + def get_preferred_block_size(cls, default_block_size: int = 16) -> int: + supported_sizes = cls.get_supported_kernel_block_sizes() + if not supported_sizes: + return default_block_size + + if cls.supports_block_size(default_block_size): + return default_block_size + + return min(s.base if isinstance(s, MultipleOf) else s for s in supported_sizes) + @classmethod def is_mla(cls) -> bool: return False