diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 4fe8aec31301..c1e2ea4fcdab 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -596,8 +596,24 @@ def init_forward_metadata_capture_cuda_graph( fast_decode_plan, decode_wrappers[i] ) elif forward_mode.is_target_verify(): + # FlashInfer's prefill wrapper decides mask mode based on whether + # `custom_mask_buf` is initialized (not whether a custom mask is provided). + # For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a + # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise + # FlashInfer will treat the (zero) buffer as a real mask and block attention. + use_custom_mask = ( + spec_info is not None + and getattr(spec_info, "custom_mask", None) is not None + ) prefill_wrappers = [] for i in range(self.num_wrappers): + wrapper_kwargs = {} + if use_custom_mask: + wrapper_kwargs = { + "custom_mask_buf": self.cuda_graph_custom_mask, + "mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1], + } + prefill_wrappers.append( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, @@ -608,8 +624,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], - custom_mask_buf=self.cuda_graph_custom_mask, - mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + **wrapper_kwargs, ) ) seq_lens_sum = seq_lens.sum().item() @@ -783,10 +798,14 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + causal = ( + not layer.is_cross_attention + and layer.attn_type != AttentionType.ENCODER_ONLY + ) o = prefill_wrapper_paged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, + causal=causal, sm_scale=layer.scaling, # Disable sliding window attention for multi-item scoring: # - Sliding window could cut across item boundaries, breaking semantic coherence @@ -838,11 +857,6 @@ def forward_extend( ) else: - if not self.is_dllm_model: - # TODO: design a better interface - # For other models, use causal attention for the ragged part as previously - causal = True - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 36c55826d821..377a7ec749fb 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -276,6 +276,24 @@ def copy_to_cpu(self): self.copy_done.record() +def validate_dflash_request(req: Req) -> Optional[str]: + if req.return_logprob: + return "DFLASH speculative decoding does not support return_logprob yet." + + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + return ( + "DFLASH speculative decoding does not support " + "grammar-constrained decoding yet." + ) + + return None + + class Scheduler( SchedulerOutputProcessorMixin, SchedulerUpdateWeightsMixin, @@ -1861,6 +1879,14 @@ def handle_generate_request( self._add_request_to_queue(req) return + if self.spec_algorithm.is_dflash(): + error_msg = validate_dflash_request(req) + if error_msg is not None: + req.set_finish_with_abort(error_msg) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return + # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c7c7d6b5ec0b..69cb176efbdc 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -547,18 +547,15 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 - if ( - model_runner.spec_algorithm.is_eagle() - or model_runner.spec_algorithm.is_standalone() - or model_runner.spec_algorithm.is_ngram() - ): + if model_runner.spec_algorithm.is_speculative(): if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) + # DFLASH draft workers reuse this runner for TARGET_VERIFY mode. + if not self.model_runner.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) elif self.is_dllm: self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size @@ -646,6 +643,18 @@ def __init__(self, model_runner: ModelRunner): and model_runner.eagle_use_aux_hidden_state ): self.model_runner.model.set_eagle3_layers_to_capture() + if ( + model_runner.spec_algorithm.is_dflash() + and model_runner.dflash_use_aux_hidden_state + ): + if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH aux hidden capture." + ) + self.model_runner.model.set_dflash_layers_to_capture( + self.model_runner.dflash_target_layer_ids + ) # Capture try: @@ -671,6 +680,7 @@ def can_run(self, forward_batch: ForwardBatch): max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max(forward_batch.global_num_tokens_cpu) ) else: @@ -1007,6 +1017,12 @@ def run_once(): kwargs["pp_proxy_tensors"] = PPProxyTensors( {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and "input_embeds" in inspect.signature(forward).parameters + ): + kwargs["input_embeds"] = buffers.input_embeds[:num_tokens] logits_output_or_pp_proxy_tensors = forward( input_ids, @@ -1083,6 +1099,7 @@ def replay_prepare( max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max_num_tokens ) index = bisect.bisect_left(self.capture_bs, max_batch_size) @@ -1104,6 +1121,13 @@ def replay_prepare( ), pp_proxy_tensors=pp_proxy_tensors, ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds) + # Padded tokens aren't read, so skip zeroing them. if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, @@ -1152,6 +1176,14 @@ def replay( # In speculative decoding, these two fields are still needed. self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + self.buffers.input_embeds[: self.raw_num_token].copy_( + forward_batch.input_embeds + ) # Replay if self.enable_pdmux: @@ -1164,10 +1196,18 @@ def replay( if isinstance(output, LogitsProcessorOutput): if self.is_dllm: next_token_logits = None - full_logits = output.full_logits[: self.raw_num_token] + full_logits = ( + output.full_logits[: self.raw_num_token] + if output.full_logits is not None + else None + ) else: full_logits = None - next_token_logits = output.next_token_logits[: self.raw_num_token] + next_token_logits = ( + output.next_token_logits[: self.raw_num_token] + if output.next_token_logits is not None + else None + ) return LogitsProcessorOutput( next_token_logits=next_token_logits, @@ -1209,6 +1249,32 @@ def get_spec_info(self, num_tokens: int): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.model_runner.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_verify_mask_policy, + ) + + # Avoid enabling custom-mask modes during graph capture for backends that + # can express DFLASH verify via their built-in causal path. + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + custom_mask=( + None + if (self.model_runner.is_draft_worker or not build_custom_mask) + else self.buffers.custom_mask + ), + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.model_runner.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.model_runner.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 26d8bd82a2f1..e2dadafe62c3 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -354,6 +354,9 @@ def __init__( self.remote_instance_transfer_engine_weight_info = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False + self.dflash_use_aux_hidden_state = False + self.dflash_target_layer_ids = None + self.dflash_draft_num_layers = None if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -379,6 +382,52 @@ def __init__( # if there is no aux layer, set to None self.eagle_aux_hidden_state_layer_ids = None + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + # Select target layers to capture for building DFlash context features. + draft_model_config = ModelConfig.from_server_args( + server_args, + model_path=(server_args.speculative_draft_model_path), + model_revision=server_args.speculative_draft_model_revision, + is_draft_model=True, + ) + dflash_draft_config = parse_dflash_draft_config( + draft_hf_config=draft_model_config.hf_config + ) + draft_num_layers = dflash_draft_config.require_num_layers() + trained_target_layers = dflash_draft_config.num_target_layers + + target_num_layers = getattr( + self.model_config.hf_text_config, "num_hidden_layers", None + ) + if target_num_layers is None: + raise ValueError( + "DFLASH requires target num_hidden_layers in config. " + f"Got target={target_num_layers}." + ) + target_num_layers = int(target_num_layers) + + if ( + trained_target_layers is not None + and trained_target_layers != target_num_layers + ): + logger.warning( + "DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; " + "selecting capture layers based on the runtime target model.", + trained_target_layers, + target_num_layers, + ) + + self.dflash_use_aux_hidden_state = True + self.dflash_draft_num_layers = int(draft_num_layers) + self.dflash_target_layer_ids = dflash_draft_config.resolve_target_layer_ids( + target_num_layers=int(target_num_layers), + draft_num_layers=int(draft_num_layers), + ) + # Apply the rank zero filter to logger if server_args.show_time_cost: enable_show_time_cost() @@ -670,6 +719,14 @@ def initialize(self, pre_model_load_memory: float): self.eagle_aux_hidden_state_layer_ids ) + if self.dflash_use_aux_hidden_state: + if not hasattr(self.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH." + ) + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) + # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -2100,11 +2157,7 @@ def _should_run_flashinfer_autotune(self) -> bool: if major < 9: return False - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): return not self.is_draft_worker return True @@ -2134,16 +2187,12 @@ def _dummy_run(self, batch_size: int, run_ctx=None): capture_forward_mode = ForwardMode.EXTEND capture_hidden_mode = CaptureHiddenMode.NULL num_tokens_per_bs = 1 - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): if self.is_draft_worker: - raise RuntimeError("This should not happen") - else: - capture_forward_mode = ForwardMode.TARGET_VERIFY - num_tokens_per_bs = self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + capture_forward_mode = ForwardMode.TARGET_VERIFY + num_tokens_per_bs = self.server_args.speculative_num_draft_tokens if self.server_args.enable_return_hidden_states: capture_hidden_mode = CaptureHiddenMode.FULL @@ -2173,6 +2222,8 @@ def _dummy_run(self, batch_size: int, run_ctx=None): if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() + if self.dflash_use_aux_hidden_state: + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args) if require_gathered_buffer(self.server_args): @@ -2286,6 +2337,21 @@ def get_spec_info(): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + + # Dummy warmup only needs shape metadata; avoid forcing custom-mask mode. + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.server_args.speculative_num_draft_tokens, + custom_mask=None, + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index a6baa4817ace..bca2baca64f9 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -167,6 +167,22 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): num_layers = self.num_effective_layers cell_size = self.get_cell_size_per_token(num_layers) + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + scale_kv_cell_size_per_token_for_dflash, + ) + + draft_num_layers = getattr(self, "dflash_draft_num_layers", None) + if ( + draft_num_layers is not None + and int(draft_num_layers) > 0 + and int(num_layers) > 0 + ): + cell_size = scale_kv_cell_size_per_token_for_dflash( + target_cell_size_per_token=cell_size, + target_num_layers=int(num_layers), + draft_num_layers=int(draft_num_layers), + ) rest_memory = post_model_load_memory - pre_model_load_memory * ( 1 - self.mem_fraction_static diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py new file mode 100644 index 000000000000..27f5cdbf539d --- /dev/null +++ b/python/sglang/srt/models/dflash.py @@ -0,0 +1,399 @@ +# Adapted from the DFlash reference implementation (HF) but implemented with +# SGLang primitives (RadixAttention + SGLang KV cache). This model intentionally +# does not include token embeddings or an LM head; DFlash uses the target model's +# embedding/lm_head. + +from __future__ import annotations + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.radix_attention import AttentionType, RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import apply_qk_norm +from sglang.srt.speculative.dflash_utils import ( + can_dflash_slice_qkv_weight, + parse_dflash_draft_config, +) + +logger = logging.getLogger(__name__) + + +class DFlashAttention(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + tp_size = int(get_tensor_model_parallel_world_size()) + total_num_heads = int(config.num_attention_heads) + total_num_kv_heads = int( + getattr(config, "num_key_value_heads", total_num_heads) + ) + head_dim = int(getattr(config, "head_dim", hidden_size // total_num_heads)) + + self.hidden_size = hidden_size + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + assert self.total_num_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_heads divisible by tp_size. " + f"total_num_heads={self.total_num_heads}, tp_size={tp_size}." + ) + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_kv_heads divisible by tp_size when >= tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + else: + assert tp_size % self.total_num_kv_heads == 0, ( + f"DFlashAttention requires tp_size divisible by total_num_kv_heads when total_num_kv_heads < tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim + + attention_bias = bool(getattr(config, "attention_bias", False)) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=attention_bias, + prefix="qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * head_dim, + hidden_size, + bias=attention_bias, + prefix="o_proj", + ) + + # Per-head Q/K RMSNorm, matching HF Qwen3. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + rope_theta = float(getattr(config, "rope_theta", 1000000)) + rope_scaling = getattr(config, "rope_scaling", None) + rope_is_neox_style = bool( + getattr( + config, "rope_is_neox_style", getattr(config, "is_neox_style", True) + ) + ) + max_position_embeddings = int(getattr(config, "max_position_embeddings", 32768)) + self.rotary_emb = get_rope( + head_dim, + rotary_dim=head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + + self.scaling = head_dim**-0.5 + # DFlash uses non-causal attention over the draft block. + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + attn_type=AttentionType.ENCODER_ONLY, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = apply_qk_norm(q, k, self.q_norm, self.k_norm, self.head_dim) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + def kv_proj_only( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Project hidden_states to K/V only (skip Q). + + This is used by DFlash to materialize ctx tokens into the draft KV cache: + we only need K/V for the cached tokens; Q is never consumed. + """ + # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. + can_slice_qkv_weight, _ = can_dflash_slice_qkv_weight(self.qkv_proj) + if can_slice_qkv_weight: + kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) + weight = self.qkv_proj.weight[kv_slice] + bias = ( + self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + ) + kv = F.linear(hidden_states, weight, bias) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + return k, v + + # Fallback: compute full QKV and discard Q (keeps compatibility with quantized weights). + qkv, _ = self.qkv_proj(hidden_states) + _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return k, v + + def apply_k_norm(self, k: torch.Tensor) -> torch.Tensor: + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + return k_by_head.view_as(k) + + def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Use a minimal dummy query (1 head) to avoid doing full-Q work. + dummy_q = k.new_empty((k.shape[0], self.head_dim)) + _, k = self.rotary_emb(positions, dummy_q, k) + return k + + +class DFlashMLP(nn.Module): + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + hidden_size = int(config.hidden_size) + intermediate_size = int(getattr(config, "intermediate_size", 0)) + if intermediate_size <= 0: + raise ValueError( + f"Invalid intermediate_size={intermediate_size} for DFlash MLP." + ) + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix="gate_up_proj" if not prefix else f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix="down_proj" if not prefix else f"{prefix}.down_proj", + ) + hidden_act = getattr(config, "hidden_act", "silu") + if hidden_act != "silu": + raise ValueError( + f"Unsupported DFlash activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DFlashDecoderLayer(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = DFlashAttention(config=config, layer_id=layer_id) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = DFlashMLP(config=config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.numel() == 0: + # Keep return types consistent for upstream callers. + if residual is None: + residual = hidden_states + return hidden_states, residual + + # Pre-norm attention with fused residual+norm when possible (Qwen3-style). + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_out = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states, residual = self.post_attention_layernorm(attn_out, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DFlashDraftModel(nn.Module): + """SGLang DFlash draft model (no embedding / lm_head weights). + + The checkpoint provides: + - transformer weights for `layers.*` + - `fc.weight`, `hidden_norm.weight` for projecting target context features + - `norm.weight` for final normalization + """ + + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + self.config = config + + hidden_size = int(config.hidden_size) + num_layers = int(config.num_hidden_layers) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] + ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + # Project per-token target context features: + # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer + # feature tensors concatenated per token (not necessarily equal to num_layers). + draft_config = parse_dflash_draft_config(draft_hf_config=config) + target_num_layers = ( + int(draft_config.num_target_layers) + if draft_config.num_target_layers is not None + else num_layers + ) + target_layer_ids = draft_config.resolve_target_layer_ids( + target_num_layers=target_num_layers, draft_num_layers=num_layers + ) + num_context_features = len(target_layer_ids) + + self.num_context_features = int(num_context_features) + self.fc = nn.Linear( + self.num_context_features * hidden_size, hidden_size, bias=False + ) + self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.block_size = draft_config.resolve_block_size(default=16) + + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: + """Project concatenated target-layer hidden states into draft hidden_size.""" + expected = int(self.fc.in_features) + if target_hidden.ndim != 2 or int(target_hidden.shape[-1]) != expected: + raise ValueError( + "DFLASH target_hidden feature dim mismatch. " + f"Expected shape [N, {expected}] " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got shape={tuple(target_hidden.shape)}. " + "This usually means the target model is capturing a different number of layer features than " + "the draft checkpoint/config expects." + ) + return self.hidden_norm(self.fc(target_hidden)) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + pp_proxy_tensors=None, + ) -> LogitsProcessorOutput: + if input_embeds is None: + raise ValueError( + "DFlashDraftModel requires `input_embeds` (use the target embedding)." + ) + hidden_states = input_embeds + residual: Optional[torch.Tensor] = None + + for layer in self.layers: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + if hidden_states.numel() != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return LogitsProcessorOutput( + next_token_logits=None, + hidden_states=hidden_states, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + def resolve_param_name(name: str) -> Optional[str]: + if name in params_dict: + return name + if name.startswith("model."): + stripped_name = name[len("model.") :] + if stripped_name in params_dict: + return stripped_name + else: + prefixed_name = f"model.{name}" + if prefixed_name in params_dict: + return prefixed_name + return None + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if f".{weight_name}." not in name: + continue + mapped_name = name.replace(weight_name, param_name) + resolved_name = resolve_param_name(mapped_name) + if resolved_name is None: + continue + param = params_dict[resolved_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + resolved_name = resolve_param_name(name) + if resolved_name is None: + # Ignore unexpected weights (e.g., HF rotary caches). + continue + param = params_dict[resolved_name] + if resolved_name.endswith("fc.weight") and tuple( + loaded_weight.shape + ) != tuple(param.shape): + raise ValueError( + "DFLASH fc.weight shape mismatch. This usually means the draft checkpoint's " + "number of context features (K) does not match this config. " + f"Expected fc.weight.shape={tuple(param.shape)} " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got {tuple(loaded_weight.shape)} for weight '{name}'." + ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DFlashDraftModel diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index f955ac750d34..b8ad74015c6e 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -794,6 +794,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 86de1b32713c..9ec274d2a75c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -499,6 +499,8 @@ class ServerArgs: speculative_num_steps: Optional[int] = None speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None + speculative_dflash_block_size: Optional[int] = None + speculative_dflash_draft_window_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -3027,6 +3029,134 @@ def _handle_speculative_decoding(self): if self.speculative_algorithm == "NEXTN": self.speculative_algorithm = "EAGLE" + if self.speculative_algorithm == "DFLASH": + if self.enable_dp_attention: + raise ValueError( + "Currently DFLASH speculative decoding does not support dp attention." + ) + + if self.pp_size != 1: + raise ValueError( + "Currently DFLASH speculative decoding only supports pp_size == 1." + ) + + if self.speculative_draft_model_path is None: + raise ValueError( + "DFLASH speculative decoding requires setting --speculative-draft-model-path." + ) + + # DFLASH does not use EAGLE-style `num_steps`/`topk`, but those fields still + # affect generic scheduler/KV-cache accounting (buffer sizing, KV freeing, + # RoPE reservation). Force them to 1 to avoid surprising memory behavior. + # + # For DFlash, the natural unit is `block_size` (verify window length). + if self.speculative_num_steps is None: + self.speculative_num_steps = 1 + elif int(self.speculative_num_steps) != 1: + logger.warning( + "DFLASH only supports speculative_num_steps == 1; overriding speculative_num_steps=%s to 1.", + self.speculative_num_steps, + ) + self.speculative_num_steps = 1 + + if self.speculative_eagle_topk is None: + self.speculative_eagle_topk = 1 + elif int(self.speculative_eagle_topk) != 1: + logger.warning( + "DFLASH only supports speculative_eagle_topk == 1; overriding speculative_eagle_topk=%s to 1.", + self.speculative_eagle_topk, + ) + self.speculative_eagle_topk = 1 + + if self.speculative_dflash_block_size is not None: + if int(self.speculative_dflash_block_size) <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-block-size to be positive, " + f"got {self.speculative_dflash_block_size}." + ) + if self.speculative_num_draft_tokens is not None and int( + self.speculative_num_draft_tokens + ) != int(self.speculative_dflash_block_size): + raise ValueError( + "Both --speculative-num-draft-tokens and --speculative-dflash-block-size are set " + "but they differ. For DFLASH they must match. " + f"speculative_num_draft_tokens={self.speculative_num_draft_tokens}, " + f"speculative_dflash_block_size={self.speculative_dflash_block_size}." + ) + self.speculative_num_draft_tokens = int( + self.speculative_dflash_block_size + ) + + window_size = None + if self.speculative_dflash_draft_window_size is not None: + window_size = int(self.speculative_dflash_draft_window_size) + if window_size <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-draft-window-size " + f"to be positive, got {window_size}." + ) + self.speculative_dflash_draft_window_size = window_size + + if self.speculative_num_draft_tokens is None: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + model_override_args = json.loads(self.json_model_override_args) + inferred_block_size = None + try: + from sglang.srt.utils.hf_transformers_utils import get_config + + draft_hf_config = get_config( + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + revision=self.speculative_draft_model_revision, + model_override_args=model_override_args, + ) + inferred_block_size = parse_dflash_draft_config( + draft_hf_config=draft_hf_config + ).resolve_block_size(default=None) + except Exception as e: + logger.warning( + "Failed to infer DFLASH block_size from draft model config; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", + e, + ) + + if inferred_block_size is None: + inferred_block_size = 16 + logger.warning( + "speculative_num_draft_tokens is not set; defaulting to %d for DFLASH.", + inferred_block_size, + ) + self.speculative_num_draft_tokens = inferred_block_size + + if window_size is not None: + draft_tokens = int(self.speculative_num_draft_tokens) + if window_size < draft_tokens: + raise ValueError( + "DFLASH --speculative-dflash-draft-window-size must be >= " + "--speculative-num-draft-tokens (block_size). " + f"window_size={window_size}, block_size={draft_tokens}." + ) + + if self.max_running_requests is None: + self.max_running_requests = 48 + logger.warning( + "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." + ) + + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)." + ) + + if self.enable_mixed_chunk: + self.enable_mixed_chunk = False + logger.warning( + "Mixed chunked prefill is disabled because of using dflash speculative decoding." + ) + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding @@ -4832,7 +4962,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], + choices=["DFLASH", "EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( @@ -4876,6 +5006,21 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-dflash-block-size", + type=int, + help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", + default=ServerArgs.speculative_dflash_block_size, + ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + help="DFLASH only. Sliding window size for the draft-model KV cache. " + "When set, the draft worker keeps a recent target-token window in its " + "local cache (paged backends may retain up to one extra page on the left " + "for alignment). Default is full context.", + default=ServerArgs.speculative_dflash_draft_window_size, + ) parser.add_argument( "--speculative-accept-threshold-single", type=float, diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py new file mode 100644 index 000000000000..fbb06cc70ee1 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info.py @@ -0,0 +1,501 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import torch + +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import apply_custom_logit_processor +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.dflash_utils import ( + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func + + +def _compute_paged_keep_slots( + *, + prefix_lens: torch.Tensor, + commit_lens: torch.Tensor, + draft_token_num: int, + page_size: int, +) -> torch.Tensor: + """Compute how many draft slots per request must remain allocated. + + The allocator frees at page granularity for paged mode, so we can only release + full pages from the tail after verify. + """ + + if page_size <= 1: + raise ValueError(f"Expected page_size > 1, got {page_size}.") + + seq_dtype = prefix_lens.dtype + extended_lens = prefix_lens + int(draft_token_num) + new_lens = prefix_lens + commit_lens.to(seq_dtype) + aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size + keep_lens = torch.minimum(aligned_new_lens, extended_lens) + keep_slots = (keep_lens - prefix_lens).to(torch.int64) + keep_slots.clamp_(min=0, max=int(draft_token_num)) + return keep_slots + + +@dataclass +class DFlashDraftInput(SpecInput): + """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. + + This object is stored on `ScheduleBatch.spec_info` between decode iterations. + It is NOT sent to model attention backends; the DFlash worker uses it to run + the draft model and to track draft-side cache progress. + + When draft windowing is disabled, `draft_seq_lens` matches the committed target + prefix length already materialized in the draft KV cache. When windowing is + enabled, `draft_seq_lens` is the logical resident length in the draft worker's + compact req-to-token mapping. In paged mode this may exceed the requested + window by up to `page_size - 1` so the local page table remains valid. `ctx_lens` + tracks newly committed target tokens that still need draft KV materialization. + """ + + # Current token to start the next DFlash block (one per request). + verified_id: torch.Tensor + + # Flattened context features for tokens that need to be appended into the draft cache. + # Shape: [sum(ctx_lens), K * hidden_size], where K is the number of target-layer + # hidden-state features concatenated per token (len(dflash_config.target_layer_ids), + # or default K == draft_num_layers for existing checkpoints). + target_hidden: torch.Tensor + + # Context lengths per request, used to slice `target_hidden`. Device tensor (int32). + ctx_lens: torch.Tensor + + # How many committed tokens are visible to the draft worker per request. + draft_seq_lens: torch.Tensor + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Draft state does not change token accounting. + return (1, 1) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + old_ctx_lens = self.ctx_lens + old_target_hidden = self.target_hidden + + self.verified_id = self.verified_id[new_indices] + self.ctx_lens = old_ctx_lens[new_indices] + self.draft_seq_lens = self.draft_seq_lens[new_indices] + + if old_target_hidden is None or old_target_hidden.numel() == 0: + self.target_hidden = old_target_hidden + return + + # Rebuild target_hidden for the filtered batch using vectorized indexing. + old_bs = int(old_ctx_lens.shape[0]) + offsets = torch.zeros( + (old_bs + 1,), dtype=torch.int64, device=old_ctx_lens.device + ) + offsets[1:].copy_(old_ctx_lens.to(torch.int64).cumsum(0)) + + start = offsets[:-1] + seg_start = start[new_indices] + seg_lens = old_ctx_lens[new_indices].to(torch.int64) + + max_len = int(seg_lens.max().item()) if seg_lens.numel() > 0 else 0 + if max_len <= 0: + self.target_hidden = old_target_hidden[:0] + return + + r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[ + None, : + ] + pos2d = seg_start[:, None] + r + mask = r < seg_lens[:, None] + flat_pos = pos2d[mask] + self.target_hidden = ( + old_target_hidden.index_select(0, flat_pos) + if flat_pos.numel() > 0 + else old_target_hidden[:0] + ) + + def merge_batch(self, spec_info: "DFlashDraftInput"): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.ctx_lens = torch.cat([self.ctx_lens, spec_info.ctx_lens], dim=0) + self.draft_seq_lens = torch.cat( + [self.draft_seq_lens, spec_info.draft_seq_lens], dim=0 + ) + if self.target_hidden is None or self.target_hidden.numel() == 0: + self.target_hidden = spec_info.target_hidden + elif ( + spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0 + ): + self.target_hidden = torch.cat( + [self.target_hidden, spec_info.target_hidden], dim=0 + ) + + +@dataclass +class DFlashVerifyInput(SpecInput): + """Inputs for a target-model verify forward in DFlash (spec-v1). + + The verify forward is run with `ForwardMode.TARGET_VERIFY` so that the target + model returns logits for all tokens in the block, enabling accept-length + computation. + """ + + draft_token: torch.Tensor + positions: torch.Tensor + draft_token_num: int + # Kept for compatibility with attention backends that gate tree metadata by `topk > 1`. + # DFLASH verify is linear (non-tree), so this is always 1. + topk: int = 1 + # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). + # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. + custom_mask: torch.Tensor | None = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + + # Shape info for padding (e.g., DP attention / CUDA graph). + num_tokens_per_batch: int = -1 + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_VERIFY) + if self.num_tokens_per_batch == -1: + self.num_tokens_per_batch = int(self.draft_token_num) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + + def prepare_for_verify( + self, + batch: ScheduleBatch, + page_size: int, + *, + build_custom_mask: bool = True, + ): + if batch.forward_mode.is_idle(): + return + + batch.input_ids = self.draft_token + + if page_size == 1: + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, len(batch.input_ids) + ) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu + end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), + ) + self.last_loc = last_loc + + bs = batch.batch_size() + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + if not build_custom_mask: + self.custom_mask = None + return + + if self.draft_token_num <= 0: + raise ValueError( + f"DFLASH draft_token_num must be positive, got {self.draft_token_num}." + ) + mask_chunks: List[torch.Tensor] = [] + q_len = int(self.draft_token_num) + q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) + for prefix_len in batch.seq_lens_cpu.tolist(): + prefix_len_i = int(prefix_len) + kv_len = prefix_len_i + q_len + k_idx = torch.arange( + kv_len, device=batch.device, dtype=torch.int32 + ).unsqueeze(0) + # Allow attending to the full prefix and to tokens up to (and including) the + # current query position within the verify block (standard causal masking). + allow = k_idx <= (prefix_len_i + q_idx) + mask_chunks.append(allow.flatten()) + self.custom_mask = ( + torch.cat(mask_chunks, dim=0) + if mask_chunks + else torch.empty((0,), dtype=torch.bool, device=batch.device) + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + device = req_pool_indices.device + bs = len(req_pool_indices) + + qo_indptr = torch.arange( + 0, + (bs + 1) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device=device, + ) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + self.draft_token_num * bs, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + mask = self.custom_mask + if mask is not None: + mask_numel = ( + paged_kernel_lens_sum * self.draft_token_num + + (self.draft_token_num**2) * bs + ) + if mask.numel() < mask_numel: + # FIXME(attn): temporary fix for custom mask padding with cuda graph + mask = torch.cat( + [ + mask, + torch.full( + (mask_numel - mask.numel(),), + True, + dtype=torch.bool, + device=device, + ), + ], + dim=0, + ) + self.custom_mask = mask + return kv_indices, cum_kv_seq_len, qo_indptr, mask + + def verify( + self, + *, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + page_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """DFlash verification for greedy and non-greedy sampling. + + Returns: + new_verified_id: int64 tensor [bs] (the new current token per request) + commit_lens: int32 tensor [bs] (how many verify-input tokens are committed) + next_target_hidden: tensor [sum(commit_lens), feature_dim] + accept_length_per_req_cpu: list[int] (accepted draft tokens per request) + """ + if batch.forward_mode.is_idle(): + empty = torch.empty((0,), dtype=torch.int64, device=batch.device) + return empty, empty.to(torch.int32), empty, [] + + bs = batch.batch_size() + device = logits_output.next_token_logits.device + + sampling_info = batch.sampling_info + if sampling_info is not None: + if len(sampling_info) != bs: + raise RuntimeError( + "DFLASH verify sampling_info size mismatch: " + f"len(sampling_info)={len(sampling_info)}, bs={bs}." + ) + + # Keep speculative verify semantics consistent with normal sampling path. + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + logits_output.next_token_logits, + sampling_info, + num_tokens_in_batch=self.draft_token_num, + ) + + if ( + sampling_info.penalizer_orchestrator.is_required + or sampling_info.logit_bias is not None + ): + linear_penalty = torch.zeros( + (bs, logits_output.next_token_logits.shape[1]), + dtype=torch.float32, + device=device, + ) + sampling_info.apply_logits_bias(linear_penalty) + logits_output.next_token_logits.add_( + torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + ) + + candidates = self.draft_token.view(bs, self.draft_token_num) + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + ) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, self.draft_token_num + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + + # Single D2H transfer: candidates[1:] + accept_len + bonus + packed = torch.cat( + [candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 + ).cpu() + + max_acc = self.draft_token_num - 1 + accept_length_per_req_cpu: List[int] = [] + commit_lens_cpu: List[int] = [] + new_verified_list: List[int] = [] + + for i, req in enumerate(batch.reqs): + acc_len = int(packed[i, max_acc].item()) + proposed = packed[i, :acc_len].tolist() + [ + int(packed[i, max_acc + 1].item()) + ] + + appended = 0 + for token_id in proposed: + token_id = int(token_id) + req.output_ids.append(token_id) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(token_id) + + if req.output_ids: + new_verified_token = int(req.output_ids[-1]) + elif req.origin_input_ids: + # If no token was appended in this verify step, keep the current token unchanged. + new_verified_token = int(req.origin_input_ids[-1]) + else: + raise RuntimeError( + "DFLASH verify cannot determine current token: both output_ids and origin_input_ids are empty." + ) + + commit_lens_cpu.append(appended) + new_verified_list.append(new_verified_token) + accept_length_per_req_cpu.append(max(0, appended - 1)) + req.spec_verify_ct += 1 + req.spec_accepted_tokens += accept_length_per_req_cpu[-1] + + commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) + new_verified_id = torch.tensor( + new_verified_list, dtype=torch.int64, device=device + ) + + # Free uncommitted KV cache slots and compact out_cache_loc. + if page_size == 1: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + keep_mask = ( + torch.arange(self.draft_token_num, device=device)[None, :] + < commit_lens[:, None] + ) + batch.token_to_kv_pool_allocator.free(out_cache_loc[~keep_mask]) + batch.out_cache_loc = out_cache_loc[keep_mask] + else: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + row_offsets = torch.arange(self.draft_token_num, device=device)[None, :] + keep_slots = _compute_paged_keep_slots( + prefix_lens=batch.seq_lens, + commit_lens=commit_lens, + draft_token_num=self.draft_token_num, + page_size=page_size, + ) + free_mask = row_offsets >= keep_slots[:, None] + batch.token_to_kv_pool_allocator.free(out_cache_loc[free_mask]) + + keep_mask = row_offsets < commit_lens[:, None] + batch.out_cache_loc = out_cache_loc[keep_mask] + + # Update req-level KV cache accounting. + for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): + req.kv_committed_len += commit_len + req.kv_allocated_len = req.kv_committed_len + + # Update req_to_token pool mapping for newly committed tokens. + end_offset = batch.seq_lens + commit_lens.to(batch.seq_lens.dtype) + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + # Update batch seq lens. + batch.seq_lens.add_(commit_lens.to(batch.seq_lens.dtype)) + batch.seq_lens_cpu.add_( + torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype) + ) + # Keep seq_lens_sum in sync; flashinfer indices updaters rely on this for buffer sizing. + batch.seq_lens_sum += sum(commit_lens_cpu) + + # Build next-step context features from the committed verify-input tokens. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) + hidden = hidden.view(bs, self.draft_token_num, -1) + segments: List[torch.Tensor] = [] + for i, ln in enumerate(commit_lens_cpu): + if ln > 0: + segments.append(hidden[i, :ln, :]) + next_target_hidden = torch.cat(segments, dim=0) if segments else hidden[:0] + + # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). + logits_output.hidden_states = None + + return ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py new file mode 100644 index 000000000000..ddec049e0a24 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +from dataclasses import dataclass +from numbers import Integral +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.utils import is_cuda + +DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" + +_DFLASH_SAMPLING_VERIFY_AVAILABLE = False +_DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} +_DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS = frozenset( + { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } +) + + +if is_cuda(): + try: + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + ) + + _DFLASH_SAMPLING_VERIFY_AVAILABLE = True + except Exception: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None +else: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None + + +def is_dflash_sampling_verify_available() -> bool: + return _DFLASH_SAMPLING_VERIFY_AVAILABLE + + +def scale_kv_cell_size_per_token_for_dflash( + *, + target_cell_size_per_token: int, + target_num_layers: int, + draft_num_layers: int, + draft_cell_size_per_token: Optional[int] = None, +) -> int: + """Compute bytes/token budget for combined target+draft KV pools (DFLASH). + + DFLASH runs a separate draft runner with its own KV pool. The target runner's + token capacity must fit both pools in aggregate. + + Returns: + Approximate per-token bytes for (target KV + draft KV), expressed as a + scaled version of `target_cell_size_per_token`, unless an explicit + `draft_cell_size_per_token` is provided (in which case we sum them). + """ + if target_cell_size_per_token <= 0: + raise ValueError( + "target_cell_size_per_token must be positive, " + f"got {target_cell_size_per_token}." + ) + + if draft_cell_size_per_token is not None: + draft_cell_size_per_token = int(draft_cell_size_per_token) + if draft_cell_size_per_token <= 0: + raise ValueError( + "draft_cell_size_per_token must be positive when provided, " + f"got {draft_cell_size_per_token}." + ) + return int(target_cell_size_per_token) + int(draft_cell_size_per_token) + + if target_num_layers <= 0 or draft_num_layers <= 0: + return int(target_cell_size_per_token) + + total_layers = int(target_num_layers) + int(draft_num_layers) + return ( + int(target_cell_size_per_token) * int(total_layers) + int(target_num_layers) - 1 + ) // int(target_num_layers) + + +def resolve_dflash_verify_mask_policy(attn_backend: Any) -> tuple[str, bool]: + backend = attn_backend + for _ in range(4): + full_backend = getattr(backend, "full_attn_backend", None) + if full_backend is None: + break + backend = full_backend + backend_name = type(backend).__name__ + return backend_name, (backend_name not in _DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS) + + +def _get_or_create_chain_verify_buffers( + *, + bs: int, + draft_token_num: int, + device: torch.device, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + key = (device.index, int(draft_token_num)) + cached = _DFLASH_CHAIN_VERIFY_BUFFERS.get(key) + cap_bs = 0 if cached is None else int(cached["cap_bs"]) + if cap_bs < bs: + new_cap = max(int(bs), cap_bs * 2 if cap_bs > 0 else int(bs)) + retrieve_index = torch.arange( + new_cap * draft_token_num, dtype=torch.int64, device=device + ).view(new_cap, draft_token_num) + row_next = torch.arange( + 1, draft_token_num + 1, dtype=torch.int64, device=device + ) + row_next[-1] = -1 + retrieve_next_token = row_next.unsqueeze(0).expand(new_cap, -1).clone() + retrieve_next_sibling = torch.full( + (new_cap, draft_token_num), -1, dtype=torch.int64, device=device + ) + predicts = torch.empty( + (new_cap * draft_token_num,), dtype=torch.int32, device=device + ) + accept_index = torch.empty( + (new_cap, draft_token_num), dtype=torch.int32, device=device + ) + accept_token_num = torch.empty((new_cap,), dtype=torch.int32, device=device) + cached = { + "cap_bs": int(new_cap), + "retrieve_index": retrieve_index, + "retrieve_next_token": retrieve_next_token, + "retrieve_next_sibling": retrieve_next_sibling, + "predicts": predicts, + "accept_index": accept_index, + "accept_token_num": accept_token_num, + } + _DFLASH_CHAIN_VERIFY_BUFFERS[key] = cached + + assert cached is not None + retrieve_index = cached["retrieve_index"][:bs] + retrieve_next_token = cached["retrieve_next_token"][:bs] + retrieve_next_sibling = cached["retrieve_next_sibling"][:bs] + predicts = cached["predicts"][: bs * draft_token_num] + accept_index = cached["accept_index"][:bs] + accept_token_num = cached["accept_token_num"][:bs] + return ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Select target layer indices used to build DFlash context features. + + Args: + num_target_layers: Number of transformer layers in the runtime target model. + num_draft_layers: Number of layers in the DFlash draft model. + + Returns: + A list of 0-based target layer indices of length `num_draft_layers`. + + Notes: + - DFlash uses hidden states after each selected target layer (HF-style). + - SGLang captures "before layer i", so the model hook will typically add +1 + when mapping to capture points. + """ + if num_target_layers <= 0: + raise ValueError( + f"num_target_layers must be positive, got {num_target_layers}." + ) + if num_draft_layers <= 0: + raise ValueError(f"num_draft_layers must be positive, got {num_draft_layers}.") + + if num_draft_layers == 1: + return [num_target_layers // 2] + + start = 1 + end = num_target_layers - 3 + if end < start: + raise ValueError( + "DFlash layer selection requires num_target_layers >= 4. " + f"Got num_target_layers={num_target_layers}." + ) + + span = end - start + return [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + + +def _cfg_get(config: Any, key: str, default: Any = None) -> Any: + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +def _get_text_config(config: Any) -> Any: + if config is None: + return None + if isinstance(config, dict): + return config.get("text_config", config) + text_config = getattr(config, "text_config", None) + if text_config is not None: + return text_config + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + resolved = get_text_config() + if resolved is not None: + return resolved + except TypeError: + pass + return config + + +def _get_dflash_config(config: Any) -> dict: + if isinstance(config, dict): + cfg = config.get("dflash_config", None) + else: + cfg = getattr(config, "dflash_config", None) + if cfg is None: + return {} + if isinstance(cfg, dict): + return cfg + + try: + return dict(cfg) + except Exception: + return {} + + +def _parse_optional_int( + value: Any, + *, + field_name: str, + min_value: Optional[int] = None, +) -> Optional[int]: + if value is None: + return None + try: + parsed = int(value) + except Exception as e: + raise ValueError(f"Invalid {field_name}={value!r}.") from e + if min_value is not None and parsed < int(min_value): + comparator = "positive" if int(min_value) == 1 else f">= {int(min_value)}" + raise ValueError(f"{field_name} must be {comparator}, got {parsed}.") + return parsed + + +@dataclass(frozen=True) +class DFlashDraftConfig: + num_hidden_layers: Optional[int] + num_target_layers: Optional[int] + block_size: Optional[int] + target_layer_ids: Optional[List[int]] + mask_token: str + mask_token_id: Optional[int] + + def require_num_layers(self) -> int: + if self.num_hidden_layers is None: + raise ValueError( + "DFLASH requires draft num_hidden_layers in config. " + "Got config without num_hidden_layers." + ) + return int(self.num_hidden_layers) + + def resolve_block_size(self, *, default: Optional[int] = None) -> Optional[int]: + return self.block_size if self.block_size is not None else default + + def resolve_target_layer_ids( + self, + *, + target_num_layers: int, + draft_num_layers: Optional[int] = None, + ) -> List[int]: + target_num_layers = int(target_num_layers) + if target_num_layers <= 0: + raise ValueError( + f"target_num_layers must be positive, got {target_num_layers}." + ) + + if self.target_layer_ids is None: + if draft_num_layers is None: + draft_num_layers = self.require_num_layers() + return build_target_layer_ids(target_num_layers, int(draft_num_layers)) + + resolved = list(self.target_layer_ids) + if len(resolved) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(resolved)}." + ) + for idx, val in enumerate(resolved): + if val < 0 or val >= target_num_layers: + raise ValueError( + "DFLASH target_layer_ids contains an out-of-range layer id. " + f"target_layer_ids[{idx}]={val}, target_num_layers={target_num_layers}." + ) + return resolved + + +def parse_dflash_draft_config(*, draft_hf_config: Any) -> DFlashDraftConfig: + """Parse and validate DFLASH draft config fields from HF config/dict.""" + dflash_cfg = _get_dflash_config(draft_hf_config) + draft_text_config = _get_text_config(draft_hf_config) + + num_hidden_layers = _parse_optional_int( + _cfg_get(draft_text_config, "num_hidden_layers", None), + field_name="DFLASH draft num_hidden_layers", + min_value=1, + ) + raw_num_target_layers = dflash_cfg.get( + "num_target_layers", + _cfg_get(draft_hf_config, "num_target_layers", None), + ) + num_target_layers = _parse_optional_int( + raw_num_target_layers, + field_name="DFLASH draft num_target_layers", + min_value=1, + ) + + # Keep support for current checkpoints where block_size is top-level. + raw_block_size = dflash_cfg.get( + "block_size", + _cfg_get(draft_hf_config, "block_size", None), + ) + block_size = _parse_optional_int( + raw_block_size, + field_name="DFLASH block_size", + min_value=1, + ) + + layer_ids = dflash_cfg.get( + "target_layer_ids", + _cfg_get(draft_hf_config, "target_layer_ids", None), + ) + parsed_target_layer_ids: Optional[List[int]] + if layer_ids is None: + parsed_target_layer_ids = None + else: + if not isinstance(layer_ids, (list, tuple)): + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(layer_ids).__name__}." + ) + parsed_target_layer_ids = [int(x) for x in layer_ids] + if len(parsed_target_layer_ids) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(parsed_target_layer_ids)}." + ) + + mask_token = dflash_cfg.get("mask_token", None) + if mask_token is None: + mask_token = DEFAULT_DFLASH_MASK_TOKEN + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + "DFLASH dflash_config.mask_token must be a non-empty string, " + f"got {mask_token!r}." + ) + + mask_token_id = dflash_cfg.get("mask_token_id", None) + if mask_token_id is not None: + if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): + raise ValueError( + "DFLASH dflash_config.mask_token_id must be an integer, " + f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." + ) + mask_token_id = int(mask_token_id) + if mask_token_id < 0: + raise ValueError( + "DFLASH dflash_config.mask_token_id must be non-negative, " + f"got {mask_token_id}." + ) + + return DFlashDraftConfig( + num_hidden_layers=num_hidden_layers, + num_target_layers=num_target_layers, + block_size=block_size, + target_layer_ids=parsed_target_layer_ids, + mask_token=mask_token, + mask_token_id=mask_token_id, + ) + + +def can_dflash_slice_qkv_weight(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether DFlash can slice KV weights from a fused QKV linear layer.""" + quant_method = getattr(qkv_proj, "quant_method", None) + if not isinstance(quant_method, UnquantizedLinearMethod): + return ( + False, + "quantized qkv_proj is not supported for this path " + f"(quant_method={type(quant_method).__name__})", + ) + if not hasattr(qkv_proj, "weight"): + return False, "qkv weight tensor is missing" + return True, "" + + +def can_dflash_use_fused_qkv_proj(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether a QKV layer is eligible for DFlash fused KV materialization.""" + eligible, reason = can_dflash_slice_qkv_weight(qkv_proj) + if not eligible: + return False, reason + if getattr(qkv_proj, "bias", None) is not None: + return False, "qkv bias is not supported for fused KV path" + return True, "" + + +def compute_dflash_accept_len_and_bonus( + *, + candidates: torch.Tensor, + target_predict: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens (greedy verify rule). + + Args: + candidates: Token ids proposed by the DFlash draft, including the current token. + Shape: [bs, block_size]. candidates[:, 0] is the current token. + target_predict: Token ids predicted by the target model for each position in the block. + Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t. + + Returns: + accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). + bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append). + + Notes: + Matches the reference implementation rule: + accept while candidates[:, 1:] == target_predict[:, :-1] consecutively. + """ + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if target_predict.shape != candidates.shape: + raise ValueError( + "target_predict must have the same shape as candidates. " + f"candidates.shape={tuple(candidates.shape)}, target_predict.shape={tuple(target_predict.shape)}" + ) + + bs, block_size = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}.") + + matches = candidates[:, 1:] == target_predict[:, :-1] + accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) + bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] + return accept_len, bonus.to(torch.int64) + + +def compute_dflash_sampling_accept_len_and_bonus( + *, + candidates: torch.Tensor, + next_token_logits: torch.Tensor, + sampling_info: Any, + threshold_single: Optional[float] = None, + threshold_acc: Optional[float] = None, + uniform_samples: Optional[torch.Tensor] = None, + uniform_samples_for_final_sampling: Optional[torch.Tensor] = None, + use_sparse_topk: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens for non-greedy sampling. + + This is a chain-specialized variant of speculative target-only verification: + - DFlash proposals are linear (topk == 1), so each verify level has at most one candidate. + - When a candidate is rejected at a level, the final token is sampled from + `relu(q - p)` where `p` has only the rejected candidate mass. + """ + if not _DFLASH_SAMPLING_VERIFY_AVAILABLE: + raise RuntimeError( + "DFLASH non-greedy verification is unavailable on this build/device." + ) + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + + bs, draft_token_num = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + if candidates.device != next_token_logits.device: + raise ValueError( + "candidates and next_token_logits must be on the same device, " + f"got {candidates.device} and {next_token_logits.device}." + ) + + if threshold_single is None: + from sglang.srt.server_args import get_global_server_args + + threshold_single = get_global_server_args().speculative_accept_threshold_single + if threshold_acc is None: + from sglang.srt.server_args import get_global_server_args + + threshold_acc = get_global_server_args().speculative_accept_threshold_acc + threshold_single = float(threshold_single) + threshold_acc = max(float(threshold_acc), 1e-9) + + device = next_token_logits.device + + if uniform_samples is None: + uniform_samples = torch.rand( + (bs, draft_token_num), dtype=torch.float32, device=device + ) + else: + if uniform_samples.shape != (bs, draft_token_num): + raise ValueError( + "uniform_samples shape mismatch. " + f"Expected {(bs, draft_token_num)}, got {tuple(uniform_samples.shape)}." + ) + uniform_samples = uniform_samples.to(device=device, dtype=torch.float32) + + if uniform_samples_for_final_sampling is None: + uniform_samples_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=device + ) + else: + if uniform_samples_for_final_sampling.shape != (bs,): + raise ValueError( + "uniform_samples_for_final_sampling shape mismatch. " + f"Expected {(bs,)}, got {tuple(uniform_samples_for_final_sampling.shape)}." + ) + uniform_samples_for_final_sampling = uniform_samples_for_final_sampling.to( + device=device, + dtype=torch.float32, + ) + + need_top_k = bool(getattr(sampling_info, "need_top_k_sampling", True)) + need_top_p = bool(getattr(sampling_info, "need_top_p_sampling", False)) + # Build target distribution once over all verify rows. + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, draft_token_num, dim=0 + ) + scaled_logits = next_token_logits / expanded_temperature + sparse_topk_applied = False + + if use_sparse_topk and need_top_k: + repeated_top_ks = torch.repeat_interleave( + sampling_info.top_ks, draft_token_num, dim=0 + ).to(dtype=torch.int64) + vocab_size = int(scaled_logits.shape[-1]) + repeated_top_ks.clamp_(min=1, max=vocab_size) + max_top_k = int(repeated_top_ks.max().item()) + + # Sparse exact path for top-k/top-p (top-k-first semantics), then scatter to dense. + if 0 < max_top_k < vocab_size: + topk_logits, topk_indices = torch.topk(scaled_logits, k=max_top_k, dim=-1) + if not torch.all(repeated_top_ks == max_top_k): + ranks = torch.arange(max_top_k, device=device, dtype=torch.int64)[ + None, : + ] + valid = ranks < repeated_top_ks.unsqueeze(1) + topk_logits = topk_logits.masked_fill(~valid, float("-inf")) + + topk_probs = F.softmax(topk_logits, dim=-1) + if need_top_p: + repeated_top_ps = torch.repeat_interleave( + sampling_info.top_ps, draft_token_num, dim=0 + ) + topk_probs = top_p_renorm_prob(topk_probs, repeated_top_ps) + + target_probs = torch.zeros_like(scaled_logits, dtype=topk_probs.dtype) + target_probs.scatter_(1, topk_indices, topk_probs) + sparse_topk_applied = True + + if not sparse_topk_applied: + target_probs = F.softmax(scaled_logits, dim=-1) + if need_top_k: + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ks, draft_token_num, dim=0), + ) + if need_top_p: + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ps, draft_token_num, dim=0), + ) + target_probs = target_probs.view(bs, draft_token_num, -1).contiguous() + draft_probs = torch.zeros_like(target_probs) + + ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) = _get_or_create_chain_verify_buffers( + bs=bs, + draft_token_num=draft_token_num, + device=device, + ) + candidates_i64 = ( + candidates if candidates.dtype == torch.int64 else candidates.to(torch.int64) + ) + tree_speculative_sampling_target_only( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates_i64, + retrive_index=retrieve_index, + retrive_next_token=retrieve_next_token, + retrive_next_sibling=retrieve_next_sibling, + uniform_samples=uniform_samples, + uniform_samples_for_final_sampling=uniform_samples_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=threshold_single, + threshold_acc=threshold_acc, + deterministic=True, + ) + + accept_len = accept_token_num + row_ids = torch.arange(bs, dtype=torch.long, device=device) + accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long) + bonus = predicts[accept_pos].to(torch.int64) + return accept_len, bonus diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py new file mode 100644 index 000000000000..030aa21e5b35 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -0,0 +1,1245 @@ +import logging +import math +from copy import deepcopy +from typing import Optional, Union + +import torch + +from sglang.srt.distributed import get_tp_group +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) +from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput +from sglang.srt.speculative.dflash_utils import ( + can_dflash_use_fused_qkv_proj, + is_dflash_sampling_verify_available, + parse_dflash_draft_config, + resolve_dflash_verify_mask_policy, +) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.utils import is_cuda + +logger = logging.getLogger(__name__) + +_FusedKVMaterializeHelper = None + + +def _get_fused_kv_materialize_helper(): + global _FusedKVMaterializeHelper + if _FusedKVMaterializeHelper is None: + from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, + ) + + _FusedKVMaterializeHelper = FusedKVMaterializeHelper + return _FusedKVMaterializeHelper + + +class DFlashWorker: + """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.server_args = server_args + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moe_ep_rank = moe_ep_rank + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank + self.nccl_port = nccl_port + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.page_size = server_args.page_size + self.draft_window_size: Optional[int] = ( + int(server_args.speculative_dflash_draft_window_size) + if server_args.speculative_dflash_draft_window_size is not None + else None + ) + self.use_compact_draft_cache = self.draft_window_size is not None + self.device = target_worker.device + + self._warned_sampling_fallback = False + self._logged_first_verify = False + + # Draft runner (separate KV cache + attention backend). + # Without draft windowing, the draft worker aliases the target request->token + # mapping and allocation state. With draft windowing enabled, the draft worker + # keeps a private compact req->token table over the same global KV index space, + # so radix-cache/prefix-hit KV remains reusable while draft attention sees only + # the recent window. + target_req_to_token_pool, target_token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + shared_req_to_token_pool = ( + None if self.use_compact_draft_cache else target_req_to_token_pool + ) + draft_server_args = deepcopy(server_args) + draft_server_args.skip_tokenizer_init = True + draft_backend = draft_server_args.speculative_draft_attention_backend + supported_draft_backends = ("flashinfer", "fa3", "fa4") + if draft_backend is None: + draft_backend, _ = draft_server_args.get_attention_backends() + if draft_backend is None: + draft_backend = "flashinfer" + elif draft_backend == "trtllm_mha": + logger.warning( + "DFLASH draft worker does not support 'trtllm_mha' because the " + "draft path requires non-causal attention. Falling back to " + "'flashinfer'." + ) + draft_backend = "flashinfer" + elif draft_backend not in supported_draft_backends: + logger.warning( + "DFLASH draft worker only supports attention_backend in %s for now, " + "but got %r. Falling back to 'flashinfer'.", + supported_draft_backends, + draft_backend, + ) + draft_backend = "flashinfer" + # Make the draft worker backend explicit and self-contained (no further overrides). + draft_server_args.speculative_draft_attention_backend = None + draft_server_args.prefill_attention_backend = None + draft_server_args.decode_attention_backend = None + draft_server_args.attention_backend = draft_backend + # Keep draft context length aligned with the target. + draft_server_args.context_length = ( + target_worker.model_runner.model_config.context_len + ) + saved_server_args = get_global_server_args() + self.draft_worker = TpModelWorker( + server_args=draft_server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + pp_rank=0, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + dp_rank=dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=shared_req_to_token_pool, + token_to_kv_pool_allocator=target_token_to_kv_pool_allocator, + memory_pool_config=target_worker.model_runner.memory_pool_config, + ) + set_global_server_args_for_scheduler(saved_server_args) + self.draft_model_runner = self.draft_worker.model_runner + self.draft_model = self.draft_model_runner.model + draft_config = parse_dflash_draft_config( + draft_hf_config=self.draft_model_runner.model_config.hf_config + ) + if server_args.speculative_num_draft_tokens is None: + # Should not happen (ServerArgs should have inferred it), but keep a fallback. + self.block_size = int(draft_config.resolve_block_size(default=16)) + else: + self.block_size = int(server_args.speculative_num_draft_tokens) + model_block_size = draft_config.block_size + if model_block_size is None: + model_block_size = getattr(self.draft_model, "block_size", None) + if model_block_size is not None and int(model_block_size) != int( + self.block_size + ): + logger.warning( + "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", + self.block_size, + model_block_size, + ) + + self._mask_token = draft_config.mask_token + self._mask_token_id_override = draft_config.mask_token_id + self._mask_token_id = self._resolve_mask_token_id( + mask_token=self._mask_token, + mask_token_id=self._mask_token_id_override, + ) + if self.tp_rank == 0: + logger.info( + "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s, draft_window_size=%s, compact_cache=%s", + getattr(draft_server_args, "attention_backend", None), + self.draft_model.__class__.__name__, + self.block_size, + self.draft_window_size, + self.use_compact_draft_cache, + ) + logger.info( + "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s, mask_token_id_override=%s", + self._mask_token, + self._mask_token_id, + self._mask_token_id_override, + ) + + self._block_pos_offsets = torch.arange( + self.block_size, device=self.device, dtype=torch.int64 + ) + self._draft_block_ids_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_positions_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_tokens_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] + self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU + self._draft_block_spec_info = DFlashVerifyInput( + draft_token=torch.empty((0,), dtype=torch.long, device=self.device), + positions=torch.empty((0,), dtype=torch.int64, device=self.device), + draft_token_num=int(self.block_size), + custom_mask=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None + self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_gather_cap: int = 0 + self._draft_greedy_best_rank_buf: Optional[torch.Tensor] = None + self._draft_greedy_rank_index_buf: Optional[torch.Tensor] = None + self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_index_cap: int = 0 + + self._use_fused_kv_materialize = is_cuda() + self._fused_kv_helper: Optional[object] = None + if self._use_fused_kv_materialize: + self._init_fused_kv_helper() + + def _init_fused_kv_helper(self) -> None: + """Initialize the fused KV materialization helper with pre-stacked weights.""" + try: + layers = self.draft_model.layers + fused_disable_reason: Optional[str] = None + + if len(layers) == 0: + fused_disable_reason = "no layers found" + + for layer_idx, layer in enumerate(layers): + attn = layer.self_attn + eligible, reason = can_dflash_use_fused_qkv_proj(attn.qkv_proj) + if not eligible: + fused_disable_reason = f"{reason}: layer={layer_idx}" + break + + # Keep semantics aligned with set_kv_buffer scaling behavior. + k_scale = getattr(attn.attn, "k_scale", None) + v_scale = getattr(attn.attn, "v_scale", None) + if k_scale is not None and not math.isclose(float(k_scale), 1.0): + fused_disable_reason = ( + "non-unit k_scale is not supported for fused KV path: " + f"layer={layer_idx}, k_scale={k_scale}" + ) + break + if v_scale is not None and not math.isclose(float(v_scale), 1.0): + fused_disable_reason = ( + "non-unit v_scale is not supported for fused KV path: " + f"layer={layer_idx}, v_scale={v_scale}" + ) + break + + rope_is_neox_style = bool( + getattr(attn.rotary_emb, "is_neox_style", True) + ) + if not rope_is_neox_style: + fused_disable_reason = ( + "non-neox RoPE is not supported for fused KV path: " + f"layer={layer_idx}, rope_is_neox_style={rope_is_neox_style}" + ) + break + + if fused_disable_reason is not None: + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization disabled: %s", + fused_disable_reason, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + return + + FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() + first_attn = layers[0].self_attn + rotary_emb = first_attn.rotary_emb + + self._fused_kv_helper = FusedKVMaterializeHelper( + layers=layers, + rotary_emb=rotary_emb, + num_kv_heads=first_attn.num_kv_heads, + head_dim=first_attn.head_dim, + device=self.device, + ) + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization enabled. " + "n_layers=%d, num_kv_heads=%d, head_dim=%d", + len(layers), + first_attn.num_kv_heads, + first_attn.head_dim, + ) + except Exception as e: + logger.warning( + "DFLASH fused KV initialization failed, falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + def _ensure_draft_block_buffers(self, bs: int) -> None: + cap = ( + 0 + if self._draft_block_ids_buf is None + else int(self._draft_block_ids_buf.shape[0]) + ) + if cap >= int(bs): + return + + new_cap = max(int(bs), cap * 2 if cap > 0 else int(bs)) + device = self.device + block_size = int(self.block_size) + self._draft_block_ids_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_positions_buf = torch.empty( + (new_cap, block_size), dtype=torch.int64, device=device + ) + self._draft_block_tokens_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_end_buf = torch.empty( + (new_cap,), dtype=torch.int32, device=device + ) + self._draft_seq_lens_cpu_buf = torch.empty( + (new_cap,), dtype=torch.int32, device="cpu" + ) + + def __getattr__(self, name): + # Delegate anything not implemented yet to the target worker. + return getattr(self.target_worker, name) + + def clear_cache_pool(self): + # The target worker owns the shared KV allocator/cache. For the compact + # sliding-window path, the draft req->token view is rebuilt from committed + # target state before each draft forward, so there is nothing persistent + # to flush here. + pass + + def _gather_req_to_token_masked( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + pos2d: torch.Tensor, + mask: torch.Tensor, + context: str, + ) -> torch.Tensor: + if pos2d.ndim != 2: + raise RuntimeError( + f"{context} expected 2D positions, got shape={tuple(pos2d.shape)}." + ) + if mask.shape != pos2d.shape: + raise RuntimeError( + f"{context} mask/position shape mismatch: {tuple(mask.shape)} vs {tuple(pos2d.shape)}." + ) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + if mask.dtype != torch.bool: + mask = mask.to(torch.bool) + + table_width = int(req_to_token.shape[1]) + if table_width <= 0: + if bool(mask.any().item()): + raise RuntimeError( + f"{context} req_to_token table is empty but gather mask is non-empty." + ) + return torch.empty((0,), dtype=torch.int64, device=self.device) + + # Only the masked-off rectangular padding can be out of range in the normal + # ragged-batch case. Replace those don't-care columns with a valid in-range + # position before the gather so the kernel only sees real positions. + safe_pos2d = pos2d.masked_fill(~mask, 0) + return req_to_token[req_pool_indices[:, None], safe_pos2d][mask].to(torch.int64) + + def _gather_req_to_token_segments( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + start: torch.Tensor | None, + lengths: torch.Tensor, + ) -> torch.Tensor: + lengths = lengths.to(torch.int64) + if lengths.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + max_len = int(lengths.max().item()) + if max_len <= 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + offsets = torch.arange( + max_len, device=self.device, dtype=torch.int64 + ).unsqueeze(0) + if start is None: + pos2d = offsets.expand(req_pool_indices.shape[0], -1) + else: + pos2d = start.to(torch.int64).unsqueeze(1) + offsets + mask = offsets < lengths.unsqueeze(1) + return self._gather_req_to_token_masked( + req_to_token=req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH req_to_token segment gather", + ) + + def _compute_compact_draft_seq_lens(self, seq_lens: torch.Tensor) -> torch.Tensor: + assert self.draft_window_size is not None + visible_lens = torch.clamp( + seq_lens.to(dtype=torch.int32, device=self.device), + max=int(self.draft_window_size), + ) + if self.page_size <= 1: + return visible_lens + + # Paged FA backends derive the page table from local token positions, so the + # compact suffix must start on a page boundary. Keep up to page_size - 1 extra + # tokens on the left to preserve valid local page structure. + seq_lens_i64 = seq_lens.to(torch.int64) + visible_lens_i64 = visible_lens.to(torch.int64) + visible_start = seq_lens_i64 - visible_lens_i64 + aligned_start = visible_start - torch.remainder(visible_start, self.page_size) + return (seq_lens_i64 - aligned_start).to(torch.int32) + + def _resolve_mask_token_id( + self, *, mask_token: str, mask_token_id: Optional[int] = None + ) -> int: + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + f"DFLASH mask_token must be a non-empty string, got {mask_token!r}." + ) + + vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) + if mask_token_id is not None: + resolved_id = int(mask_token_id) + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is not None: + token_id_from_vocab = tokenizer.get_vocab().get(mask_token, None) + if ( + token_id_from_vocab is not None + and int(token_id_from_vocab) != resolved_id + ): + raise ValueError( + "DFLASH config mismatch: dflash_config.mask_token_id conflicts with tokenizer vocab id " + f"for dflash_config.mask_token. mask_token={mask_token!r}, " + f"mask_token_id={resolved_id}, tokenizer_vocab_id={int(token_id_from_vocab)}." + ) + return resolved_id + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is None: + raise RuntimeError( + "DFLASH requires tokenizer initialization when dflash_config.mask_token_id is not set " + "(skip_tokenizer_init is not supported in this mode)." + ) + + resolved_id = None + if getattr(tokenizer, "mask_token", None) == mask_token: + resolved_id = getattr(tokenizer, "mask_token_id", None) + + if resolved_id is None: + # Prefer checking the explicit vocab mapping first. + vocab = tokenizer.get_vocab() + resolved_id = vocab.get(mask_token, None) + + if resolved_id is None: + # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. + # This is safe only when the resulting id stays within the target model vocab size. + added = tokenizer.add_special_tokens({"mask_token": mask_token}) + resolved_id = getattr(tokenizer, "mask_token_id", None) + if resolved_id is None: + resolved_id = tokenizer.convert_tokens_to_ids(mask_token) + + if added and self.tp_rank == 0: + logger.info( + "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", + mask_token, + resolved_id, + len(tokenizer), + vocab_size, + ) + + if resolved_id is None or int(resolved_id) < 0: + raise ValueError( + "DFLASH requires resolving a mask token id, but it could not be resolved. " + f"mask_token={mask_token!r}." + ) + + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + return int(resolved_id) + + def _prepare_for_speculative_decoding( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): + if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): + return + + if batch.has_grammar: + raise RuntimeError( + "Invariant broken: DFLASH batch has grammar constraints, but scheduler should have rejected this request." + ) + if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): + logger.warning( + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." + ) + self._warned_sampling_fallback = True + + bs = batch.batch_size() + + # --- 1) Append any newly committed tokens into the draft KV cache. + self._append_target_hidden_to_draft_kv(batch, draft_input) + + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if ( + lm_head is None + or not hasattr(lm_head, "weight") + or not hasattr(lm_head, "shard_indices") + ): + raise RuntimeError( + "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " + "`shard_indices` attributes." + ) + + # --- 2) Draft a non-causal block with the draft model. + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id.to(torch.long)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + # For spec-v1, the draft KV cache is always materialized before drafting the + # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` + # are the logical resident lengths in the draft-local cache. + target_prefix_lens = batch.seq_lens # int32, device + draft_prefix_lens = draft_input.draft_seq_lens + if draft_prefix_lens.dtype != torch.int32: + draft_prefix_lens = draft_prefix_lens.to(torch.int32) + if draft_prefix_lens.device != self.device: + draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) + + positions_2d = self._draft_block_positions_buf[:bs] + torch.add( + target_prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d + ) + positions = positions_2d.reshape(-1) + + block_start = draft_prefix_lens + block_end = self._draft_block_end_buf[:bs] + torch.add(block_start, int(self.block_size), out=block_end) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) + allocator = self.draft_model_runner.token_to_kv_pool_allocator + token_to_kv_pool_state_backup = allocator.backup_state() + try: + if self.page_size == 1: + block_cache_loc = allocator.alloc(bs * self.block_size) + else: + block_end_cpu = seq_lens_cpu + int(self.block_size) + last_loc = get_last_loc( + self.draft_model_runner.req_to_token_pool.req_to_token, + batch.req_pool_indices, + block_start, + ) + block_cache_loc = allocator.alloc_extend( + block_start, + seq_lens_cpu, + block_end, + block_end_cpu, + last_loc, + bs * self.block_size, + ) + if block_cache_loc is None: + raise RuntimeError( + f"DFLASH draft OOM when allocating {bs * self.block_size} block tokens." + ) + + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + block_start, + block_end, + block_cache_loc, + bs, + ) + + # Use TARGET_VERIFY mode (cuda-graphable) to run a fixed-size draft block. + # In this mode, `seq_lens` stores the prefix lengths; attention backends + # derive kv_len by adding `draft_token_num`. + draft_spec_info = self._draft_block_spec_info + seq_lens = draft_prefix_lens + seq_lens_sum = int(draft_prefix_lens.sum().item()) + forward_batch = ForwardBatch( + forward_mode=ForwardMode.TARGET_VERIFY, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=batch.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=block_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens_cpu, + positions=positions, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, + input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=draft_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + with torch.inference_mode(): + draft_logits_output = self.draft_model_runner.forward( + forward_batch + ).logits_output + finally: + # Drop the speculative block from the shared allocator (EAGLE3-style). + allocator.restore_state(token_to_kv_pool_state_backup) + + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") + draft_hidden = draft_hidden.view(bs, self.block_size, -1) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, self.block_size - 1) + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + positions = positions_2d.reshape(-1) + + verify_input = DFlashVerifyInput( + draft_token=draft_tokens.reshape(-1), + positions=positions, + draft_token_num=self.block_size, + ) + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + verify_input.prepare_for_verify( + batch, + self.page_size, + build_custom_mask=build_custom_mask, + ) + + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.spec_info = verify_input + batch.return_hidden_states = False + + def _greedy_sample_from_vocab_parallel_head( + self, + *, + hidden_states: torch.Tensor, + lm_head, + chunk_size: int = 256, + ) -> torch.Tensor: + """Greedy argmax over the target LM head in a TP-safe way. + + We cannot materialize full logits for large vocabularies efficiently, and with + TP>1 each rank only owns a shard of the LM head weight. This computes the + per-rank max, gathers candidates across TP ranks, and selects the global max. + """ + + if hidden_states.numel() == 0: + return torch.empty((0,), dtype=torch.long, device=hidden_states.device) + + tp_group = get_tp_group() + tp_size = int(tp_group.world_size) + + if not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + raise RuntimeError( + "DFLASH greedy sampling requires a vocab-parallel head with `weight` and `shard_indices`." + ) + + shard = lm_head.shard_indices + weight = lm_head.weight # [local_vocab_padded, hidden] + weight_dtype = weight.dtype + + # Valid ranges in the local shard (excluding padding): + # base vocab: [0, num_org) + # added vocab: [num_org_padded, num_org_padded + num_added) + num_org = int(shard.num_org_elements) + num_org_padded = int(shard.num_org_elements_padded) + num_added = int(shard.num_added_elements) + org_vocab_start = int(shard.org_vocab_start_index) + added_vocab_start = int(shard.added_vocab_start_index) + + num_tokens = int(hidden_states.shape[0]) + out_token_ids = torch.empty( + (num_tokens,), dtype=torch.long, device=hidden_states.device + ) + + def _cast_hs(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype == weight_dtype else x.to(weight_dtype) + + # Fast path (common): single-rank greedy sampling over the base vocab shard. + # Avoids extra max/id bookkeeping that is only needed for TP sync or added vocab. + if tp_size == 1 and num_added == 0: + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + out_token_ids[start:end] = ( + torch.argmax(base_logits, dim=-1).to(torch.long) + + org_vocab_start + ) + else: + out_token_ids[start:end] = 0 + return out_token_ids + + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + chunk_len = int(hs.shape[0]) + + # Base vocab logits. + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + local_max, local_arg = torch.max(base_logits, dim=-1) + else: + local_max = torch.full( + (chunk_len,), + torch.finfo(weight_dtype).min, + dtype=weight_dtype, + device=hs.device, + ) + local_arg = torch.zeros( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + + # Added vocab logits (e.g., LoRA-added embeddings), if present. + if num_added > 0: + added_slice_start = num_org_padded + added_slice_end = num_org_padded + num_added + added_logits = torch.matmul( + hs, weight[added_slice_start:added_slice_end].T + ) + added_max, added_arg = torch.max(added_logits, dim=-1) + use_added = added_max > local_max + local_max = torch.where(use_added, added_max, local_max) + # For base/added conversion below, keep local_arg expressed in the full local + # weight index space (base + padding + added), matching `lm_head.weight`. + local_arg = torch.where( + use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg + ) + + # Convert local argmax indices to global token ids. + if num_added == 0: + local_arg.add_(org_vocab_start) + global_ids = local_arg + else: + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] + global_ids[~is_base] = added_vocab_start + ( + local_arg[~is_base] - num_org_padded + ) + + if tp_size == 1: + out_token_ids[start:end] = global_ids.to(torch.long) + continue + + # Gather per-rank maxima and associated global ids, then select the global max. + needed = tp_size * chunk_len + chunk_cap = int(chunk_size) + if ( + self._draft_greedy_gather_cap < needed + or self._draft_greedy_gathered_max_buf is None + or self._draft_greedy_gathered_ids_buf is None + or self._draft_greedy_gathered_max_buf.dtype != local_max.dtype + or self._draft_greedy_gathered_max_buf.device != hs.device + ): + # Allocate enough space for the max chunk size to avoid reallocations. + cap = tp_size * chunk_cap + self._draft_greedy_gathered_max_buf = torch.empty( + (cap,), dtype=local_max.dtype, device=hs.device + ) + self._draft_greedy_gathered_ids_buf = torch.empty( + (cap,), dtype=global_ids.dtype, device=hs.device + ) + self._draft_greedy_gather_cap = cap + + if ( + self._draft_greedy_index_cap < chunk_len + or self._draft_greedy_best_rank_buf is None + or self._draft_greedy_rank_index_buf is None + or self._draft_greedy_selected_ids_buf is None + or self._draft_greedy_best_rank_buf.device != hs.device + or self._draft_greedy_selected_ids_buf.device != hs.device + ): + self._draft_greedy_best_rank_buf = torch.empty( + (chunk_cap,), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_rank_index_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_selected_ids_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_index_cap = chunk_cap + + gathered_max = self._draft_greedy_gathered_max_buf[:needed] + gathered_ids = self._draft_greedy_gathered_ids_buf[:needed] + + tp_group.all_gather_into_tensor(gathered_max, local_max.contiguous()) + tp_group.all_gather_into_tensor(gathered_ids, global_ids.contiguous()) + gathered_max = gathered_max.view(tp_size, chunk_len) + gathered_ids = gathered_ids.view(tp_size, chunk_len) + + best_rank = self._draft_greedy_best_rank_buf[:chunk_len] + torch.argmax(gathered_max, dim=0, out=best_rank) + + rank_index = self._draft_greedy_rank_index_buf[:, :chunk_len] + rank_index[0].copy_(best_rank) + selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len] + torch.gather(gathered_ids, 0, rank_index, out=selected_ids) + out_token_ids[start:end].copy_(selected_ids.view(-1)) + + return out_token_ids + + def _append_target_hidden_to_draft_kv( + self, + batch: ScheduleBatch, + draft_input: DFlashDraftInput, + ) -> None: + """Materialize the target hidden-state features into the draft KV cache. + + This must be run before exposing new tokens to radix cache (prefix hits), otherwise + another request could reuse target KV indices without having draft KV values. + """ + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError( + "DFLASH draft state missing target_hidden context features." + ) + if draft_input.ctx_lens.numel() != bs: + raise RuntimeError( + f"DFLASH ctx_lens length mismatch: got {draft_input.ctx_lens.numel()} for bs={bs}." + ) + if draft_input.draft_seq_lens.numel() != bs: + raise RuntimeError( + f"DFLASH draft_seq_lens length mismatch: got {draft_input.draft_seq_lens.numel()} for bs={bs}." + ) + + total_ctx = int(draft_input.target_hidden.shape[0]) + if total_ctx <= 0: + draft_input.ctx_lens = torch.zeros_like(draft_input.ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + return + + target_req_to_token = batch.req_to_token_pool.req_to_token + draft_req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token + + req_pool_indices = batch.req_pool_indices + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + + ctx_lens = draft_input.ctx_lens + if ctx_lens.dtype != torch.int32: + ctx_lens = ctx_lens.to(torch.int32) + if ctx_lens.device != device: + ctx_lens = ctx_lens.to(device, non_blocking=True) + ctx_start = batch.seq_lens.to(torch.int64) - ctx_lens.to(torch.int64) + + if bs == 1: + # Fast path for single request. + max_ctx = int(total_ctx) + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + pos2d = ctx_start[:, None] + r[None, :] # [1, ctx] + cache2d = target_req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] + ctx_cache_loc = cache2d.reshape(-1).to(torch.int64) # [ctx] + ctx_positions = pos2d.reshape(-1) # [ctx] + else: + # In decode mode, ctx_lens <= block_size so we can skip the .item() sync. + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + max_ctx = int(ctx_lens.max().item()) + else: + max_ctx = int(self.block_size) + if max_ctx <= 0: + raise RuntimeError(f"DFLASH invalid max_ctx={max_ctx} for KV append.") + + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + r = r[None, :] # [1, max_ctx] + pos2d = ctx_start[:, None] + r # [bs, max_ctx] + mask = r < ctx_lens[:, None] + + # Batched gather of cache locations and positions. + ctx_cache_loc = self._gather_req_to_token_masked( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH target hidden KV append", + ) # [sum(ctx_lens)] + ctx_positions = pos2d[mask] # [sum(ctx_lens)] + + with torch.inference_mode(): + ctx_hidden = self.draft_model.project_target_hidden( + draft_input.target_hidden + ) # [sum(ctx), hidden] + if ctx_hidden.shape[0] != ctx_cache_loc.numel(): + raise RuntimeError( + f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." + ) + + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + except Exception as e: + logger.warning( + "DFLASH fused KV append failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + else: + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + + if self.use_compact_draft_cache: + new_draft_seq_lens = self._compute_compact_draft_seq_lens(batch.seq_lens) + suffix_start = batch.seq_lens.to(torch.int64) - new_draft_seq_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + start=suffix_start, + lengths=new_draft_seq_lens, + ) + assign_req_to_token_pool_func( + batch.req_pool_indices, + draft_req_to_token, + torch.zeros_like(new_draft_seq_lens), + new_draft_seq_lens, + suffix_cache_loc, + bs, + ) + draft_input.draft_seq_lens = new_draft_seq_lens + else: + draft_input.draft_seq_lens = batch.seq_lens.to(dtype=torch.int32) + draft_input.ctx_lens = torch.zeros_like(ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + + def _append_target_hidden_sequential( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(ctx_positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + def _append_target_hidden_fused( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + """Fused KV materialization using batched projection + Triton kernel.""" + token_to_kv_pool = self.draft_model_runner.token_to_kv_pool + layers = self.draft_model.layers + + def _write_layer_kv( + layer_idx: int, cache_k: torch.Tensor, cache_v: torch.Tensor + ) -> None: + attn = layers[layer_idx].self_attn.attn + token_to_kv_pool.set_kv_buffer( + attn, + ctx_cache_loc, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) + + self._fused_kv_helper.materialize( + ctx_hidden=ctx_hidden, + positions=ctx_positions, + write_layer_kv=_write_layer_kv, + ) + + def _update_target_mamba_state_after_verify( + self, + *, + batch: ScheduleBatch, + seq_lens_pre_verify: torch.Tensor, + commit_lens: torch.Tensor, + ) -> None: + """Commit Mamba intermediate states for accepted verify steps. + + During TARGET_VERIFY, Mamba kernels run with `disable_state_update=True` and + cache per-step intermediate states. After acceptance, we need to commit the + state corresponding to each request's last accepted step. + """ + attn_backend = self.target_worker.model_runner.attn_backend + if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): + return + + accepted_steps = commit_lens.to(torch.int64) - 1 + mamba_steps_to_track = None + + if batch.mamba_track_indices is not None: + mamba_track_interval = self.server_args.mamba_track_interval + to_track_mask = ( + seq_lens_pre_verify // mamba_track_interval + != batch.seq_lens // mamba_track_interval + ) + tracking_point = ( + batch.seq_lens // mamba_track_interval * mamba_track_interval + ) + to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0) + can_track_mask = to_track_mask & ( + to_track_ith < commit_lens.to(to_track_ith.dtype) + ) + mamba_steps_to_track = torch.where( + can_track_mask, + to_track_ith.to(torch.int64), + torch.full_like(to_track_ith, -1, dtype=torch.int64), + ) + + attn_backend.update_mamba_state_after_mtp_verify( + accepted_steps=accepted_steps, + mamba_track_indices=batch.mamba_track_indices, + mamba_steps_to_track=mamba_steps_to_track, + model=self.target_worker.model_runner.model, + ) + + def forward_batch_generation( + self, + batch: Union[ScheduleBatch, ModelWorkerBatch], + **kwargs, + ) -> GenerationBatchResult: + if getattr(batch, "return_logprob", False): + raise RuntimeError( + "Invariant broken: DFLASH batch requested return_logprob, but scheduler should have rejected this request." + ) + + if isinstance(batch, ModelWorkerBatch): + # Should not happen for spec-v1 (non-overlap) scheduling, but keep a sane fallback. + return self.target_worker.forward_batch_generation(batch, **kwargs) + + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + logits_output, next_token_ids = ( + batch_result.logits_output, + batch_result.next_token_ids, + ) + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, but got None." + ) + + # Materialize the prompt tokens into the draft KV cache immediately. This is required + # for radix cache support, since the scheduler may update radix after prefill returns. + device = next_token_ids.device + + def _to_int32_device_tensor(x, *, device=device): + if isinstance(x, torch.Tensor): + if x.device != device: + x = x.to(device, non_blocking=True) + return x if x.dtype == torch.int32 else x.to(torch.int32) + return torch.tensor(x, dtype=torch.int32, device=device) + + extend_seq_lens = _to_int32_device_tensor( + model_worker_batch.extend_seq_lens + ) + draft_input = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens=extend_seq_lens, + draft_seq_lens=( + torch.zeros_like(extend_seq_lens) + if self.use_compact_draft_cache + else _to_int32_device_tensor(model_worker_batch.extend_prefix_lens) + ), + ) + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=0, + can_run_cuda_graph=batch_result.can_run_cuda_graph, + ) + + # Decode / target-verify stage. + draft_input = batch.spec_info + if not isinstance(draft_input, DFlashDraftInput): + raise RuntimeError( + "DFLASH decode requires DFlashDraftInput state on the running batch. " + "This usually means the request did not complete the prefill stage." + ) + + self._prepare_for_speculative_decoding(batch, draft_input) + + model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.forward_mode.is_target_verify() + verify_input = model_worker_batch.spec_info + assert isinstance(verify_input, DFlashVerifyInput) + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + batch.seq_lens.clone() if need_mamba_verify_commit else None + ) + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True, **kwargs + ) + logits_output, can_run_cuda_graph = ( + batch_result.logits_output, + batch_result.can_run_cuda_graph, + ) + + ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) = verify_input.verify( + batch=batch, + logits_output=logits_output, + page_size=self.page_size, + ) + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) + + # Update draft state for the next iteration. Also materialize the committed verify tokens + # into the draft KV cache immediately so radix cache entries are safe to reuse. + draft_input.verified_id = new_verified_id + draft_input.target_hidden = next_target_hidden + draft_input.ctx_lens = commit_lens + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + batch.forward_mode = ForwardMode.DECODE + + num_accepted_tokens = sum(accept_length_per_req_cpu) + if not self._logged_first_verify and self.tp_rank == 0: + logger.info( + "DFLASH verify completed. accept_length_per_req=%s", + accept_length_per_req_cpu, + ) + self._logged_first_verify = True + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=new_verified_id, + num_accepted_tokens=num_accepted_tokens, + accept_length_per_req_cpu=accept_length_per_req_cpu, + can_run_cuda_graph=can_run_cuda_graph, + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index a40a8aa0dc33..3e5727187572 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -15,6 +15,7 @@ class SpeculativeAlgorithm(Enum): """Enumeration of speculative decoding algorithms.""" + DFLASH = auto() EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() @@ -33,6 +34,9 @@ def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm: def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE + def is_speculative(self) -> bool: + return self != SpeculativeAlgorithm.NONE + def is_eagle(self) -> bool: # NOTE: EAGLE3 is a variant of EAGLE return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 @@ -40,6 +44,9 @@ def is_eagle(self) -> bool: def is_eagle3(self) -> bool: return self == SpeculativeAlgorithm.EAGLE3 + def is_dflash(self) -> bool: + return self == SpeculativeAlgorithm.DFLASH + def is_standalone(self) -> bool: return self == SpeculativeAlgorithm.STANDALONE @@ -57,6 +64,16 @@ def create_worker( ), "Cannot create worker for NONE speculative algorithm." enable_overlap = not server_args.disable_overlap_schedule + + if self.is_dflash(): + if enable_overlap: + raise ValueError( + "DFLASH does not support overlap scheduling (spec v2)." + ) + from sglang.srt.speculative.dflash_worker import DFlashWorker + + return DFlashWorker + if self.is_eagle() and server_args.enable_multi_layer_eagle: # FIXME: migrate to EagleWorker if enable_overlap: @@ -110,6 +127,8 @@ class SpecInputType(IntEnum): # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() EAGLE_VERIFY = auto() + DFLASH_DRAFT = auto() + DFLASH_VERIFY = auto() NGRAM_VERIFY = auto() @@ -120,11 +139,15 @@ def __init__(self, spec_input_type: SpecInputType): def is_draft_input(self) -> bool: # FIXME: remove this function which is only used for assertion # or use another variable name like `draft_input` to substitute `spec_info` - return self.spec_input_type == SpecInputType.EAGLE_DRAFT + return self.spec_input_type in { + SpecInputType.EAGLE_DRAFT, + SpecInputType.DFLASH_DRAFT, + } def is_verify_input(self) -> bool: return self.spec_input_type in { SpecInputType.EAGLE_VERIFY, + SpecInputType.DFLASH_VERIFY, SpecInputType.NGRAM_VERIFY, } diff --git a/python/sglang/srt/speculative/triton_ops/__init__.py b/python/sglang/srt/speculative/triton_ops/__init__.py new file mode 100644 index 000000000000..a8ea8f4c704b --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Triton kernels for speculative decoding.""" + +from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, +) + +__all__ = ["FusedKVMaterializeHelper"] diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py new file mode 100644 index 000000000000..e7dc4c05ddfc --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -0,0 +1,303 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fused Triton kernel for DFlash KV materialization. + +Combines: KV projection (cuBLAS) + RMSNorm + RoPE (Triton), then pool-managed KV writes. +""" + +from typing import Callable, List + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_norm_rope_kernel( + kv_ptr, # [total_ctx, kv_size * 2] + k_norm_weight_ptr, # [head_dim] + cos_sin_cache_ptr, # [max_pos, rotary_dim] + positions_ptr, # [total_ctx] + k_out_ptr, # [total_ctx, num_kv_heads, head_dim] + v_out_ptr, # [total_ctx, num_kv_heads, head_dim] + kv_stride_ctx, + cos_sin_stride_pos, + k_out_stride_ctx, + k_out_stride_head, + v_out_stride_ctx, + v_out_stride_head, + total_ctx, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_size: tl.constexpr, + rotary_dim: tl.constexpr, + half_rotary_dim: tl.constexpr, + eps: tl.constexpr, + BLOCK_HD: tl.constexpr, +): + """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads).""" + ctx_id = tl.program_id(0) + head_id = tl.program_id(1) + if ctx_id >= total_ctx: + return + + # Load metadata + position = tl.load(positions_ptr + ctx_id) + + # Compute base pointers + kv_base = kv_ptr + ctx_id * kv_stride_ctx + k_base = kv_base + head_id * head_dim + v_base = kv_base + kv_size + head_id * head_dim + k_write = k_out_ptr + ctx_id * k_out_stride_ctx + head_id * k_out_stride_head + v_write = v_out_ptr + ctx_id * v_out_stride_ctx + head_id * v_out_stride_head + + # Load K and V + offs = tl.arange(0, BLOCK_HD) + mask_hd = offs < head_dim + mask_half = offs < half_rotary_dim + + k_raw = tl.load(k_base + offs, mask=mask_hd, other=0.0).to(tl.float32) + v_raw = tl.load(v_base + offs, mask=mask_hd, other=0.0) + + # RMSNorm on K + inv_rms = tl.rsqrt(tl.sum(k_raw * k_raw) / head_dim + eps) + norm_w = tl.load(k_norm_weight_ptr + offs, mask=mask_hd, other=1.0).to(tl.float32) + k_normed = k_raw * inv_rms * norm_w + + # RoPE (neox style): k_first, k_second -> rotated + cos_sin_base = cos_sin_cache_ptr + position * cos_sin_stride_pos + cos_v = tl.load(cos_sin_base + offs, mask=mask_half, other=1.0).to(tl.float32) + sin_v = tl.load( + cos_sin_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + + # Extract first/second halves of K for rotation + k_first = tl.where(mask_half, k_normed, 0.0) + k_second_raw = tl.load( + k_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + norm_w_second = tl.load( + k_norm_weight_ptr + half_rotary_dim + offs, mask=mask_half, other=1.0 + ).to(tl.float32) + k_second = k_second_raw * inv_rms * norm_w_second + + # Apply rotation + k_rot_first = k_first * cos_v - k_second * sin_v + k_rot_second = k_second * cos_v + k_first * sin_v + + # Store V (no transform) + tl.store(v_write + offs, v_raw, mask=mask_hd) + + # Store K: rotated halves + pass-through + tl.store(k_write + offs, k_rot_first.to(v_raw.dtype), mask=mask_half) + tl.store( + k_write + half_rotary_dim + offs, k_rot_second.to(v_raw.dtype), mask=mask_half + ) + mask_pass = (offs >= rotary_dim) & (offs < head_dim) + tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) + + +def _fused_norm_rope( + kv: torch.Tensor, # [total_ctx, kv_size*2] + k_norm_weight: torch.Tensor, # [head_dim] + cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] + positions: torch.Tensor, # [total_ctx] + num_kv_heads: int, + head_dim: int, + rotary_dim: int, + eps: float = 1e-6, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused RMSNorm + RoPE materialization for a single layer.""" + total_ctx = kv.shape[0] + if total_ctx == 0: + empty = torch.empty( + (0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + return empty, empty + + kv_size = num_kv_heads * head_dim + if kv.shape[1] != kv_size * 2: + raise ValueError( + "Invalid fused KV projection shape: " + f"got {tuple(kv.shape)}, expected second dim {kv_size * 2}." + ) + if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={rotary_dim}, head_dim={head_dim}." + ) + + half_rotary_dim = rotary_dim // 2 + BLOCK_HD = triton.next_power_of_2(head_dim) + + # Ensure int64 for indexing + if positions.device != kv.device: + positions = positions.to(device=kv.device, dtype=torch.int64) + elif positions.dtype != torch.int64: + positions = positions.to(torch.int64) + + k_out = torch.empty( + (total_ctx, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + v_out = torch.empty_like(k_out) + + _fused_norm_rope_kernel[(total_ctx, num_kv_heads)]( + kv, + k_norm_weight, + cos_sin_cache, + positions, + k_out, + v_out, + kv.stride(0), + cos_sin_cache.stride(0), + k_out.stride(0), + k_out.stride(1), + v_out.stride(0), + v_out.stride(1), + total_ctx, + num_kv_heads, + head_dim, + kv_size, + rotary_dim, + half_rotary_dim, + eps, + BLOCK_HD, + ) + return k_out, v_out + + +class FusedKVMaterializeHelper: + """Fused KV materialization helper using batched projection. + + Uses torch.einsum for batched KV projection across all layers, + then a Triton kernel for fused RMSNorm + RoPE materialization per layer. + """ + + def __init__( + self, + layers: List, + rotary_emb, + num_kv_heads: int, + head_dim: int, + device: torch.device, + ): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.n_layers = len(layers) + self.device = device + + self.rotary_dim = int(getattr(rotary_emb, "rotary_dim", head_dim)) + self.is_neox_style = bool(getattr(rotary_emb, "is_neox_style", True)) + + if not self.is_neox_style: + raise NotImplementedError("Only neox-style RoPE is supported.") + if self.rotary_dim <= 0 or self.rotary_dim > self.head_dim: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={self.rotary_dim}, head_dim={self.head_dim}." + ) + + # Pre-extract and stack weights for batched projection. + kv_weights = [] + self.k_norm_weights = [] + self.eps_values = [] + + for layer_id, layer in enumerate(layers): + attn = layer.self_attn + if int(attn.num_kv_heads) != self.num_kv_heads: + raise ValueError( + "num_kv_heads mismatch across layers for fused KV path: " + f"expected {self.num_kv_heads}, got {int(attn.num_kv_heads)} at layer {layer_id}." + ) + if int(attn.head_dim) != self.head_dim: + raise ValueError( + "head_dim mismatch across layers for fused KV path: " + f"expected {self.head_dim}, got {int(attn.head_dim)} at layer {layer_id}." + ) + layer_rotary_dim = int( + getattr(attn.rotary_emb, "rotary_dim", self.head_dim) + ) + layer_is_neox = bool(getattr(attn.rotary_emb, "is_neox_style", True)) + if ( + layer_rotary_dim != self.rotary_dim + or layer_is_neox != self.is_neox_style + ): + raise ValueError( + "RoPE config mismatch across layers for fused KV path: " + f"expected (rotary_dim={self.rotary_dim}, neox={self.is_neox_style}), " + f"got (rotary_dim={layer_rotary_dim}, neox={layer_is_neox}) at layer {layer_id}." + ) + + # Extract KV portion of QKV weight + qkv_w = attn.qkv_proj.weight + kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] + kv_weights.append(kv_weight) + self.k_norm_weights.append(attn.k_norm.weight) + self.eps_values.append(attn.k_norm.variance_epsilon) + + # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] + self.batched_kv_weight = torch.stack(kv_weights) + + def materialize( + self, + ctx_hidden: torch.Tensor, + positions: torch.Tensor, + write_layer_kv: Callable[[int, torch.Tensor, torch.Tensor], None], + ) -> None: + """Materialize KV cache for all layers using batched projection.""" + total_ctx = ctx_hidden.shape[0] + if total_ctx == 0: + return + + if positions.ndim != 1: + positions = positions.reshape(-1) + if positions.numel() != total_ctx: + raise ValueError( + "positions must match ctx_hidden token count for fused KV materialization: " + f"positions={positions.numel()}, total_ctx={total_ctx}." + ) + + max_position = int(positions.max().item()) + ensure_cos_sin_cache_length = getattr( + self.rotary_emb, "_ensure_cos_sin_cache_length", None + ) + if callable(ensure_cos_sin_cache_length): + ensure_cos_sin_cache_length(max_position) + + cos_sin_cache = self.rotary_emb.cos_sin_cache + if max_position >= int(cos_sin_cache.shape[0]): + raise RuntimeError( + "RoPE cos/sin cache is too short for fused KV materialization: " + f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." + ) + if cos_sin_cache.device != ctx_hidden.device: + cos_sin_cache = cos_sin_cache.to(ctx_hidden.device) + + # Batched KV projection: [n_layers, total_ctx, kv_size*2] + kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) + + # Per-layer fused norm/RoPE/materialize, then delegate writes to the KV pool. + for layer_id in range(self.n_layers): + cache_k, cache_v = _fused_norm_rope( + kv_all[layer_id], + self.k_norm_weights[layer_id], + cos_sin_cache, + positions, + self.num_kv_heads, + self.head_dim, + self.rotary_dim, + self.eps_values[layer_id], + ) + write_layer_kv(layer_id, cache_k, cache_v) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9cbd2e59dc90..6022f602c3f4 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -107,6 +107,10 @@ DEFAULT_TARGET_MODEL_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_DRAFT_MODEL_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" +# DFLASH model +DEFAULT_TARGET_MODEL_DFLASH = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_DRAFT_MODEL_DFLASH = "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + # EAGLE2 with DP-Attention models DEFAULT_TARGET_MODEL_EAGLE_DP_ATTN = "Qwen/Qwen3-30B-A3B" DEFAULT_DRAFT_MODEL_EAGLE_DP_ATTN = "Tengyunw/qwen3_30b_moe_eagle3" diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py new file mode 100644 index 000000000000..aa9ee2327d21 --- /dev/null +++ b/test/registered/spec/dflash/test_dflash.py @@ -0,0 +1,152 @@ +import os +import unittest + +import openai + +from sglang.srt.environ import envs +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin +from sglang.test.kits.matched_stop_kit import MatchedStopMixin +from sglang.test.kits.radix_cache_server_kit import gen_radix_tree +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=300, suite="stage-b-test-1-gpu-small") + + +class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): + max_running_requests = 64 + attention_backend = "flashinfer" + page_size = 1 + other_launch_args = [] + model = DEFAULT_TARGET_MODEL_DFLASH + draft_model = DEFAULT_DRAFT_MODEL_DFLASH + gsm8k_accuracy_thres = 0.75 + gsm8k_accept_length_thres = 2.8 + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + launch_args = [ + "--trust-remote-code", + "--attention-backend", + cls.attention_backend, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + cls.draft_model, + "--page-size", + str(cls.page_size), + "--max-running-requests", + str(cls.max_running_requests), + "--cuda-graph-bs", + *[str(i) for i in range(1, cls.max_running_requests + 1)], + ] + launch_args.extend(cls.other_launch_args) + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( + 1 + ), envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override( + True + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=launch_args, + ) + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_early_stop(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + for i in range(8): + max_tokens = (i % 3) + 1 + response = client.completions.create( + model=self.model, + prompt=f"There are {i} apples on the table. How to divide them equally?", + max_tokens=max_tokens, + temperature=0, + ) + text = response.choices[0].text + print(f"early_stop: max_tokens={max_tokens}, text={text!r}") + assert self.process.poll() is None + + def test_eos_handling(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + response = client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Today is a sunny day and I like"}], + max_tokens=256, + temperature=0.1, + ) + text = response.choices[0].message.content + print(f"eos_handling: text={text!r}") + self.assertNotIn("<|eot_id|>", text) + self.assertNotIn("<|end_of_text|>", text) + assert self.process.poll() is None + + def test_greedy_determinism(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + prompt = "The capital of France is" + outputs = [] + for _ in range(2): + response = client.completions.create( + model=self.model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + outputs.append(response.choices[0].text) + print(f"determinism: {outputs=}") + self.assertEqual(outputs[0], outputs[1]) + assert self.process.poll() is None + + +class TestDFlashServerPage256(TestDFlashServerBase): + page_size = 256 + + def test_radix_attention(self): + import requests + + nodes = gen_radix_tree(num_nodes=50) + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} + for node in nodes + ], + } + res = requests.post(self.base_url + "/generate", json=data) + assert res.status_code == 200 + assert self.process.poll() is None + + +class TestDFlashServerChunkedPrefill(TestDFlashServerBase): + other_launch_args = ["--chunked-prefill-size", "4"] + + +class TestDFlashServerNoCudaGraph(TestDFlashServerBase): + other_launch_args = ["--disable-cuda-graph"] + + +if __name__ == "__main__": + unittest.main()