Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,90 @@ def test_per_head_quant_scales_backend_selection(
use_per_head_quant_scales=True,
)
assert backend_name in str(exc_info.value)


@pytest.mark.parametrize(
"backend_name,use_non_causal,should_succeed",
[
("FLASH_ATTN", True, True), # FlashAttn supports non-causal
("FLASH_ATTN", False, True), # FlashAttn also works with causal
("FLASHINFER", True, False), # FlashInfer does not support non-causal
("FLASHINFER", False, True), # FlashInfer works with causal
],
)
def test_non_causal_backend_selection(
backend_name: str, use_non_causal: bool, should_succeed: bool
):
"""Test that use_non_causal on AttentionConfig controls backend filtering.

DFlashProposer sets use_non_causal=True on the draft model's
AttentionConfig so only non-causal-capable backends are selected.
The target model keeps use_non_causal=False (default) and can use
any backend.
"""
_cached_get_attn_backend.cache_clear()

attention_config = AttentionConfig(
backend=AttentionBackendEnum[backend_name],
use_non_causal=use_non_causal,
)
cache_config = CacheConfig(block_size=16)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)

if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
):
if should_succeed:
backend = get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype=None,
)
assert backend.get_name() == backend_name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype=None,
)
assert "non-causal" in str(exc_info.value).lower()


def test_non_causal_autoselect_backend():
"""Test that when backend=None with use_non_causal=True, auto-selection
picks a compatible backend.

This simulates the DFlash scenario where the user doesn't specify
--attention-backend or --speculative-config.attention_backend.
The drafter inherits backend=None and auto-selects a backend that
supports non-causal attention.
"""
_cached_get_attn_backend.cache_clear()

attention_config = AttentionConfig(
backend=None,
use_non_causal=True,
)
cache_config = CacheConfig(block_size=16)
vllm_config = VllmConfig(
attention_config=attention_config, cache_config=cache_config
)

if CudaPlatform is None:
pytest.skip("CudaPlatform not available")
with (
set_current_vllm_config(vllm_config),
patch("vllm.platforms.current_platform", CudaPlatform()),
):
backend = get_attn_backend(
head_size=128,
dtype=torch.float16,
kv_cache_dtype=None,
)
assert backend.supports_non_causal()
3 changes: 3 additions & 0 deletions vllm/config/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class AttentionConfig:
use_fp4_indexer_cache: bool = False
"""If set, use fp4 indexer cache for dsv32 family model (not support yet)"""

use_non_causal: bool = False
"""Whether to use non-causal (bidirectional) attention."""

def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the configs
Expand Down
16 changes: 15 additions & 1 deletion vllm/config/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
from typing import TYPE_CHECKING, Any, Literal, get_args

from pydantic import Field, SkipValidation, model_validator
from pydantic import Field, SkipValidation, field_validator, model_validator
from typing_extensions import Self

from vllm.config import LoadConfig
Expand All @@ -17,6 +17,7 @@
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
from vllm.v1.attention.backends.registry import AttentionBackendEnum

if TYPE_CHECKING:
from transformers import PretrainedConfig
Expand Down Expand Up @@ -106,6 +107,10 @@ class SpeculativeConfig:
inherits the target model's `--moe-backend` setting. Useful when the
drafter and generator require different MoE kernels (e.g. quantized
generator with unquantized drafter)."""
attention_backend: AttentionBackendEnum | None = None
"""Attention backend to use for the draft model. When `None`, the backend is
automatically selected. Useful when the drafter requires a different attention
backend (e.g. DFlash needs a non-causal-capable backend like FLASH_ATTN)."""
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
Expand Down Expand Up @@ -911,6 +916,15 @@ def create_draft_parallel_config(

return draft_parallel_config

@field_validator("attention_backend", mode="before")
@classmethod
def _parse_attention_backend(cls, value: Any) -> Any:
if isinstance(value, str):
if value.lower() == "auto":
return None
return AttentionBackendEnum[value.upper()]
return value

@model_validator(mode="after")
def _verify_args(self) -> Self:
if self.tensor_parallel_size is not None:
Expand Down
7 changes: 1 addition & 6 deletions vllm/v1/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,6 @@ def get_attn_backend(
else:
block_size = None

speculative_config = vllm_config.speculative_config
use_non_causal = (
speculative_config is not None and speculative_config.method == "dflash"
)

attn_selector_config = AttentionSelectorConfig(
head_size=head_size,
dtype=dtype,
Expand All @@ -97,7 +92,7 @@ def get_attn_backend(
use_mm_prefix=use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type or AttentionType.DECODER,
use_non_causal=use_non_causal,
use_non_causal=vllm_config.attention_config.use_non_causal,
use_batch_invariant=envs.VLLM_BATCH_INVARIANT,
)

Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/spec_decode/dflash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from dataclasses import replace
from typing import Any

import torch
Expand Down Expand Up @@ -67,6 +68,17 @@ def __init__(
# For DFlash we use the input embeddings to embed the mask token
self.parallel_drafting_hidden_state_tensor = None

@override
def _create_draft_vllm_config(self) -> VllmConfig:
base = super()._create_draft_vllm_config()
return replace(
base,
attention_config=replace(
base.attention_config,
use_non_causal=True,
),
)

@override
def _raise_if_multimodal(self):
# Override to allow multimodal inputs since DFlash supports Qwen3.5 models
Expand Down
22 changes: 18 additions & 4 deletions vllm/v1/spec_decode/llm_base_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,15 +1276,29 @@ def _create_draft_vllm_config(self) -> VllmConfig:
Subclasses may override to apply additional config changes.
"""
spec_cfg = self.speculative_config
base = self.vllm_config

if spec_cfg.moe_backend is not None:
return replace(
self.vllm_config,
base = replace(
base,
kernel_config=replace(
self.vllm_config.kernel_config,
base.kernel_config,
moe_backend=spec_cfg.moe_backend,
),
)
return self.vllm_config

# Note (matt): Never inherit the attention backend from base, because there are
# many opportunities for incompatibility, so we always independently autoselect
# unless explicitly specified in the speculative config.
base = replace(
base,
attention_config=replace(
base.attention_config,
backend=spec_cfg.attention_backend,
),
)

return base

def _get_model(self) -> nn.Module:
"""
Expand Down
Loading