diff --git a/benchmarks/attention_benchmarks/benchmark.py b/benchmarks/attention_benchmarks/benchmark.py index de56cbac8474..0329d110244c 100644 --- a/benchmarks/attention_benchmarks/benchmark.py +++ b/benchmarks/attention_benchmarks/benchmark.py @@ -59,7 +59,9 @@ def run_mla_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: """Run MLA benchmark with appropriate backend.""" from mla_runner import run_mla_benchmark as run_mla - return run_mla(config.backend, config, **kwargs) + return run_mla( + config.backend, config, prefill_backend=config.prefill_backend, **kwargs + ) def run_benchmark(config: BenchmarkConfig, **kwargs) -> BenchmarkResult: @@ -440,14 +442,21 @@ def main(): # Backend selection parser.add_argument( "--backends", + "--decode-backends", nargs="+", - help="Backends to benchmark (flash, triton, flashinfer, cutlass_mla, " + help="Decode backends to benchmark (flash, triton, flashinfer, cutlass_mla, " "flashinfer_mla, flashattn_mla, flashmla)", ) parser.add_argument( "--backend", help="Single backend (alternative to --backends)", ) + parser.add_argument( + "--prefill-backends", + nargs="+", + help="Prefill backends to compare (fa2, fa3, fa4). " + "Uses the first decode backend for impl construction.", + ) # Batch specifications parser.add_argument( @@ -502,7 +511,7 @@ def main(): # Override args with YAML values, but CLI args take precedence # Check if CLI provided backends (they would be non-None and not default) - cli_backends_provided = args.backends is not None or args.backend is not None + cli_backends_provided = args.backend is not None or args.backends is not None # Backend(s) - only use YAML if CLI didn't specify if not cli_backends_provided: @@ -512,6 +521,12 @@ def main(): elif "backends" in yaml_config: args.backends = yaml_config["backends"] args.backend = None + elif "decode_backends" in yaml_config: + args.backends = yaml_config["decode_backends"] + args.backend = None + + # Prefill backends (e.g., ["fa3", "fa4"]) + args.prefill_backends = yaml_config.get("prefill_backends", None) # Check for special modes if "mode" in yaml_config: @@ -613,7 +628,10 @@ def main(): # Determine backends backends = args.backends or ([args.backend] if args.backend else ["flash"]) + prefill_backends = getattr(args, "prefill_backends", None) console.print(f"Backends: {', '.join(backends)}") + if prefill_backends: + console.print(f"Prefill backends: {', '.join(prefill_backends)}") console.print(f"Batch specs: {', '.join(args.batch_specs)}") console.print() @@ -850,37 +868,93 @@ def main(): else: # Normal mode: compare backends - total = len(backends) * len(args.batch_specs) + decode_results = [] + prefill_results = [] - with tqdm(total=total, desc="Benchmarking") as pbar: - for spec in args.batch_specs: - for backend in backends: - config = BenchmarkConfig( - backend=backend, - batch_spec=spec, - num_layers=args.num_layers, - head_dim=args.head_dim, - num_q_heads=args.num_q_heads, - num_kv_heads=args.num_kv_heads, - block_size=args.block_size, - device=args.device, - repeats=args.repeats, - warmup_iters=args.warmup_iters, - profile_memory=args.profile_memory, - ) + # Run decode backend comparison + if not prefill_backends: + # No prefill backends specified: compare decode backends as before + total = len(backends) * len(args.batch_specs) - result = run_benchmark(config) - all_results.append(result) + with tqdm(total=total, desc="Benchmarking") as pbar: + for spec in args.batch_specs: + for backend in backends: + config = BenchmarkConfig( + backend=backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + ) - if not result.success: - console.print(f"[red]Error {backend} {spec}: {result.error}[/]") + result = run_benchmark(config) + decode_results.append(result) - pbar.update(1) + if not result.success: + console.print( + f"[red]Error {backend} {spec}: {result.error}[/]" + ) - # Display results - console.print("\n[bold green]Results:[/]") - formatter = ResultsFormatter(console) - formatter.print_table(all_results, backends) + pbar.update(1) + + console.print("\n[bold green]Results:[/]") + formatter = ResultsFormatter(console) + formatter.print_table(decode_results, backends) + + # Run prefill backend comparison + if prefill_backends: + # Use first decode backend for impl construction + decode_backend = backends[0] + total = len(prefill_backends) * len(args.batch_specs) + + console.print( + f"[yellow]Prefill comparison mode: " + f"using {decode_backend} for decode impl[/]" + ) + + with tqdm(total=total, desc="Prefill benchmarking") as pbar: + for spec in args.batch_specs: + for pb in prefill_backends: + config = BenchmarkConfig( + backend=decode_backend, + batch_spec=spec, + num_layers=args.num_layers, + head_dim=args.head_dim, + num_q_heads=args.num_q_heads, + num_kv_heads=args.num_kv_heads, + block_size=args.block_size, + device=args.device, + repeats=args.repeats, + warmup_iters=args.warmup_iters, + profile_memory=args.profile_memory, + prefill_backend=pb, + ) + + result = run_benchmark(config) + + # Label result with prefill backend name for display + labeled_config = replace(result.config, backend=pb) + result = replace(result, config=labeled_config) + prefill_results.append(result) + + if not result.success: + console.print(f"[red]Error {pb} {spec}: {result.error}[/]") + + pbar.update(1) + + console.print("\n[bold green]Prefill Backend Results:[/]") + formatter = ResultsFormatter(console) + formatter.print_table( + prefill_results, prefill_backends, compare_to_fastest=True + ) + + all_results = decode_results + prefill_results # Save results if all_results: diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 9fa22c8d54f0..208d6273c928 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -77,6 +77,7 @@ def __init__(self, num_heads: int, qk_nope_head_dim: int, v_head_dim: int): self.qk_nope_head_dim = qk_nope_head_dim self.v_head_dim = v_head_dim self.out_dim = qk_nope_head_dim + v_head_dim + self.weight = torch.empty(0, dtype=torch.bfloat16) def __call__(self, x: torch.Tensor) -> tuple[torch.Tensor]: """ @@ -213,6 +214,7 @@ class BenchmarkConfig: use_cuda_graphs: bool = False # MLA-specific + prefill_backend: str | None = None kv_lora_rank: int | None = None qk_nope_head_dim: int | None = None qk_rope_head_dim: int | None = None diff --git a/benchmarks/attention_benchmarks/configs/mla_prefill.yaml b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml index ef6b2cb07dc7..122dbd783c5b 100644 --- a/benchmarks/attention_benchmarks/configs/mla_prefill.yaml +++ b/benchmarks/attention_benchmarks/configs/mla_prefill.yaml @@ -1,4 +1,19 @@ -# MLA prefill-only benchmark configuration for sparse backends +# MLA prefill backend comparison +# +# Compares all available MLA prefill backends: +# FA backends: fa2, fa3, fa4 (FlashAttention versions) +# Non-FA: flashinfer, cudnn, trtllm (Blackwell-only, require flashinfer) +# +# Uses cutlass_mla as the decode backend for impl construction +# (only the prefill path is exercised). +# +# Backends that aren't available on the current platform will report errors +# in the results table (e.g., fa3 on Blackwell, cudnn without artifactory). +# +# Usage: +# python benchmark.py --config configs/mla_prefill.yaml + +description: "MLA prefill backend comparison" model: name: "deepseek-v3" @@ -12,20 +27,25 @@ model: v_head_dim: 128 block_size: 128 -# Model parameter sweep: simulate tensor parallelism by varying num_q_heads -# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads -model_parameter_sweep: - param_name: "num_q_heads" - values: [128, 64, 32, 16] - label_format: "{backend}_{value}h" +# model: +# name: "deepseek-v2-lite" +# num_layers: 27 +# num_q_heads: 16 +# num_kv_heads: 1 +# head_dim: 576 +# kv_lora_rank: 512 +# qk_nope_head_dim: 128 +# qk_rope_head_dim: 64 +# v_head_dim: 128 +# block_size: 128 batch_specs: # Pure prefill - - "1q512" - - "1q1k" - - "1q2k" - - "1q4k" - - "1q8k" + - "q512" + - "q1k" + - "q2k" + - "q4k" + - "q8k" # Batched pure prefill - "2q512" @@ -44,19 +64,63 @@ batch_specs: - "8q4k" - "8q8k" - # Extend - - "1q512s4k" - - "1q512s8k" - - "1q1ks8k" - - "1q2ks8k" - - "1q2ks16k" - - "1q4ks16k" + # Chunked prefill / extend + # Short context + - "q128s1k" + - "q256s2k" + - "q512s4k" + - "q1ks4k" + - "q2ks8k" + - "2q128s1k" + - "2q256s2k" + - "2q512s4k" + - "2q1ks4k" + - "2q2ks8k" + - "4q128s1k" + - "4q256s2k" + - "4q512s4k" + - "4q1ks4k" + - "4q2ks8k" + - "8q128s1k" + - "8q256s2k" + - "8q512s4k" + - "8q1ks4k" + + # Medium context + - "q128s16k" + - "q512s16k" + - "q1ks16k" + - "q2ks16k" + - "2q128s16k" + - "2q512s16k" + - "2q1ks16k" + - "2q2ks16k" + - "4q128s16k" + - "4q512s16k" + - "4q1ks16k" + - "4q2ks16k" + + # Long context + - "q128s64k" + - "q512s64k" + - "q1ks64k" + - "q2ks64k" + - "2q128s64k" + - "2q512s64k" + - "2q1ks64k" + - "2q2ks64k" + +decode_backends: + - CUTLASS_MLA -backends: - - FLASHMLA_SPARSE - - FLASHINFER_MLA_SPARSE +prefill_backends: + - fa2 + - fa3 + - fa4 + - flashinfer + - cudnn + - trtllm device: "cuda:0" -repeats: 10 -warmup_iters: 3 -profile_memory: true +repeats: 20 +warmup_iters: 5 diff --git a/benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml b/benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml new file mode 100644 index 000000000000..ef6b2cb07dc7 --- /dev/null +++ b/benchmarks/attention_benchmarks/configs/mla_sparse_prefill.yaml @@ -0,0 +1,62 @@ +# MLA prefill-only benchmark configuration for sparse backends + +model: + name: "deepseek-v3" + num_layers: 60 + num_q_heads: 128 + num_kv_heads: 1 + head_dim: 576 + kv_lora_rank: 512 + qk_nope_head_dim: 128 + qk_rope_head_dim: 64 + v_head_dim: 128 + block_size: 128 + +# Model parameter sweep: simulate tensor parallelism by varying num_q_heads +# TP=1: 128 heads, TP=2: 64 heads, TP=4: 32 heads, TP=8: 16 heads +model_parameter_sweep: + param_name: "num_q_heads" + values: [128, 64, 32, 16] + label_format: "{backend}_{value}h" + +batch_specs: + # Pure prefill + - "1q512" + - "1q1k" + - "1q2k" + - "1q4k" + - "1q8k" + + # Batched pure prefill + - "2q512" + - "2q1k" + - "2q2k" + - "2q4k" + - "2q8k" + - "4q512" + - "4q1k" + - "4q2k" + - "4q4k" + - "4q8k" + - "8q512" + - "8q1k" + - "8q2k" + - "8q4k" + - "8q8k" + + # Extend + - "1q512s4k" + - "1q512s8k" + - "1q1ks8k" + - "1q2ks8k" + - "1q2ks16k" + - "1q4ks16k" + +backends: + - FLASHMLA_SPARSE + - FLASHINFER_MLA_SPARSE + +device: "cuda:0" +repeats: 10 +warmup_iters: 3 +profile_memory: true diff --git a/benchmarks/attention_benchmarks/mla_runner.py b/benchmarks/attention_benchmarks/mla_runner.py index 110f580fb7bd..34d3bf3766ff 100644 --- a/benchmarks/attention_benchmarks/mla_runner.py +++ b/benchmarks/attention_benchmarks/mla_runner.py @@ -62,6 +62,7 @@ def create_minimal_vllm_config( max_num_seqs: int = 256, mla_dims: dict | None = None, index_topk: int | None = None, + prefill_backend: str | None = None, ) -> VllmConfig: """ Create minimal VllmConfig for MLA benchmarks. @@ -75,6 +76,9 @@ def create_minimal_vllm_config( setup_mla_dims(model_name) index_topk: Optional topk value for sparse MLA backends. If provided, the config will include index_topk for sparse attention. + prefill_backend: Prefill backend name (e.g., "fa3", "fa4", "flashinfer", + "cudnn", "trtllm"). Configures the attention config to + force the specified prefill backend. Returns: VllmConfig for benchmarking @@ -163,7 +167,7 @@ def create_minimal_vllm_config( compilation_config = CompilationConfig() - return VllmConfig( + vllm_config = VllmConfig( model_config=model_config, cache_config=cache_config, parallel_config=parallel_config, @@ -171,9 +175,84 @@ def create_minimal_vllm_config( compilation_config=compilation_config, ) + if prefill_backend is not None: + prefill_cfg = get_prefill_backend_config(prefill_backend) + if prefill_cfg["flash_attn_version"] is not None: + vllm_config.attention_config.flash_attn_version = prefill_cfg[ + "flash_attn_version" + ] + vllm_config.attention_config.disable_flashinfer_prefill = prefill_cfg[ + "disable_flashinfer_prefill" + ] + vllm_config.attention_config.use_cudnn_prefill = prefill_cfg[ + "use_cudnn_prefill" + ] + vllm_config.attention_config.use_trtllm_ragged_deepseek_prefill = prefill_cfg[ + "use_trtllm_ragged_deepseek_prefill" + ] + + return vllm_config + # ============================================================================ -# Backend Configuration +# Prefill Backend Configuration +# ============================================================================ + +# Maps prefill backend names to attention config overrides. +# FA backends set flash_attn_version and disable non-FA paths. +# Non-FA backends enable their specific path and disable others. +_PREFILL_BACKEND_CONFIG: dict[str, dict] = { + "fa2": { + "flash_attn_version": 2, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "fa3": { + "flash_attn_version": 3, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "fa4": { + "flash_attn_version": 4, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "flashinfer": { + "flash_attn_version": None, + "disable_flashinfer_prefill": False, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "cudnn": { + "flash_attn_version": None, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": True, + "use_trtllm_ragged_deepseek_prefill": False, + }, + "trtllm": { + "flash_attn_version": None, + "disable_flashinfer_prefill": True, + "use_cudnn_prefill": False, + "use_trtllm_ragged_deepseek_prefill": True, + }, +} + + +def get_prefill_backend_config(prefill_backend: str) -> dict: + """Get attention config overrides for a prefill backend.""" + if prefill_backend not in _PREFILL_BACKEND_CONFIG: + raise ValueError( + f"Unknown prefill backend: {prefill_backend!r}. " + f"Available: {list(_PREFILL_BACKEND_CONFIG.keys())}" + ) + return _PREFILL_BACKEND_CONFIG[prefill_backend] + + +# ============================================================================ +# Decode Backend Configuration # ============================================================================ @@ -203,6 +282,7 @@ def _get_backend_config(backend: str) -> dict: Returns: Dict with backend configuration """ + from vllm.v1.attention.backend import MultipleOf from vllm.v1.attention.backends.registry import AttentionBackendEnum try: @@ -219,8 +299,8 @@ def _get_backend_config(backend: str) -> dict: block_sizes = backend_class.get_supported_kernel_block_sizes() # Use first supported block size (backends typically support one for MLA) block_size = block_sizes[0] if block_sizes else None - if hasattr(block_size, "value"): - # Handle MultipleOf enum + if isinstance(block_size, MultipleOf): + # No fixed block size; fall back to config value block_size = None # Check if sparse via class method if available @@ -676,16 +756,11 @@ def _run_single_benchmark( if is_sparse and indexer is not None: indexer.fill_random_indices(total_q, max_kv_len) - # Determine which forward method to use - if is_sparse: - # Sparse backends use forward_mqa + # Determine which forward method to use based on metadata + if metadata.decode is not None: forward_fn = lambda: impl.forward_mqa(decode_inputs, kv_cache, metadata, layer) - elif metadata.decode is not None: - forward_fn = lambda: impl._forward_decode( - decode_inputs, kv_cache, metadata, layer - ) elif metadata.prefill is not None: - forward_fn = lambda: impl._forward_prefill( + forward_fn = lambda: impl.forward_mha( prefill_inputs["q"], prefill_inputs["k_c_normed"], prefill_inputs["k_pe"], @@ -732,6 +807,7 @@ def _run_mla_benchmark_batched( backend: str, configs_with_params: list[tuple], # [(config, threshold, num_splits), ...] index_topk: int = 2048, + prefill_backend: str | None = None, ) -> list[BenchmarkResult]: """ Unified batched MLA benchmark runner for all backends. @@ -743,11 +819,13 @@ def _run_mla_benchmark_batched( to avoid setup/teardown overhead. Args: - backend: Backend name + backend: Backend name (decode backend used for impl construction) configs_with_params: List of (config, threshold, num_splits) tuples - threshold: reorder_batch_threshold (FlashAttn/FlashMLA only) - num_splits: num_kv_splits (CUTLASS only) index_topk: Topk value for sparse MLA backends (default 2048) + prefill_backend: Prefill backend name (e.g., "fa3", "fa4"). + When set, forces the specified FlashAttention version for prefill. Returns: List of BenchmarkResult objects @@ -780,11 +858,25 @@ def _run_mla_benchmark_batched( block_size=block_size, mla_dims=mla_dims, # Use custom dims from config or default index_topk=index_topk if is_sparse else None, + prefill_backend=prefill_backend, ) results = [] with set_current_vllm_config(vllm_config): + # Clear cached prefill backend detection functions so they re-evaluate + # with the current VllmConfig. These are @functools.cache decorated and + # would otherwise return stale results from a previous backend's config. + from vllm.model_executor.layers.attention.mla_attention import ( + use_cudnn_prefill, + use_flashinfer_prefill, + use_trtllm_ragged_deepseek_prefill, + ) + + use_flashinfer_prefill.cache_clear() + use_cudnn_prefill.cache_clear() + use_trtllm_ragged_deepseek_prefill.cache_clear() + # Create backend impl, layer, builder, and indexer (reused across benchmarks) impl, layer, builder_instance, indexer = _create_backend_impl( backend_cfg, @@ -794,6 +886,38 @@ def _run_mla_benchmark_batched( index_topk=index_topk if is_sparse else None, ) + # Verify the actual prefill backend matches what was requested + if prefill_backend is not None: + prefill_cfg = get_prefill_backend_config(prefill_backend) + fa_version = prefill_cfg["flash_attn_version"] + + if fa_version is not None: + # FA backend: verify the impl's FA version + actual_fa_version = getattr(impl, "vllm_flash_attn_version", None) + if actual_fa_version != fa_version: + raise RuntimeError( + f"Prefill backend '{prefill_backend}' requested FA " + f"version {fa_version}, but the impl is using FA " + f"version {actual_fa_version}. Check " + f"vllm/v1/attention/backends/fa_utils.py." + ) + else: + # Non-FA backend: verify the builder picked the right path + expected_flags = { + "flashinfer": "_use_fi_prefill", + "cudnn": "_use_cudnn_prefill", + "trtllm": "_use_trtllm_ragged_prefill", + } + flag_name = expected_flags.get(prefill_backend) + if flag_name and not getattr(builder_instance, flag_name, False): + raise RuntimeError( + f"Prefill backend '{prefill_backend}' was requested " + f"but the metadata builder did not enable it. This " + f"usually means a dependency is missing (e.g., " + f"flashinfer not installed) or the platform doesn't " + f"support it." + ) + # Run each benchmark with the shared impl for config, threshold, num_splits in configs_with_params: # Set threshold for this benchmark (FlashAttn/FlashMLA only) @@ -844,6 +968,7 @@ def run_mla_benchmark( reorder_batch_threshold: int | None = None, num_kv_splits: int | None = None, index_topk: int = 2048, + prefill_backend: str | None = None, ) -> BenchmarkResult | list[BenchmarkResult]: """ Unified MLA benchmark runner for all backends. @@ -861,6 +986,8 @@ def run_mla_benchmark( (single config mode only) num_kv_splits: Number of KV splits for CUTLASS (single config mode only) index_topk: Topk value for sparse MLA backends (default 2048) + prefill_backend: Prefill backend name (e.g., "fa3", "fa4"). + When set, forces the specified FlashAttention version for prefill. Returns: BenchmarkResult (single mode) or list of BenchmarkResult (batched mode) @@ -884,7 +1011,9 @@ def run_mla_benchmark( return_single = True # Use unified batched execution - results = _run_mla_benchmark_batched(backend, configs_with_params, index_topk) + results = _run_mla_benchmark_batched( + backend, configs_with_params, index_topk, prefill_backend=prefill_backend + ) # Return single result or list based on input return results[0] if return_single else results diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index dd184e38eb5e..a7e9e6ff5545 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -39,7 +39,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2 + GIT_TAG 1488682bb545f7d020e958a33116b1419d1cfc83 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/vllm/config/attention.py b/vllm/config/attention.py index e05544f08e10..85673f384adf 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -30,14 +30,14 @@ class AttentionConfig: use_cudnn_prefill: bool = False """Whether to use cudnn prefill.""" - use_trtllm_ragged_deepseek_prefill: bool = True + use_trtllm_ragged_deepseek_prefill: bool = False """Whether to use TRTLLM ragged deepseek prefill.""" use_trtllm_attention: bool | None = None """If set to True/False, use or don't use the TRTLLM attention backend in flashinfer. If None, auto-detect the attention backend in flashinfer.""" - disable_flashinfer_prefill: bool = False + disable_flashinfer_prefill: bool = True """Whether to disable flashinfer prefill.""" disable_flashinfer_q_quantization: bool = False diff --git a/vllm/model_executor/layers/attention/mla_attention.py b/vllm/model_executor/layers/attention/mla_attention.py index b1dc1a860501..5b7113dc82f5 100644 --- a/vllm/model_executor/layers/attention/mla_attention.py +++ b/vllm/model_executor/layers/attention/mla_attention.py @@ -1282,8 +1282,6 @@ def is_deepseek_r1_mla_compatible(vllm_config: VllmConfig) -> bool: @functools.cache def use_flashinfer_prefill() -> bool: - # For blackwell default to flashinfer prefill if it's available since - # it is faster than FA2. from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() @@ -2154,13 +2152,16 @@ def __init__( # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim for attention backends that do - # not support different headdims - # We don't need to pad V if we are on a hopper system with FA3 + # not support different headdims. + # FA3 on Hopper (SM90) and FA4 natively handle diff headdims. device_capability = current_platform.get_device_capability() self._pad_v = self.vllm_flash_attn_version is None or not ( - self.vllm_flash_attn_version == 3 - and device_capability is not None - and device_capability[0] == 9 + ( + self.vllm_flash_attn_version == 3 + and device_capability is not None + and device_capability[0] == 9 + ) + or self.vllm_flash_attn_version == 4 ) self.dcp_world_size: int = -1 diff --git a/vllm/v1/attention/backends/fa_utils.py b/vllm/v1/attention/backends/fa_utils.py index 20502cbf0feb..cd8c46d032c0 100644 --- a/vllm/v1/attention/backends/fa_utils.py +++ b/vllm/v1/attention/backends/fa_utils.py @@ -125,11 +125,14 @@ def get_flash_attn_version( # FA4 on SM100 (Blackwell) has TMEM capacity limits that restrict # supported head dimensions. # See: https://github.com/Dao-AILab/flash-attention/issues/1959 + # Exception: hdim 192 is supported for MLA's diff-headdim case + # (qk=192, v=128), added upstream in commits 1a15733e/1b36ab19. if ( fa_version == 4 and device_capability.major >= 10 and head_size is not None and head_size > 128 + and head_size != 192 ): logger.warning_once( "FA4 on Blackwell does not support head_size=%d due to TMEM "