diff --git a/python/sglang/srt/models/gemma4_causal.py b/python/sglang/srt/models/gemma4_causal.py index 1cf761d31039..38debdb5dc68 100644 --- a/python/sglang/srt/models/gemma4_causal.py +++ b/python/sglang/srt/models/gemma4_causal.py @@ -878,6 +878,9 @@ def __init__( def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens + def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: + return self.model.embed_tokens.weight, self.lm_head.weight + def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) diff --git a/python/sglang/srt/models/gemma4_mm.py b/python/sglang/srt/models/gemma4_mm.py index a1f0fe2feca1..a9d0ca083ed2 100644 --- a/python/sglang/srt/models/gemma4_mm.py +++ b/python/sglang/srt/models/gemma4_mm.py @@ -256,6 +256,11 @@ def pad_input_ids( def get_input_embeddings(self) -> nn.Embedding: return self.language_model.get_input_embeddings() + def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: + # Gemma 4 multimodal ties its LM head to the text embed_tokens + embed = self.language_model.embed_tokens.weight + return embed, embed + def get_attention_sliding_window_size(self): return getattr(self.config.text_config, "sliding_window", -1) - 1 diff --git a/python/sglang/srt/models/gemma4_mtp.py b/python/sglang/srt/models/gemma4_mtp.py new file mode 100644 index 000000000000..1cb87b7c2e99 --- /dev/null +++ b/python/sglang/srt/models/gemma4_mtp.py @@ -0,0 +1,398 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +import copy +import logging +from typing import Dict, Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel + +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import ( + LogitsMetadata, + LogitsProcessor, + LogitsProcessorOutput, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.gemma4_causal import Gemma4ForCausalLM, Gemma4TextModel +from sglang.srt.speculative.frozen_kv_mtp_info import FrozenKVMTPContext +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +def _get_text_config(model_or_config) -> PretrainedConfig: + """Normalize either a model or a (possibly wrapped) config to ``Gemma4TextConfig``.""" + cfg = getattr(model_or_config, "config", model_or_config) + return getattr(cfg, "text_config", cfg) + + +def _resolve_target_text_model(target_model): + for attr in ("language_model", "model"): + candidate = getattr(target_model, attr, None) + if candidate is not None and hasattr(candidate, "layers"): + return candidate + raise AttributeError( + f"Frozen-KV MTP cannot locate the target trunk on " + f"{type(target_model).__name__}; expected ``.language_model`` " + "(multimodal) or ``.model`` (text-only) with a ``.layers`` attribute." + ) + + +class Gemma4AssistantForCausalLM(Gemma4ForCausalLM): + """Gemma 4 MTP assistant: target embed + recurrent hidden through pre/post projection; own ``lm_head``.""" + + base_model_prefix = "model" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + text_config = copy.deepcopy(_get_text_config(config)) + text_config.num_kv_shared_layers = 0 + PreTrainedModel.__init__(self, config=text_config) + self.assistant_config = config + self.config = text_config + self.quant_config = quant_config + + self.vocab_size = text_config.vocab_size + self.hidden_size = text_config.hidden_size + self.backbone_hidden_size = config.backbone_hidden_size + self.target_embed_scale = self.backbone_hidden_size**0.5 + self.use_ordered_embeddings = bool( + getattr(config, "use_ordered_embeddings", False) + ) + self.centroid_intermediate_top_k = int( + getattr(config, "centroid_intermediate_top_k", 32) + ) + + self.target_embed_weight: Optional[torch.Tensor] = None + self.pre_projection = ReplicatedLinear( + 2 * self.backbone_hidden_size, + self.hidden_size, + bias=False, + quant_config=None, + prefix=add_prefix("pre_projection", prefix), + ) + self.model = Gemma4TextModel( + config=text_config, + quant_config=quant_config, + prefix=add_prefix("model", prefix), + ) + self.post_projection = ReplicatedLinear( + self.hidden_size, + self.backbone_hidden_size, + bias=False, + quant_config=None, + prefix=add_prefix("post_projection", prefix), + ) + + if text_config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = nn.Linear(self.hidden_size, self.vocab_size, bias=False) + self.logits_processor = LogitsProcessor(text_config, skip_all_gather=True) + + if self.use_ordered_embeddings: + self.num_centroids = int(config.num_centroids) + self.vocab_size_per_centroid, rem = divmod( + self.vocab_size, self.num_centroids + ) + if rem: + raise ValueError( + "Frozen-KV MTP centroid head requires vocab_size to be a " + f"multiple of num_centroids (vocab={self.vocab_size}, " + f"num_centroids={self.num_centroids})." + ) + self.centroids = nn.Linear(self.hidden_size, self.num_centroids, bias=False) + self.register_buffer( + "token_ordering", + torch.zeros(self.vocab_size, dtype=torch.long), + persistent=True, + ) + else: + self.num_centroids = self.vocab_size_per_centroid = self.centroids = None + self.register_buffer("token_ordering", None, persistent=False) + + self.kv_context: Optional[FrozenKVMTPContext] = None + self.post_init() + + def bind_frozen_kv_context(self, ctx: FrozenKVMTPContext) -> None: + """Bind assistant attention to target-owned KV and suppress assistant KV writes.""" + for assistant_logical, layer in enumerate(self.model.layers): + target_phys = ctx.get_physical_layer_id(assistant_logical) + layer.self_attn.is_kv_shared_layer = True + layer.self_attn.kv_shared_layer_index = target_phys + layer.self_attn.attn.layer_id = target_phys + layer.self_attn.layer_id = assistant_logical + self.kv_context = ctx + + def build_frozen_kv_mtp_context( + self, + target_model, + target_token_to_kv_pool: KVCache, + ) -> FrozenKVMTPContext: + """Map each assistant layer to the target physical layer that owns its K/V. + + HF Gemma 4 ties each typed (sliding/full) assistant layer to the target's + last layer of the same type; that layer is itself KV-shared with an + earlier non-shared layer (via ``kv_shared_layer_index``). We collapse + those two hops once so attention can hand a direct ``layer_id`` to + ``RadixAttention`` at bind time. + """ + target_text = _get_text_config(target_model) + assistant_text = _get_text_config(self) + layers = _resolve_target_text_model(target_model).layers + + def kv_owner(idx: int) -> int: + attn = layers[idx].self_attn + owner = ( + getattr(attn, "kv_shared_layer_index", None) + if getattr(attn, "is_kv_shared_layer", False) + else idx + ) + if owner is None or getattr( + layers[owner].self_attn, "is_kv_shared_layer", False + ): + raise RuntimeError( + f"Frozen-KV MTP: target layer {idx} resolved to physical " + f"{owner!r}, which is missing or itself KV-shared " + "(HF invariant changed?)." + ) + return owner + + L = target_text.num_hidden_layers + by_type = {target_text.layer_types[i]: kv_owner(i) for i in (L - 2, L - 1)} + + physical: Dict[int, int] = {} + for i, t in enumerate(assistant_text.layer_types): + if t not in by_type: + raise ValueError( + f"Frozen-KV MTP assistant layer {i} has type {t!r}, " + f"expected one of {sorted(by_type)}." + ) + physical[i] = by_type[t] + + return FrozenKVMTPContext( + target_token_to_kv_pool=target_token_to_kv_pool, + physical_layer_ids=physical, + ) + + def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: + if self.target_embed_weight is None: + raise RuntimeError( + "Gemma4AssistantForCausalLM target embedding is not bound yet." + ) + return self.target_embed_weight, self.lm_head.weight + + def set_embed_and_head(self, embed: torch.Tensor, head: torch.Tensor) -> None: + """Rebind target embedding; ``head`` ignored (assistant keeps ``lm_head``).""" + del head + self.target_embed_weight = embed + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def get_attention_sliding_window_size(self) -> int: + # Gemma 4 config treats the bound as inclusive; SGLang attention metadata + # uses an exclusive window size, matching the target Gemma 4 models. + return self.config.sliding_window - 1 + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + **kwargs, + ) -> LogitsProcessorOutput: + if input_embeds is None: + if self.target_embed_weight is None: + raise RuntimeError( + "Gemma4AssistantForCausalLM requires set_embed_and_head() " + "before token-id forward." + ) + token_embed = ( + torch.nn.functional.embedding(input_ids, self.target_embed_weight) + * self.target_embed_scale + ) + else: + token_embed = input_embeds + + if forward_batch.spec_info is None or not hasattr( + forward_batch.spec_info, "hidden_states" + ): + raise RuntimeError( + "Frozen-KV MTP forward requires forward_batch.spec_info." + "hidden_states to carry the recurrent state. The worker's " + "_frozen_kv_target_view context manager must be exited " + "before model forward, leaving spec_info populated." + ) + prev_hidden = forward_batch.spec_info.hidden_states + if token_embed.shape != prev_hidden.shape: + raise ValueError( + "Frozen-KV MTP forward: token_embed and prev_hidden must have " + f"the same shape (got {token_embed.shape} vs {prev_hidden.shape})." + ) + + z, _ = self.pre_projection(torch.cat([token_embed, prev_hidden], dim=-1)) + hidden_states = self.model( + input_ids=None, + positions=positions, + forward_batch=forward_batch, + input_embeds=z, + per_layer_inputs=None, + **kwargs, + ) + projected_states, _ = self.post_projection(hidden_states) + + if self.use_ordered_embeddings: + return self._centroid_logits_processor( + input_ids, hidden_states, projected_states, forward_batch + ) + + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + hidden_states_before_norm=projected_states, + ) + + def _apply_centroid_masking(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Centroid-masked logits for E2B/E4B assistant heads.""" + if self.centroids is None or self.token_ordering is None: + raise RuntimeError( + "Frozen-KV MTP centroid head invoked but centroid weights " + "are not initialized." + ) + prefix_shape = hidden_states.shape[:-1] + flat_hidden = hidden_states.reshape(-1, hidden_states.shape[-1]) + num_tokens = flat_hidden.shape[0] + + _, top_k_indices = torch.topk( + self.centroids(flat_hidden), + k=self.centroid_intermediate_top_k, + dim=-1, + ) + + # Contiguous gather: [C, vpc, H] indexed by centroid IDs. + num_selected = self.centroid_intermediate_top_k * self.vocab_size_per_centroid + selected_embeddings = self.lm_head.weight.view( + self.num_centroids, + self.vocab_size_per_centroid, + self.hidden_size, + )[top_k_indices].reshape(num_tokens, num_selected, self.hidden_size) + + selected_logits = torch.bmm( + flat_hidden.unsqueeze(1), + selected_embeddings.transpose(1, 2), + ).squeeze(1) + + # Scatter to real vocab positions via token_ordering. + centroid_vocab_indices = ( + self.token_ordering.long() + .view(self.num_centroids, self.vocab_size_per_centroid)[top_k_indices] + .view(num_tokens, -1) + ) + mask_value = torch.finfo(selected_logits.dtype).min / 2 + output = torch.full( + (num_tokens, self.vocab_size), + mask_value, + dtype=selected_logits.dtype, + device=selected_logits.device, + ) + output.scatter_(dim=-1, index=centroid_vocab_indices, src=selected_logits) + return output.view(*prefix_shape, self.vocab_size) + + def _centroid_logits_processor( + self, + input_ids: torch.Tensor, + hidden_states: torch.Tensor, + projected_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + logits_metadata = LogitsMetadata.from_forward_batch(forward_batch) + if logits_metadata.extend_return_logprob: + raise NotImplementedError( + "Frozen-KV MTP centroid head does not support input logprobs yet." + ) + + ( + pruned_states, + pruned_states_before_norm, + aux_pruned_states, + sample_indices, + *_, + ) = self.logits_processor._get_pruned_states( + hidden_states, projected_states, None, logits_metadata + ) + hidden_states_to_store = self.logits_processor._get_hidden_states_to_store( + hidden_states, + projected_states, + None, + pruned_states, + pruned_states_before_norm, + aux_pruned_states, + sample_indices, + logits_metadata, + ) + del input_ids, hidden_states, projected_states + + logits = self._apply_centroid_masking(pruned_states) + sampled_logits = ( + logits[sample_indices] if sample_indices is not None else logits + ) + return LogitsProcessorOutput( + next_token_logits=sampled_logits, + hidden_states=hidden_states_to_store, + mm_input_embeds=logits_metadata.mm_input_embeds, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def remap_assistant_weights(): + for name, weight in weights: + if name.startswith("masked_embedding."): + name = name.removeprefix("masked_embedding.") + yield name, weight + + result = super().load_weights(remap_assistant_weights()) + if self.use_ordered_embeddings: + self._reorder_embedding_to_centroid_order() + return result + + @torch.no_grad() + def _reorder_embedding_to_centroid_order(self) -> None: + """Reorder lm_head.weight from natural vocab order to centroid order.""" + if self.token_ordering is None: + return + ordering = self.token_ordering.long() + lm_head_w = self.lm_head.weight + reordered = lm_head_w.data[ordering] + lm_head_w.data.copy_(reordered) + logger.info( + "Reordered lm_head/embed_tokens (%s) to centroid order " + "for contiguous centroid masking.", + list(lm_head_w.shape), + ) + + +EntryClass = Gemma4AssistantForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 30a3612d9b2f..61cf6a35151a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -301,6 +301,43 @@ def add_rl_on_policy_target_choices(choices): RL_ON_POLICY_TARGET_CHOICES.extend(choices) +def _resolve_speculative_algorithm_alias( + speculative_algorithm: Optional[str], + speculative_draft_model_path: Optional[str], + trust_remote_code: bool = False, +) -> Optional[str]: + """Resolve CLI speculative algorithm; NEXTN/EAGLE may become FROZEN_KV_MTP for Gemma4 assistant drafts.""" + + is_gemma4_draft = False + if speculative_draft_model_path: + from transformers import AutoConfig + + cfg = AutoConfig.from_pretrained( + speculative_draft_model_path, trust_remote_code=trust_remote_code + ) + is_gemma4_draft = "Gemma4AssistantForCausalLM" in ( + getattr(cfg, "architectures", None) or [] + ) + + if speculative_algorithm == "EAGLE3" and is_gemma4_draft: + raise ValueError( + "Gemma4AssistantForCausalLM draft requires " + "--speculative-algorithm NEXTN or EAGLE; EAGLE3 is " + "not supported for this draft architecture." + ) + + if speculative_algorithm == "NEXTN" or speculative_algorithm == "EAGLE": + if is_gemma4_draft: + logger.info( + "Detected Gemma4AssistantForCausalLM draft; " + f"promoting --speculative-algorithm {speculative_algorithm} to FROZEN_KV_MTP." + ) + return "FROZEN_KV_MTP" + return "EAGLE" + + return speculative_algorithm + + @dataclasses.dataclass class ServerArgs: """ @@ -3283,8 +3320,11 @@ def _handle_speculative_decoding(self): self.speculative_moe_runner_backend ).is_flashinfer_trtllm(), "Currently speculative MoE runner backend doesn't support flashinfer_trtllm, please use triton or auto backend for speculative moe runner instead." - if self.speculative_algorithm == "NEXTN": - self.speculative_algorithm = "EAGLE" + self.speculative_algorithm = _resolve_speculative_algorithm_alias( + self.speculative_algorithm, + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + ) if self.speculative_skip_dp_mlp_sync: assert self.speculative_algorithm == "EAGLE", ( @@ -3420,6 +3460,25 @@ def _handle_speculative_decoding(self): "Mixed chunked prefill is disabled because of using dflash speculative decoding." ) + if self.speculative_algorithm == "FROZEN_KV_MTP": + if self.max_running_requests is None: + self.max_running_requests = 48 + logger.warning( + "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." + ) + + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled when using Frozen-KV MTP speculative decoding (spec v2 is not supported yet)." + ) + + if self.enable_mixed_chunk: + self.enable_mixed_chunk = False + logger.warning( + "Mixed chunked prefill is disabled because of using " + "Frozen-KV MTP speculative decoding." + ) + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py new file mode 100644 index 000000000000..56e17906b97b --- /dev/null +++ b/python/sglang/srt/speculative/frozen_kv_mtp_cuda_graph_runner.py @@ -0,0 +1,400 @@ +from __future__ import annotations + +import bisect +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional + +import torch + +from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len +from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, + CudaGraphRunner, + DeepEPCudaGraphRunnerAdapter, + get_batch_sizes_to_capture, + get_global_graph_memory_pool, + model_capture_mode, + set_global_graph_memory_pool, + set_is_extend_in_batch, + set_torch_compile_config, +) +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.model_executor.input_buffers import ForwardInputBuffers +from sglang.srt.speculative.frozen_kv_mtp_info import FrozenKVMTPDraftInput +from sglang.srt.utils import ( + require_attn_tp_gather, + require_gathered_buffer, + require_mlp_sync, + require_mlp_tp_gather, +) + +if TYPE_CHECKING: + from sglang.srt.speculative.frozen_kv_mtp_worker import FrozenKVMTPWorker + + +@dataclass +class FrozenKVMTPInputBuffers(ForwardInputBuffers): + req_pool_indices: torch.Tensor + positions: torch.Tensor + mrope_positions: torch.Tensor + seq_lens: torch.Tensor + seq_lens_cpu: torch.Tensor + topk_p: torch.Tensor + topk_index: torch.Tensor + hidden_states: torch.Tensor + global_num_tokens_gpu: Optional[torch.Tensor] + global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] + + +class FrozenKVMTPCudaGraphRunner: + """CUDA graph runner for the Frozen-KV MTP recurrent draft-loop step.""" + + def __init__(self, frozen_kv_mtp_worker: FrozenKVMTPWorker): + self.frozen_kv_mtp_worker = frozen_kv_mtp_worker + self.model_runner = model_runner = frozen_kv_mtp_worker.draft_model_runner + self.graphs = {} + self.output_buffers = {} + self.enable_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) + self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) + self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) + self.tp_size = self.model_runner.tp_size + self.dp_size = self.model_runner.dp_size + self.speculative_num_steps = model_runner.server_args.speculative_num_steps + self.topk = model_runner.server_args.speculative_eagle_topk + self.draft_attn_backend = frozen_kv_mtp_worker.draft_attn_backend + self.enable_profile_cuda_graph = ( + model_runner.server_args.enable_profile_cuda_graph + ) + self.enable_pdmux = False + self.deepep_adapter = DeepEPCudaGraphRunnerAdapter() + + self.num_tokens_per_bs = self.topk + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture( + model_runner, self.num_tokens_per_bs + ) + self.max_bs = max(self.capture_bs) + self.max_num_token = self.max_bs * self.num_tokens_per_bs + + self.draft_attn_backend.init_cuda_graph_state(self.max_bs, self.max_num_token) + self.seq_len_fill_value = ( + self.draft_attn_backend.get_cuda_graph_seq_len_fill_value() + ) + seq_lens_cpu = torch.full( + (self.max_num_token,), self.seq_len_fill_value, dtype=torch.int32 + ) + + if self.enable_torch_compile: + set_torch_compile_config() + + with torch.device(model_runner.device): + req_pool_indices = torch.zeros((self.max_num_token,), dtype=torch.int64) + positions = torch.zeros((self.max_num_token,), dtype=torch.int64) + mrope_positions = torch.zeros((3, self.max_num_token), dtype=torch.int64) + seq_lens = torch.full( + (self.max_num_token,), self.seq_len_fill_value, dtype=torch.int32 + ) + topk_p = torch.zeros((self.max_bs, self.topk), dtype=torch.float32) + topk_index = torch.zeros((self.max_bs, self.topk), dtype=torch.int64) + hidden_states = torch.zeros( + (self.max_bs, frozen_kv_mtp_worker._recurrent_hidden_size), + dtype=self.model_runner.dtype, + ) + + if self.require_gathered_buffer: + if self.require_mlp_tp_gather: + global_num_tokens_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + global_num_tokens_for_logprob_gpu = torch.zeros( + (self.dp_size,), dtype=torch.int32 + ) + else: + assert self.require_attn_tp_gather + global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + global_num_tokens_for_logprob_gpu = torch.zeros( + (1,), dtype=torch.int32 + ) + else: + global_num_tokens_gpu = None + global_num_tokens_for_logprob_gpu = None + + self.buffers = FrozenKVMTPInputBuffers( + req_pool_indices=req_pool_indices, + positions=positions, + mrope_positions=mrope_positions, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + global_num_tokens_gpu=global_num_tokens_gpu, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob_gpu, + ) + self.buffers.share_buffers() + + try: + with model_capture_mode(): + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture frozen-KV MTP cuda graph failed: {e}\n" + f"{CUDA_GRAPH_CAPTURE_FAILED_MSG}" + ) + + def can_run(self, forward_batch: ForwardBatch): + if self.require_mlp_tp_gather: + cuda_graph_bs = max(forward_batch.global_num_tokens_cpu) // ( + self.topk * self.topk + ) + else: + cuda_graph_bs = ( + forward_batch.batch_size // self.topk + if self.topk > 1 + else forward_batch.batch_size + ) + + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph + return is_bs_supported + + def _create_graph(self): + return torch.cuda.CUDAGraph() + + def _capture_init(self, run_once_fn): + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + run_once_fn() + + def _capture_graph(self, graph, pool, stream, run_once_fn): + with torch.cuda.graph(graph, pool=pool, stream=stream): + out = run_once_fn() + return out + + def _replay(self): + self.graphs[self.bs].replay() + + def capture(self): + CudaGraphRunner.capture(self) + + def capture_one_batch_size( + self, num_seqs: int, forward: Callable, stream_idx: int = 0 + ): + del forward, stream_idx + buffers = self.buffers + graph = self._create_graph() + stream = self.stream + request_bs = num_seqs + expanded_bs = request_bs * self.num_tokens_per_bs + + req_pool_indices = buffers.req_pool_indices[:expanded_bs] + positions = buffers.positions[:expanded_bs] + mrope_positions = buffers.mrope_positions[:, :expanded_bs] + seq_lens = buffers.seq_lens[:expanded_bs] + seq_lens_cpu = buffers.seq_lens_cpu[:expanded_bs] + topk_p = buffers.topk_p[:request_bs] + topk_index = buffers.topk_index[:request_bs] + hidden_states = buffers.hidden_states[:request_bs] + + if self.require_mlp_tp_gather: + buffers.global_num_tokens_gpu.copy_( + torch.tensor( + [expanded_bs] * self.dp_size, + dtype=torch.int32, + device=buffers.positions.device, + ) + ) + buffers.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [expanded_bs] * self.dp_size, + dtype=torch.int32, + device=buffers.positions.device, + ) + ) + global_num_tokens = buffers.global_num_tokens_gpu + global_num_tokens_for_logprob = buffers.global_num_tokens_for_logprob_gpu + global_dp_buffer_len = expanded_bs * self.dp_size + elif self.require_attn_tp_gather: + buffers.global_num_tokens_gpu.copy_( + torch.tensor( + [expanded_bs], + dtype=torch.int32, + device=buffers.positions.device, + ) + ) + buffers.global_num_tokens_for_logprob_gpu.copy_( + torch.tensor( + [expanded_bs], + dtype=torch.int32, + device=buffers.positions.device, + ) + ) + global_num_tokens = buffers.global_num_tokens_gpu + global_num_tokens_for_logprob = buffers.global_num_tokens_for_logprob_gpu + global_dp_buffer_len = expanded_bs + else: + global_num_tokens = None + global_num_tokens_for_logprob = None + global_dp_buffer_len = None + + spec_info = FrozenKVMTPDraftInput( + topk_p=topk_p, + topk_index=topk_index, + hidden_states=hidden_states, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + spec_info.num_tokens_per_req = self.topk + spec_info.num_tokens_for_logprob_per_req = self.topk + spec_info.positions = positions + + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=expanded_bs, + input_ids=None, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.frozen_kv_mtp_worker.kv_context.target_token_to_kv_pool, + attn_backend=self.draft_attn_backend, + out_cache_loc=None, + seq_lens_sum=seq_lens.sum().item(), + return_logprob=False, + positions=positions, + mrope_positions=mrope_positions, + global_num_tokens_gpu=global_num_tokens, + global_num_tokens_for_logprob_gpu=global_num_tokens_for_logprob, + dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), + global_dp_buffer_len=global_dp_buffer_len, + spec_algorithm=self.model_runner.spec_algorithm, + spec_info=spec_info, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + + self.frozen_kv_mtp_worker._init_frozen_kv_metadata_capture_cuda_graph( + forward_batch + ) + + def run_once(): + forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None + set_dp_buffer_len( + global_dp_buffer_len, + expanded_bs, + forward_batch.dp_padding_mode.is_max_len(), + ) + set_is_extend_in_batch(False) + + hidden_states_backup = forward_batch.spec_info.hidden_states + ret = self.frozen_kv_mtp_worker.draft_forward( + forward_batch, skip_attn_backend_init=True + ) + forward_batch.spec_info.hidden_states = hidden_states_backup + return ret + + self.deepep_adapter.capture(is_extend_in_batch=False) + self._capture_init(run_once) + out = self._capture_graph( + graph, get_global_graph_memory_pool(), stream, run_once + ) + set_global_graph_memory_pool(graph.pool()) + return graph, out + + def _postprocess_output_to_raw_bs(self, out, raw_bs): + parent_list, top_scores_index, draft_tokens = (t[:raw_bs] for t in out) + return parent_list, top_scores_index, draft_tokens + + def replay(self, forward_batch: ForwardBatch): + self.deepep_adapter.replay() + buffers = self.buffers + + raw_expanded_bs = forward_batch.batch_size + raw_bs = ( + raw_expanded_bs // self.num_tokens_per_bs + if self.topk > 1 + else raw_expanded_bs + ) + raw_num_token = raw_expanded_bs + + if self.require_mlp_tp_gather: + max_num_tokens = max(forward_batch.global_num_tokens_cpu) + max_batch_size = max_num_tokens // ( + self.num_tokens_per_bs * self.num_tokens_per_bs + ) + index = bisect.bisect_left(self.capture_bs, max_batch_size) + else: + index = bisect.bisect_left(self.capture_bs, raw_bs) + + bs = self.capture_bs[index] + expanded_bs = bs * self.num_tokens_per_bs + if bs != raw_bs: + buffers.seq_lens.fill_(self.seq_len_fill_value) + buffers.positions.zero_() + + num_tokens = expanded_bs + buffers.seq_lens[:raw_expanded_bs].copy_(forward_batch.seq_lens) + buffers.positions[:raw_num_token].copy_(forward_batch.positions) + if forward_batch.mrope_positions is not None: + buffers.mrope_positions[:, :raw_num_token].copy_( + forward_batch.mrope_positions + ) + buffers.topk_p[:raw_bs].copy_(forward_batch.spec_info.topk_p) + buffers.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) + buffers.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) + buffers.req_pool_indices[:raw_expanded_bs].copy_(forward_batch.req_pool_indices) + + if self.require_gathered_buffer: + buffers.global_num_tokens_gpu.fill_(expanded_bs) + buffers.global_num_tokens_for_logprob_gpu.fill_(expanded_bs) + + if bs != raw_bs: + forward_batch.batch_size = expanded_bs + forward_batch.seq_lens = buffers.seq_lens[:expanded_bs] + forward_batch.req_pool_indices = buffers.req_pool_indices[:expanded_bs] + forward_batch.positions = buffers.positions[:num_tokens] + if forward_batch.mrope_positions is not None: + forward_batch.mrope_positions = buffers.mrope_positions[:, :num_tokens] + + if forward_batch.seq_lens_cpu is not None: + if bs != raw_bs: + buffers.seq_lens_cpu.fill_(self.seq_len_fill_value) + buffers.seq_lens_cpu[:raw_expanded_bs].copy_(forward_batch.seq_lens_cpu) + forward_batch.seq_lens_cpu = buffers.seq_lens_cpu[:expanded_bs] + + self.frozen_kv_mtp_worker._init_frozen_kv_metadata_replay_cuda_graph( + forward_batch, + expanded_bs, + forward_batch.seq_lens_sum + + (expanded_bs - raw_expanded_bs) * self.seq_len_fill_value, + ) + + self.raw_bs = raw_bs + self.bs = bs + self._replay() + out = self.output_buffers[bs] + + if bs != raw_bs: + out = self._postprocess_output_to_raw_bs(out, raw_bs) + forward_batch.batch_size = raw_expanded_bs + forward_batch.positions = buffers.positions[:raw_num_token] + forward_batch.seq_lens = buffers.seq_lens[:raw_expanded_bs] + forward_batch.req_pool_indices = buffers.req_pool_indices[:raw_expanded_bs] + if forward_batch.mrope_positions is not None: + forward_batch.mrope_positions = buffers.mrope_positions[ + :, :raw_num_token + ] + if forward_batch.seq_lens_cpu is not None: + forward_batch.seq_lens_cpu = buffers.seq_lens_cpu[:raw_expanded_bs] + + return out diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_info.py b/python/sglang/srt/speculative/frozen_kv_mtp_info.py new file mode 100644 index 000000000000..27a7249b07e8 --- /dev/null +++ b/python/sglang/srt/speculative/frozen_kv_mtp_info.py @@ -0,0 +1,82 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +from dataclasses import dataclass, fields +from typing import Dict + +from sglang.srt.mem_cache.memory_pool import KVCache +from sglang.srt.speculative.eagle_info import ( + EagleDraftInput, + EagleVerifyInput, + EagleVerifyOutput, +) +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType + + +@dataclass(frozen=True) +class FrozenKVMTPContext: + """Target KV pool + assistant-logical -> target-physical layer map.""" + + target_token_to_kv_pool: KVCache + physical_layer_ids: Dict[int, int] + + def get_physical_layer_id(self, idx: int) -> int: + if idx not in self.physical_layer_ids: + raise KeyError( + f"FrozenKVMTPContext has no physical layer id for assistant " + f"logical index {idx}; available: {sorted(self.physical_layer_ids)}" + ) + return self.physical_layer_ids[idx] + + +@dataclass +class FrozenKVMTPDraftInput(EagleDraftInput): + """Draft input for Frozen-KV MTP. + + Frozen-KV MTP currently reuses the EAGLE scheduler/attention contract, but + has a dedicated type so algorithm-specific behavior can move here over time. + """ + + def __post_init__(self): + SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_DRAFT) + + +@dataclass +class FrozenKVMTPVerifyInput(EagleVerifyInput): + """Verify input for Frozen-KV MTP.""" + + def __post_init__(self): + SpecInput.__init__(self, SpecInputType.FROZEN_KV_MTP_VERIFY) + + def verify(self, *args, **kwargs) -> EagleVerifyOutput: + output = super().verify(*args, **kwargs) + output.draft_input = _to_frozen_kv_mtp_draft_input(output.draft_input) + return output + + +FrozenKVMTPVerifyOutput = EagleVerifyOutput + + +def _to_frozen_kv_mtp_draft_input( + draft_input: EagleDraftInput, +) -> FrozenKVMTPDraftInput: + if isinstance(draft_input, FrozenKVMTPDraftInput): + return draft_input + return FrozenKVMTPDraftInput( + **{ + field.name: getattr(draft_input, field.name) + for field in fields(EagleDraftInput) + } + ) diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_utils.py b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py new file mode 100644 index 000000000000..dc74d801bed3 --- /dev/null +++ b/python/sglang/srt/speculative/frozen_kv_mtp_utils.py @@ -0,0 +1,155 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import annotations + +from contextlib import contextmanager +from typing import Tuple + +import torch + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.speculative.frozen_kv_mtp_info import ( + FrozenKVMTPContext, + FrozenKVMTPDraftInput, +) +from sglang.srt.speculative.spec_utils import fast_topk + + +@contextmanager +def frozen_kv_target_view(forward_batch: ForwardBatch, kv_context: FrozenKVMTPContext): + """Build attention metadata against committed target-prefix geometry.""" + if kv_context is None: + raise RuntimeError( + "Frozen-KV MTP target view called before the model was bound; " + "bind the frozen KV context first." + ) + saved_spec_info = forward_batch.spec_info + saved_kv_pool = forward_batch.token_to_kv_pool + forward_batch.spec_info = None + forward_batch.token_to_kv_pool = kv_context.target_token_to_kv_pool + try: + yield + finally: + forward_batch.spec_info = saved_spec_info + forward_batch.token_to_kv_pool = saved_kv_pool + + +@contextmanager +def target_kv_pool_view(forward_batch: ForwardBatch, kv_context: FrozenKVMTPContext): + if kv_context is None: + raise RuntimeError( + "Frozen-KV MTP target KV pool view called before the model was bound; " + "bind the frozen KV context first." + ) + saved_kv_pool = forward_batch.token_to_kv_pool + forward_batch.token_to_kv_pool = kv_context.target_token_to_kv_pool + try: + yield + finally: + forward_batch.token_to_kv_pool = saved_kv_pool + + +def set_frozen_kv_positions(forward_batch: ForwardBatch, topk: int) -> None: + """Rope phase = last written target slot, not advanced per draft step.""" + seq_lens = forward_batch.seq_lens + positions = torch.clamp(seq_lens - 1, min=0).to(torch.int64) + if ( + topk > 1 + and forward_batch.positions is not None + and forward_batch.positions.numel() == positions.numel() * topk + ): + positions = positions.repeat_interleave(topk, dim=0) + if forward_batch.positions is None: + forward_batch.positions = positions + else: + if forward_batch.positions.shape == positions.shape: + forward_batch.positions.copy_(positions) + else: + forward_batch.positions = positions + + +def expand_for_topk_draft(forward_batch: ForwardBatch, topk: int) -> None: + """Repeat committed-prefix metadata for the active ``B * topk`` frontier.""" + if topk == 1 or forward_batch.batch_size == 0: + return + + if forward_batch.batch_size != forward_batch.seq_lens.shape[0]: + raise RuntimeError( + "Frozen-KV MTP topk expansion expects an unexpanded forward " + "batch where batch_size == len(seq_lens)." + ) + + forward_batch.batch_size *= topk + forward_batch.req_pool_indices = forward_batch.req_pool_indices.repeat_interleave( + topk, dim=0 + ) + forward_batch.seq_lens = forward_batch.seq_lens.repeat_interleave(topk, dim=0) + if forward_batch.seq_lens_cpu is not None: + forward_batch.seq_lens_cpu = forward_batch.seq_lens_cpu.repeat_interleave( + topk, dim=0 + ) + forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item() + else: + forward_batch.seq_lens_sum = torch.sum(forward_batch.seq_lens).item() + + positions = torch.clamp(forward_batch.seq_lens - 1, min=0).to(torch.int64) + forward_batch.positions = positions + forward_batch.num_token_non_padded_cpu = positions.numel() + if forward_batch.num_token_non_padded is not None: + forward_batch.num_token_non_padded.fill_(positions.numel()) + if ( + forward_batch.mrope_positions is not None + and forward_batch.mrope_positions.shape[-1] * topk == positions.numel() + ): + forward_batch.mrope_positions = forward_batch.mrope_positions.repeat_interleave( + topk, dim=-1 + ) + + +def position_for_batch(batch: ScheduleBatch) -> torch.Tensor: + return torch.clamp(batch.seq_lens - 1, min=0).to(torch.int64) + + +def select_last_extend_hidden( + batch: ScheduleBatch, hidden_states: torch.Tensor +) -> torch.Tensor: + if hidden_states.shape[0] == batch.batch_size(): + return hidden_states + lens = torch.tensor(batch.extend_lens, device=hidden_states.device) + last_indices = torch.cumsum(lens, dim=0) - 1 + return hidden_states[last_indices.to(torch.long)] + + +def select_last_verified_seed( + draft_input: FrozenKVMTPDraftInput, +) -> Tuple[torch.Tensor, torch.Tensor]: + if draft_input.num_accepted_tokens is None: + return draft_input.verified_id, draft_input.hidden_states + + counts = draft_input.num_accepted_tokens.to(torch.long) + last_indices = torch.cumsum(counts, dim=0) - 1 + return ( + draft_input.verified_id[last_indices], + draft_input.hidden_states[last_indices], + ) + + +def capture_for_decode( + logits_output: LogitsProcessorOutput, draft_input: FrozenKVMTPDraftInput, topk: int +) -> None: + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + draft_input.topk_p, draft_input.topk_index = fast_topk(probs, topk, dim=-1) + draft_input.hidden_states = logits_output.hidden_states diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py new file mode 100644 index 000000000000..8174816ae05e --- /dev/null +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker.py @@ -0,0 +1,772 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Frozen-KV MTP draft worker. + +The assistant reads target KV only. It reuses EAGLE's verify input/output +contract, but owns the seed and recurrent draft loop because there is no +assistant-side KV extension. +""" + +from __future__ import annotations + +import logging +from typing import List, Optional, Tuple + +import torch + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.moe.utils import ( + speculative_moe_a2a_backend_context, + speculative_moe_backend_context, +) +from sglang.srt.layers.utils.logprob import add_output_logprobs_for_spec_v1 +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.model_executor.pool_configurator import MemoryPoolConfig +from sglang.srt.observability.req_time_stats import set_time_batch +from sglang.srt.observability.trace import get_global_tracing_enabled +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.eagle_utils import ( + build_tree_kernel_efficient, + organize_draft_results, +) +from sglang.srt.speculative.frozen_kv_mtp_info import ( + FrozenKVMTPContext, + FrozenKVMTPDraftInput, + FrozenKVMTPVerifyInput, + FrozenKVMTPVerifyOutput, +) +from sglang.srt.speculative.frozen_kv_mtp_utils import ( + capture_for_decode, + expand_for_topk_draft, + frozen_kv_target_view, + position_for_batch, + select_last_extend_hidden, + select_last_verified_seed, + set_frozen_kv_positions, + target_kv_pool_view, +) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import ( + draft_tp_context, + fast_topk, + generate_token_bitmask, + maybe_detect_nan, + maybe_detect_oob, + select_top_k_tokens, +) +from sglang.srt.utils import empty_context + +logger = logging.getLogger(__name__) + + +class FrozenKVMTPWorker(TpModelWorker): + """Frozen-KV MTP worker; same constructor shape as EAGLEWorker. Entry: + :meth:`forward_batch_generation` (stubs for now). + """ + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.server_args = server_args + self.topk = server_args.speculative_eagle_topk + self.speculative_num_steps = server_args.speculative_num_steps + self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens + self.gpu_id = gpu_id + self.device = server_args.device + self.target_worker = target_worker + self.page_size = server_args.page_size + self.speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) + assert self.speculative_algorithm.is_frozen_kv_mtp(), ( + "FrozenKVMTPWorker should only be instantiated for " + "SpeculativeAlgorithm.FROZEN_KV_MTP, got " + f"{self.speculative_algorithm.name}. The dispatch happens in " + "server_args._handle_speculative_decoding -> " + "_resolve_speculative_algorithm_alias." + ) + + # Assistant reads target KV directly, so its context length must match the target. + server_args.context_length = target_worker.model_runner.model_config.context_len + + # Defer cuda graph capture; we do it ourselves below. + backup_disable_cuda_graph = server_args.disable_cuda_graph + server_args.disable_cuda_graph = True + + # Draft attention uses target req_to_token + KV allocator (read-only). + self.req_to_token_pool, self.token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + + target_cfg = target_worker.model_runner.memory_pool_config + draft_pool_config = MemoryPoolConfig( + max_total_num_tokens=64, # Dummy value + max_running_requests=target_cfg.max_running_requests, + ) + + self.hot_token_id = None + + with ( + empty_context() + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + super().__init__( + server_args=server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + pp_rank=0, + dp_rank=dp_rank, + moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + memory_pool_config=draft_pool_config, + ) + + embed, head = self.target_worker.model_runner.model.get_embed_and_head() + if hasattr(self.draft_model_runner.model, "set_embed_and_head"): + self.draft_model_runner.model.set_embed_and_head(embed, head) + else: + logger.debug( + "Draft model %s does not implement set_embed_and_head; " + "skipping target-embedding bind in Frozen-KV MTP skeleton.", + type(self.draft_model_runner.model).__name__, + ) + + self.kv_context: Optional["FrozenKVMTPContext"] = None + if hasattr(self.draft_model_runner.model, "bind_frozen_kv_context"): + self._bind_kv_context() + + self.draft_model_runner.server_args.disable_cuda_graph = ( + backup_disable_cuda_graph + ) + + self.draft_tp_context = ( + draft_tp_context if server_args.enable_dp_attention else empty_context + ) + + self.draft_attn_backend = self._init_draft_attn_backend() + self.draft_model_runner.draft_attn_backend = self.draft_attn_backend + self.cuda_graph_runner = None + + with self.draft_tp_context( + self.draft_model_runner.tp_group + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + self.init_cuda_graphs() + + @property + def draft_model_runner(self): + return self.model_runner + + def get_attn_backend(self): # pragma: no cover - exposed for adaptive + return self.draft_attn_backend + + def clear_cache_pool(self): + pass + + def _resolve_draft_backend_type(self) -> str: + return ( + self.server_args.speculative_draft_attention_backend + or self.server_args.decode_attention_backend + or self.server_args.attention_backend + ) + + def _init_draft_attn_backend(self): + if self.topk == 1: + return self.draft_model_runner.attn_backend + + backend_type = self._resolve_draft_backend_type() + if backend_type != "triton": + raise ValueError( + "Frozen-KV MTP topk > 1 currently supports only the triton " + f"attention backend, got {backend_type}." + ) + return self._init_triton_draft_attn_backend() + + def _init_triton_draft_attn_backend(self): + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + max_bs = self.req_to_token_pool.size * self.topk + kv_indptr_buf = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=self.draft_model_runner.device + ) + return TritonAttnBackend( + self.draft_model_runner, + skip_prefill=True, + kv_indptr_buf=kv_indptr_buf, + ) + + def _bind_kv_context(self) -> None: + draft_model = self.draft_model_runner.model + if not hasattr(draft_model, "build_frozen_kv_mtp_context") or not hasattr( + draft_model, "bind_frozen_kv_context" + ): + logger.debug( + "Draft model %s does not implement Frozen-KV MTP context hooks; " + "skipping frozen-kv bind.", + type(draft_model).__name__, + ) + return + + ctx = draft_model.build_frozen_kv_mtp_context( + target_model=self.target_worker.model_runner.model, + target_token_to_kv_pool=self.target_worker.model_runner.token_to_kv_pool, + ) + draft_model.bind_frozen_kv_context(ctx) + self.kv_context = ctx + + def _frozen_kv_target_view(self, forward_batch: ForwardBatch): + return frozen_kv_target_view(forward_batch, self.kv_context) + + def _target_kv_pool_view(self, forward_batch: ForwardBatch): + return target_kv_pool_view(forward_batch, self.kv_context) + + def _set_positions(self, forward_batch: ForwardBatch) -> None: + set_frozen_kv_positions(forward_batch, self.topk) + + def _expand_for_topk_draft(self, forward_batch: ForwardBatch) -> None: + expand_for_topk_draft(forward_batch, self.topk) + + def _position_for_batch(self, batch: ScheduleBatch) -> torch.Tensor: + return position_for_batch(batch) + + @property + def _recurrent_hidden_size(self) -> int: + return int(self.draft_model_runner.model.backbone_hidden_size) + + def _init_frozen_kv_metadata(self, forward_batch: ForwardBatch) -> None: + if forward_batch.forward_mode.is_idle(): + return + if forward_batch.seq_lens_cpu is not None: + forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item() + else: + forward_batch.seq_lens_sum = torch.sum(forward_batch.seq_lens).item() + with self._frozen_kv_target_view(forward_batch): + self.draft_attn_backend.init_forward_metadata(forward_batch) + forward_batch.attn_backend = self.draft_attn_backend + + def _init_frozen_kv_metadata_capture_cuda_graph( + self, forward_batch: ForwardBatch + ) -> None: + with self._frozen_kv_target_view(forward_batch): + self.draft_attn_backend.init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.positions.numel(), + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=None, + ) + forward_batch.attn_backend = self.draft_attn_backend + + def _init_frozen_kv_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int, seq_lens_sum: int + ) -> None: + with self._frozen_kv_target_view(forward_batch): + self.draft_attn_backend.init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices[:bs], + forward_batch.seq_lens[:bs], + seq_lens_sum, + encoder_lens=None, + forward_mode=ForwardMode.DECODE, + spec_info=None, + seq_lens_cpu=( + forward_batch.seq_lens_cpu[:bs] + if forward_batch.seq_lens_cpu is not None + else None + ), + ) + forward_batch.attn_backend = self.draft_attn_backend + + def init_cuda_graphs(self) -> None: + if self.server_args.disable_cuda_graph or self.speculative_num_steps <= 1: + return + if self.target_worker.device != "cuda": + logger.info( + "Frozen-KV MTP draft CUDA graph is only supported on CUDA; " + "running the draft loop eagerly on %s.", + self.target_worker.device, + ) + return + + from sglang.srt.speculative.frozen_kv_mtp_cuda_graph_runner import ( + FrozenKVMTPCudaGraphRunner, + ) + + logger.info("Capture Frozen-KV MTP draft cuda graph begin.") + self.cuda_graph_runner = FrozenKVMTPCudaGraphRunner(self) + logger.info("Capture Frozen-KV MTP draft cuda graph end.") + + def _select_last_extend_hidden( + self, batch: ScheduleBatch, hidden_states: torch.Tensor + ) -> torch.Tensor: + return select_last_extend_hidden(batch, hidden_states) + + def _select_last_verified_seed( + self, draft_input: FrozenKVMTPDraftInput + ) -> Tuple[torch.Tensor, torch.Tensor]: + return select_last_verified_seed(draft_input) + + def _capture_for_decode( + self, logits_output: LogitsProcessorOutput, draft_input: FrozenKVMTPDraftInput + ) -> None: + capture_for_decode(logits_output, draft_input, self.topk) + + def _run_assistant_seed_step( + self, + batch: ScheduleBatch, + last_token_ids: torch.Tensor, + last_hidden_states: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor] = None, + mm_input_embeds: Optional[torch.Tensor] = None, + draft_input: Optional[FrozenKVMTPDraftInput] = None, + ) -> None: + """Run the one-token assistant seed step against frozen target KV.""" + if batch.forward_mode.is_idle() or last_token_ids.numel() == 0: + batch.spec_info = FrozenKVMTPDraftInput.create_idle_input( + device=batch.device, + hidden_size=self._recurrent_hidden_size, + dtype=self.model_config.dtype, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + return + + if draft_input is None: + draft_input = FrozenKVMTPDraftInput() + + draft_input.verified_id = last_token_ids.to(torch.int64) + draft_input.hidden_states = last_hidden_states + draft_input.capture_hidden_mode = CaptureHiddenMode.LAST + draft_input.num_tokens_per_req = 1 + draft_input.num_tokens_for_logprob_per_req = 1 + draft_input.positions = self._position_for_batch(batch) + + forward_mode_backup = batch.forward_mode + input_ids_backup = batch.input_ids + return_hidden_states_backup = batch.return_hidden_states + return_logprob_backup = batch.return_logprob + spec_info_backup = batch.spec_info + + batch.forward_mode = ForwardMode.DECODE + batch.input_ids = draft_input.verified_id + batch.return_hidden_states = False + batch.return_logprob = False + batch.spec_info = draft_input + + try: + model_worker_batch = batch.get_model_worker_batch( + seq_lens_cpu_cache=seq_lens_cpu + ) + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.draft_model_runner + ) + forward_batch.return_logprob = False + if mm_input_embeds is not None: + forward_batch.mm_input_embeds = mm_input_embeds + self._set_positions(forward_batch) + self._init_frozen_kv_metadata(forward_batch) + with self._target_kv_pool_view(forward_batch): + logits_output = self.draft_model_runner.forward( + forward_batch, skip_attn_backend_init=True + ).logits_output + maybe_detect_nan(logits_output.next_token_logits, "frozen_kv_mtp_seed") + self._capture_for_decode(logits_output, draft_input) + finally: + batch.forward_mode = forward_mode_backup + batch.input_ids = input_ids_backup + batch.return_hidden_states = return_hidden_states_backup + batch.return_logprob = return_logprob_backup + # Keep the seeded draft state; only restore the old object on error paths + # before the assignment above could have happened. + if batch.spec_info is not draft_input: + batch.spec_info = spec_info_backup + + def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult: + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + ( + logits_output, + next_token_ids, + seq_lens_cpu, + can_run_cuda_graph, + ) = self.forward_target_extend(batch) + with self.draft_tp_context( + self.draft_model_runner.tp_group + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + self.forward_draft_extend( + batch, + logits_output.hidden_states, + next_token_ids, + seq_lens_cpu, + logits_output.mm_input_embeds, + ) + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_drafts=0, + can_run_cuda_graph=can_run_cuda_graph, + ) + + set_time_batch(batch.reqs, "set_spec_draft_start_time", trace_only=True) + with self.draft_tp_context( + self.draft_model_runner.tp_group + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + spec_info = self.draft(batch) + set_time_batch(batch.reqs, "set_spec_draft_end_time", trace_only=True) + set_time_batch(batch.reqs, "set_spec_verify_start_time", trace_only=True) + + logits_output, verify_output, _, can_run_cuda_graph = self.verify( + batch, spec_info + ) + + if get_global_tracing_enabled(): + for idx, req in enumerate(batch.reqs): + accepted = verify_output.num_accepted_drafts_per_req_cpu[idx] + req.time_stats.set_spec_verify_end_time(accepted_tokens=accepted) + + set_time_batch(batch.reqs, "set_spec_draft_extend_start_time", trace_only=True) + with self.draft_tp_context( + self.draft_model_runner.tp_group + ), speculative_moe_backend_context(), speculative_moe_a2a_backend_context(): + if ( + self.server_args.enable_dp_attention + or batch.spec_info.verified_id.numel() + ): + self.forward_draft_extend_after_decode(batch) + set_time_batch(batch.reqs, "set_spec_draft_extend_end_time", trace_only=True) + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=verify_output.verified_id, + num_accepted_drafts=sum(verify_output.num_accepted_drafts_per_req_cpu), + num_accepted_drafts_per_req_cpu=verify_output.num_accepted_drafts_per_req_cpu, + can_run_cuda_graph=can_run_cuda_graph, + ) + + def forward_target_extend( + self, batch: ScheduleBatch + ) -> Tuple[LogitsProcessorOutput, torch.Tensor, Optional[torch.Tensor], bool]: + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + batch_result = self.target_worker.forward_batch_generation(model_worker_batch) + return ( + batch_result.logits_output, + batch_result.next_token_ids, + model_worker_batch.seq_lens_cpu, + batch_result.can_run_cuda_graph, + ) + + def forward_draft_extend( + self, + batch: ScheduleBatch, + hidden_states: torch.Tensor, + next_token_ids: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + mm_input_embeds: Optional[torch.Tensor] = None, + ) -> None: + last_hidden = self._select_last_extend_hidden(batch, hidden_states) + self._run_assistant_seed_step( + batch, + next_token_ids, + last_hidden, + seq_lens_cpu=seq_lens_cpu, + mm_input_embeds=mm_input_embeds, + ) + + def forward_draft_extend_after_decode(self, batch: ScheduleBatch) -> None: + assert isinstance(batch.spec_info, FrozenKVMTPDraftInput) + input_is_idle = batch.forward_mode.is_idle() + if not input_is_idle and batch.spec_info.verified_id.numel() == 0: + batch = batch.copy() + batch.prepare_for_idle() + batch.spec_info = FrozenKVMTPDraftInput.create_idle_input( + device=self.device, + hidden_size=self._recurrent_hidden_size, + dtype=self.model_config.dtype, + topk=self.topk, + capture_hidden_mode=CaptureHiddenMode.LAST, + ) + + if batch.forward_mode.is_idle(): + return + + draft_input = batch.spec_info + seq_lens_backup = batch.seq_lens.clone() + seq_lens_cpu_backup = batch.seq_lens_cpu.clone() + req_pool_indices_backup = batch.req_pool_indices + + try: + if draft_input.seq_lens_for_draft_extend is not None: + # Verify may leave finished requests in ScheduleBatch; seed only + # the unfinished requests carried by draft_input. + batch.seq_lens = draft_input.seq_lens_for_draft_extend + batch.seq_lens_cpu = draft_input.seq_lens_for_draft_extend_cpu + batch.req_pool_indices = draft_input.req_pool_indices_for_draft_extend + + last_token_ids, last_hidden = self._select_last_verified_seed(draft_input) + self._run_assistant_seed_step( + batch, + last_token_ids, + last_hidden, + seq_lens_cpu=draft_input.seq_lens_for_draft_extend_cpu, + draft_input=draft_input, + ) + finally: + batch.seq_lens = seq_lens_backup + batch.seq_lens_cpu = seq_lens_cpu_backup + batch.req_pool_indices = req_pool_indices_backup + + def draft(self, batch: ScheduleBatch): + if batch.forward_mode.is_idle(): + return FrozenKVMTPVerifyInput.create_idle_input( + self.topk, + self.speculative_num_steps, + self.speculative_num_draft_tokens, + ) + + batch.maybe_evict_swa() + for req in batch.reqs: + req.decode_batch_idx += 1 + + spec_info = batch.spec_info + assert isinstance(spec_info, FrozenKVMTPDraftInput) + + if batch.sampling_info.penalizer_orchestrator.is_required: + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + spec_info.verified_id.to(torch.int64) + ) + + spec_info.capture_hidden_mode = CaptureHiddenMode.LAST + spec_info.num_tokens_per_req = self.topk + spec_info.num_tokens_for_logprob_per_req = self.topk + spec_info.positions = self._position_for_batch(batch) + batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + batch.return_hidden_states = False + + model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST + forward_batch = ForwardBatch.init_new( + model_worker_batch, self.draft_model_runner + ) + self._set_positions(forward_batch) + self._expand_for_topk_draft(forward_batch) + + can_run_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run( + forward_batch + ) + if can_run_cuda_graph: + parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay( + forward_batch + ) + else: + forward_batch.can_run_dp_cuda_graph = False + parent_list, top_scores_index, draft_tokens = self.draft_forward( + forward_batch + ) + + ( + tree_mask, + position, + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + draft_tokens, + ) = build_tree_kernel_efficient( + spec_info.verified_id, + parent_list, + top_scores_index, + draft_tokens, + batch.seq_lens, + batch.seq_lens_sum, + self.topk, + self.speculative_num_steps, + self.speculative_num_draft_tokens, + ) + + return FrozenKVMTPVerifyInput( + draft_token=draft_tokens, + custom_mask=tree_mask, + positions=position, + retrieve_index=retrieve_index, + retrieve_next_token=retrieve_next_token, + retrieve_next_sibling=retrieve_next_sibling, + retrieve_cum_len=None, + spec_steps=self.speculative_num_steps, + topk=self.topk, + draft_token_num=self.speculative_num_draft_tokens, + capture_hidden_mode=CaptureHiddenMode.FULL, + seq_lens_sum=batch.seq_lens_sum, + seq_lens_cpu=batch.seq_lens_cpu, + ) + + def draft_forward( + self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False + ): + spec_info = forward_batch.spec_info + assert isinstance(spec_info, FrozenKVMTPDraftInput) + topk_p, topk_index, hidden_states = ( + spec_info.topk_p, + spec_info.topk_index, + spec_info.hidden_states, + ) + maybe_detect_nan(topk_p, "frozen_kv_mtp_draft: initial topk_p") + + score_list: List[torch.Tensor] = [] + token_list: List[torch.Tensor] = [] + parents_list: List[torch.Tensor] = [] + + if not skip_attn_backend_init and self.speculative_num_steps > 1: + self._init_frozen_kv_metadata(forward_batch) + + scores = None + for i in range(self.speculative_num_steps): + input_ids, hidden_states, scores, tree_info = select_top_k_tokens( + i, topk_p, topk_index, hidden_states, scores, self.topk + ) + score_list.append(tree_info[0]) + token_list.append(tree_info[1]) + parents_list.append(tree_info[2]) + + if i == self.speculative_num_steps - 1: + break + + forward_batch.input_ids = input_ids + forward_batch.spec_info.hidden_states = hidden_states + self._set_positions(forward_batch) + + with self._target_kv_pool_view(forward_batch): + logits_output = self.draft_model_runner.forward( + forward_batch, skip_attn_backend_init=True + ).logits_output + + maybe_detect_nan( + logits_output.next_token_logits, f"frozen_kv_mtp_draft step {i}" + ) + probs = torch.softmax(logits_output.next_token_logits, dim=-1) + topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + maybe_detect_oob( + topk_index, + 0, + logits_output.next_token_logits.shape[-1], + "frozen_kv_mtp_draft: topk_index OOB", + ) + hidden_states = logits_output.hidden_states + + return organize_draft_results( + score_list, token_list, parents_list, self.speculative_num_draft_tokens + ) + + def verify(self, batch: ScheduleBatch, spec_info: FrozenKVMTPVerifyInput): + seq_lens_pre_verify = batch.seq_lens.clone() + spec_info.prepare_for_verify(batch, self.page_size) + spec_info.num_tokens_per_req = self.speculative_num_steps + 1 + batch.return_hidden_states = False + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.spec_info = spec_info + + model_worker_batch = batch.get_model_worker_batch( + seq_lens_cpu_cache=spec_info.seq_lens_cpu + ) + assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode + + if batch.has_grammar: + retrieve_next_token_cpu = spec_info.retrieve_next_token.cpu() + retrieve_next_sibling_cpu = spec_info.retrieve_next_sibling.cpu() + draft_tokens_cpu = spec_info.draft_token.view( + spec_info.retrieve_next_token.shape + ).cpu() + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True + ) + logits_output, can_run_cuda_graph = ( + batch_result.logits_output, + batch_result.can_run_cuda_graph, + ) + + vocab_mask = None + if batch.has_grammar: + vocab_mask = generate_token_bitmask( + batch.reqs, + spec_info, + retrieve_next_token_cpu, + retrieve_next_sibling_cpu, + draft_tokens_cpu, + batch.sampling_info.vocab_size, + ) + if vocab_mask is not None: + assert spec_info.grammar is not None + vocab_mask = vocab_mask.to(spec_info.retrieve_next_token.device) + batch.sampling_info.vocab_mask = None + + maybe_detect_nan(logits_output.next_token_logits, "frozen_kv_mtp_verify") + + spec_info.hidden_states = logits_output.hidden_states + res: FrozenKVMTPVerifyOutput = spec_info.verify( + batch, + logits_output, + self.token_to_kv_pool_allocator, + self.page_size, + vocab_mask, + ) + + logits_output.next_token_logits = logits_output.next_token_logits[ + res.accepted_indices + ] + logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] + + if ( + self.target_worker.model_runner.hybrid_gdn_config is not None + or self.target_worker.model_runner.mamba2_config is not None + or self.target_worker.model_runner.hybrid_lightning_config is not None + ): + logger.warning( + "Frozen-KV MTP does not implement mamba state updates; " + "targets with recurrent state should not use this path." + ) + + if batch.return_logprob: + add_output_logprobs_for_spec_v1(batch, res, logits_output) + + batch.forward_mode = ( + ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE + ) + batch.spec_info = res.draft_input + + del seq_lens_pre_verify + return logits_output, res, model_worker_batch, can_run_cuda_graph diff --git a/python/sglang/srt/speculative/frozen_kv_mtp_worker_v2.py b/python/sglang/srt/speculative/frozen_kv_mtp_worker_v2.py new file mode 100644 index 000000000000..958d09fb4aa7 --- /dev/null +++ b/python/sglang/srt/speculative/frozen_kv_mtp_worker_v2.py @@ -0,0 +1,42 @@ +# Copyright 2026 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Overlap-scheduling placeholder for frozen-KV MTP (raises until implemented).""" + +from __future__ import annotations + +from typing import Optional + +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.frozen_kv_mtp_worker import FrozenKVMTPWorker + + +class FrozenKVMTPWorkerV2(FrozenKVMTPWorker): + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + raise NotImplementedError( + "FrozenKVMTPWorkerV2 (overlap scheduling for Frozen-KV MTP) is " + "not yet implemented. Pass --disable-overlap-schedule to use " + "FrozenKVMTPWorker." + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 3e5727187572..a0e6c9c24b7c 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -13,11 +13,12 @@ class SpeculativeAlgorithm(Enum): - """Enumeration of speculative decoding algorithms.""" + """Speculative decoding algorithms.""" DFLASH = auto() EAGLE = auto() EAGLE3 = auto() + FROZEN_KV_MTP = auto() STANDALONE = auto() NGRAM = auto() NONE = auto() @@ -38,12 +39,20 @@ def is_speculative(self) -> bool: return self != SpeculativeAlgorithm.NONE def is_eagle(self) -> bool: - # NOTE: EAGLE3 is a variant of EAGLE - return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 + # FIXME(kpham_sgl): Remove FROZEN_KV_MTP here once we + # have established support for it in the scheduler. + return self in ( + SpeculativeAlgorithm.EAGLE, + SpeculativeAlgorithm.EAGLE3, + SpeculativeAlgorithm.FROZEN_KV_MTP, + ) def is_eagle3(self) -> bool: return self == SpeculativeAlgorithm.EAGLE3 + def is_frozen_kv_mtp(self) -> bool: + return self == SpeculativeAlgorithm.FROZEN_KV_MTP + def is_dflash(self) -> bool: return self == SpeculativeAlgorithm.DFLASH @@ -54,7 +63,7 @@ def is_ngram(self) -> bool: return self == SpeculativeAlgorithm.NGRAM def supports_spec_v2(self) -> bool: - return self.is_eagle() or self.is_standalone() + return (self.is_eagle() and not self.is_frozen_kv_mtp()) or self.is_standalone() def create_worker( self, server_args: ServerArgs @@ -74,6 +83,19 @@ def create_worker( return DFlashWorker + if self.is_frozen_kv_mtp(): + if enable_overlap: + raise ValueError( + "FROZEN_KV_MTP does not support spec v2. Disable overlap " + "scheduling to use FrozenKVMTPWorker." + ) + + from sglang.srt.speculative.frozen_kv_mtp_worker import ( + FrozenKVMTPWorker, + ) + + return FrozenKVMTPWorker + if self.is_eagle() and server_args.enable_multi_layer_eagle: # FIXME: migrate to EagleWorker if enable_overlap: @@ -127,6 +149,8 @@ class SpecInputType(IntEnum): # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() EAGLE_VERIFY = auto() + FROZEN_KV_MTP_DRAFT = auto() + FROZEN_KV_MTP_VERIFY = auto() DFLASH_DRAFT = auto() DFLASH_VERIFY = auto() NGRAM_VERIFY = auto() @@ -141,12 +165,14 @@ def is_draft_input(self) -> bool: # or use another variable name like `draft_input` to substitute `spec_info` return self.spec_input_type in { SpecInputType.EAGLE_DRAFT, + SpecInputType.FROZEN_KV_MTP_DRAFT, SpecInputType.DFLASH_DRAFT, } def is_verify_input(self) -> bool: return self.spec_input_type in { SpecInputType.EAGLE_VERIFY, + SpecInputType.FROZEN_KV_MTP_VERIFY, SpecInputType.DFLASH_VERIFY, SpecInputType.NGRAM_VERIFY, } diff --git a/python/sglang/srt/utils/hf_transformers/config.py b/python/sglang/srt/utils/hf_transformers/config.py index 479199669418..448113fa4fd4 100644 --- a/python/sglang/srt/utils/hf_transformers/config.py +++ b/python/sglang/srt/utils/hf_transformers/config.py @@ -141,7 +141,7 @@ def get_config( if config.model_type == "multi_modality": _set_architectures(config, "MultiModalityCausalLM") - if config.model_type == "gemma4": + if config.model_type in ("gemma4", "gemma4_assistant"): # Gemma4 configs use base attributes for SWA layers and `global_*` # variants for full-attention layers. SGLang expects the opposite: # base = full-attention, `swa_*` = sliding-window overrides.