diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py index b3025987a4ad..dd873cbaa260 100644 --- a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -205,12 +205,14 @@ def _resolve_draft_backend_type(self) -> str: def _init_draft_attn_backend(self): backend_type = self._resolve_draft_backend_type() - if backend_type != "triton": - raise ValueError( - "Frozen-KV MTP currently supports only the triton attention " - f"backend, got {backend_type}." - ) - return self._init_triton_draft_attn_backend() + if backend_type == "triton": + return self._init_triton_draft_attn_backend() + if backend_type == "trtllm_mha": + return self._init_trtllm_mha_draft_attn_backend() + raise ValueError( + "Frozen-KV MTP currently supports triton and trtllm_mha attention " + f"backends, got {backend_type}." + ) def _init_triton_draft_attn_backend(self): from sglang.srt.layers.attention.triton_backend import TritonAttnBackend @@ -225,6 +227,21 @@ def _init_triton_draft_attn_backend(self): kv_indptr_buf=kv_indptr_buf, ) + def _init_trtllm_mha_draft_attn_backend(self): + # TODO(kpham-sgl): trtllm_mha (Gemma4) known gaps: + # 1. target=trtllm_mha + num_draft_tokens>=6: flashinfer missing + # tile (tileSizeQ=64, headDim=512, page=64) for target_verify + # on full-attn layers. + # 2. target=triton + draft=trtllm_mha: needs --page-size in + # {16,32,64} (triton defaults to 1; no auto-promote for + # draft-only trtllm_mha). + # 3. topk>1 + page_size>1 + SWA: eagle verify hits + # `move_kv_cache` w/o enable_kv_cache_copy (orthogonal to + # trtllm_mha but blocks the only triton-target workaround). + from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend + + return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=True) + def _bind_kv_context(self) -> None: draft_model = self.draft_model_runner.model if not hasattr(draft_model, "build_frozen_kv_mtp_context") or not hasattr(