Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions python/sglang/srt/speculative/frozen_kv_mtp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,14 @@ def _resolve_draft_backend_type(self) -> str:

def _init_draft_attn_backend(self):
backend_type = self._resolve_draft_backend_type()
if backend_type != "triton":
raise ValueError(
"Frozen-KV MTP currently supports only the triton attention "
f"backend, got {backend_type}."
)
return self._init_triton_draft_attn_backend()
if backend_type == "triton":
return self._init_triton_draft_attn_backend()
if backend_type == "trtllm_mha":
return self._init_trtllm_mha_draft_attn_backend()
raise ValueError(
"Frozen-KV MTP currently supports triton and trtllm_mha attention "
f"backends, got {backend_type}."
)

def _init_triton_draft_attn_backend(self):
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
Expand All @@ -225,6 +227,21 @@ def _init_triton_draft_attn_backend(self):
kv_indptr_buf=kv_indptr_buf,
)

def _init_trtllm_mha_draft_attn_backend(self):
# TODO(kpham-sgl): trtllm_mha (Gemma4) known gaps:
# 1. target=trtllm_mha + num_draft_tokens>=6: flashinfer missing
# tile (tileSizeQ=64, headDim=512, page=64) for target_verify
# on full-attn layers.
# 2. target=triton + draft=trtllm_mha: needs --page-size in
# {16,32,64} (triton defaults to 1; no auto-promote for
# draft-only trtllm_mha).
# 3. topk>1 + page_size>1 + SWA: eagle verify hits
# `move_kv_cache` w/o enable_kv_cache_copy (orthogonal to
# trtllm_mha but blocks the only triton-target workaround).
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend

return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=True)

def _bind_kv_context(self) -> None:
draft_model = self.draft_model_runner.model
if not hasattr(draft_model, "build_frozen_kv_mtp_context") or not hasattr(
Expand Down
Loading