Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions vllm/v1/spec_decode/dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,28 @@ def __init__(
@override
def _create_draft_vllm_config(self) -> VllmConfig:
base = super()._create_draft_vllm_config()
# When the target is Gemma 4 (heterogeneous head dimensions:
# head_dim=256, global_head_dim=512), Gemma4Config.verify_and_update_config
# force-locks `attention_config.backend` to TRITON_ATTN to prevent
# mixed-backend numerical divergence within the target's own forward.
# The base SpecDecodeBaseProposer._create_draft_vllm_config resets
# backend to spec_cfg.attention_backend (default None), but in
# practice the drafter still ends up on TRITON_ATTN which then
# rejects DFlash's non-causal drafter attention at engine init.
#
# The "mixed-backend numerical divergence" concern doesn't apply
# to spec-decode where target and drafter are separate models with
# separate KV caches and separate forwards — rejection sampling
# tolerates numerical drift by design. Force backend=None here so
# the drafter goes through the standard autoselect path and picks
# FLEX_ATTENTION (or any other non-causal-capable backend) when
# use_non_causal=True is set.
return replace(
base,
attention_config=replace(
base.attention_config,
use_non_causal=True,
backend=None,
),
)

Expand Down
Loading