From 10e563feef9488000c1ed257749510afb85c3ce3 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 6 Jan 2026 23:24:03 +0000 Subject: [PATCH 01/73] starting dflash impl --- .../sglang/srt/model_executor/model_runner.py | 40 +++ python/sglang/srt/models/qwen3.py | 12 + python/sglang/srt/server_args.py | 67 +++- .../srt/speculative/dflash_draft_model.py | 273 +++++++++++++++ python/sglang/srt/speculative/dflash_info.py | 325 +++++++++++++++++ python/sglang/srt/speculative/dflash_utils.py | 87 +++++ .../sglang/srt/speculative/dflash_worker.py | 328 ++++++++++++++++++ python/sglang/srt/speculative/spec_info.py | 16 +- .../models/test_qwen3_dflash_correctness.py | 132 +++++++ test/srt/test_dflash_acceptance_unit.py | 55 +++ 10 files changed, 1333 insertions(+), 2 deletions(-) create mode 100644 python/sglang/srt/speculative/dflash_draft_model.py create mode 100644 python/sglang/srt/speculative/dflash_info.py create mode 100644 python/sglang/srt/speculative/dflash_utils.py create mode 100644 python/sglang/srt/speculative/dflash_worker.py create mode 100644 test/manual/models/test_qwen3_dflash_correctness.py create mode 100644 test/srt/test_dflash_acceptance_unit.py diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1d69c0582781..d50668ff950e 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -138,6 +138,7 @@ set_global_server_args_for_scheduler, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.dflash_utils import build_target_layer_ids from sglang.srt.utils import ( MultiprocessingSerializer, cpu_has_amx_support, @@ -316,6 +317,7 @@ def __init__( self.remote_instance_transfer_engine_weight_info = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False + self.dflash_use_aux_hidden_state = False if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -341,6 +343,36 @@ def __init__( # if there is no aux layer, set to None self.eagle_aux_hidden_state_layer_ids = None + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + # Select target layers to capture for building DFlash context features. + draft_model_config = ModelConfig.from_server_args( + server_args, + model_path=(server_args.speculative_draft_model_path), + model_revision=server_args.speculative_draft_model_revision, + is_draft_model=True, + ) + draft_num_layers = getattr(draft_model_config.hf_config, "num_hidden_layers", None) + target_num_layers = getattr(self.model_config.hf_config, "num_hidden_layers", None) + if draft_num_layers is None or target_num_layers is None: + raise ValueError( + "DFLASH requires both draft and target to expose num_hidden_layers in config. " + f"Got draft={draft_num_layers}, target={target_num_layers}." + ) + + trained_target_layers = getattr(draft_model_config.hf_config, "num_target_layers", None) + if trained_target_layers is not None and trained_target_layers != target_num_layers: + logger.warning( + "DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; " + "selecting capture layers based on the runtime target model.", + trained_target_layers, + target_num_layers, + ) + + self.dflash_use_aux_hidden_state = True + self.dflash_aux_hidden_state_layer_ids = build_target_layer_ids( + int(target_num_layers), int(draft_num_layers) + ) + # Apply the rank zero filter to logger if server_args.show_time_cost: enable_show_time_cost() @@ -581,6 +613,14 @@ def initialize(self, min_per_gpu_memory: float): self.eagle_aux_hidden_state_layer_ids ) + if self.dflash_use_aux_hidden_state: + if not hasattr(self.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH." + ) + self.model.set_dflash_layers_to_capture(self.dflash_aux_hidden_state_layer_ids) + # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 9220831f6c6a..70e8dad75701 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -587,5 +587,17 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError("DFLASH requires explicit layer_ids for aux hidden capture.") + + self.capture_aux_hidden_states = True + # SGLang captures "before layer i". To capture the hidden state after target + # layer `k` (HF-style), we capture before layer `k + 1`. + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = Qwen3ForCausalLM diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a2b26e0e0deb..67ad21690d0d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2037,6 +2037,71 @@ def _handle_speculative_decoding(self): if self.speculative_algorithm == "NEXTN": self.speculative_algorithm = "EAGLE" + if self.speculative_algorithm == "DFLASH": + if self.enable_dp_attention: + raise ValueError( + "Currently DFLASH speculative decoding does not support dp attention." + ) + + if self.tp_size != 1 or self.pp_size != 1: + raise ValueError( + "Currently DFLASH speculative decoding only supports tp_size == 1 and pp_size == 1." + ) + + if self.speculative_draft_model_path is None: + raise ValueError( + "DFLASH speculative decoding requires setting --speculative-draft-model-path." + ) + + # Set default spec params expected by generic spec-v1 plumbing. + # For DFlash, the natural unit is `block_size` (verify window length). + if self.speculative_num_steps is None: + self.speculative_num_steps = 1 + if self.speculative_eagle_topk is None: + self.speculative_eagle_topk = 1 + if self.speculative_num_draft_tokens is None: + inferred_block_size = None + try: + if os.path.isdir(self.speculative_draft_model_path): + draft_config_path = os.path.join( + self.speculative_draft_model_path, "config.json" + ) + if os.path.isfile(draft_config_path): + with open(draft_config_path, "r") as f: + draft_config_json = json.load(f) + inferred_block_size = draft_config_json.get("block_size") + except Exception as e: + logger.warning( + "Failed to infer DFlash block_size from draft config.json; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", + e, + ) + + if inferred_block_size is None: + inferred_block_size = 16 + logger.warning( + "speculative_num_draft_tokens is not set; defaulting to %d for DFLASH.", + inferred_block_size, + ) + self.speculative_num_draft_tokens = int(inferred_block_size) + + if self.max_running_requests is None: + self.max_running_requests = 48 + logger.warning( + "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." + ) + + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled when using DFLASH speculative decoding (spec v2 is not supported yet)." + ) + + if self.enable_mixed_chunk: + self.enable_mixed_chunk = False + logger.warning( + "Mixed chunked prefill is disabled because of using dflash speculative decoding." + ) + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding @@ -3427,7 +3492,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], + choices=["DFLASH", "EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( diff --git a/python/sglang/srt/speculative/dflash_draft_model.py b/python/sglang/srt/speculative/dflash_draft_model.py new file mode 100644 index 000000000000..bd6b22132c09 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_draft_model.py @@ -0,0 +1,273 @@ +from __future__ import annotations + +import logging +import os +from typing import Optional + +import torch +from safetensors.torch import safe_open +from torch import nn +from transformers import AutoConfig, DynamicCache +from transformers.cache_utils import Cache +from transformers.models.qwen3.modeling_qwen3 import ( + ALL_ATTENTION_FUNCTIONS, + Qwen3MLP, + Qwen3RMSNorm, + Qwen3RotaryEmbedding, + FlashAttentionKwargs, + eager_attention_forward, + rotate_half, +) +from typing_extensions import Unpack + +logger = logging.getLogger(__name__) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_len = q.size(-2) + q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen3DFlashAttention(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = False + + self.q_proj = nn.Linear( + config.hidden_size, + config.num_attention_heads * self.head_dim, + bias=config.attention_bias, + ) + self.k_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + config.hidden_size, + config.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + target_hidden: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + bsz, q_len = hidden_states.shape[:-1] + ctx_len = target_hidden.shape[1] + + q = self.q_proj(hidden_states) + q = q.view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + v = torch.cat([v_ctx, v_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + v = v.transpose(1, 2) + + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) + + attn_fn = eager_attention_forward + if getattr(self.config, "_attn_implementation", "eager") != "eager": + attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=None, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Qwen3DFlashDecoderLayer(nn.Module): + def __init__(self, config, layer_idx: int): + super().__init__() + self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) + self.mlp = Qwen3MLP(config) + self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + *, + target_hidden: torch.Tensor, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: bool = False, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + past_key_values=past_key_value, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + )[0] + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashDraftModel(nn.Module): + """Local (non-trust_remote_code) DFlash draft model implementation. + + This is adapted from the DFlash reference `modeling_dflash.py` shipped with + the draft checkpoint, but is loaded as first-party code in SGLang. + + The model intentionally does NOT include embedding or lm_head weights; the + DFlash algorithm uses the target model's embedding and lm_head. + """ + + def __init__(self, config) -> None: + super().__init__() + self.config = config + + self.layers = nn.ModuleList( + [ + Qwen3DFlashDecoderLayer(config, layer_idx=i) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + + # DFlash context feature projection: concat(draft_num_layers x hidden_size) -> hidden_size. + self.fc = nn.Linear( + config.num_hidden_layers * config.hidden_size, + config.hidden_size, + bias=False, + ) + self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.block_size = config.block_size + + def forward( + self, + *, + noise_embedding: torch.Tensor, + target_hidden: torch.Tensor, + position_ids: torch.LongTensor, + past_key_values: Optional[Cache] = None, + use_cache: bool = False, + attention_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> torch.Tensor: + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer( + hidden_states=hidden_states, + target_hidden=target_hidden, + attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + use_cache=use_cache, + position_embeddings=position_embeddings, + **kwargs, + ) + + return self.norm(hidden_states) + + def make_cache(self) -> DynamicCache: + return DynamicCache() + + +def load_dflash_draft_model( + model_path: str, + *, + device: torch.device, + dtype: torch.dtype, +) -> tuple[DFlashDraftModel, object]: + """Load DFlash draft model weights from a local folder.""" + config = AutoConfig.from_pretrained(model_path, trust_remote_code=False) + # Ensure we don't accidentally select optional FlashAttention implementations. + setattr(config, "_attn_implementation", "eager") + + model = DFlashDraftModel(config).to(device=device, dtype=dtype) + + weights_path = os.path.join(model_path, "model.safetensors") + if not os.path.isfile(weights_path): + raise FileNotFoundError(f"DFLASH draft weights not found: {weights_path}") + + model_state = model.state_dict() + unexpected: list[str] = [] + with safe_open(weights_path, framework="pt", device=str(device)) as f: + for key in f.keys(): + if key not in model_state: + unexpected.append(key) + continue + model_state[key].copy_(f.get_tensor(key)) + + if unexpected: + logger.warning( + "DFLASH draft checkpoint has %d unexpected keys (ignored). Example: %s", + len(unexpected), + unexpected[0], + ) + + model.eval() + model.requires_grad_(False) + return model, config + diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py new file mode 100644 index 000000000000..e9a0c93924e0 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info.py @@ -0,0 +1,325 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import torch +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.speculative.dflash_utils import compute_dflash_accept_len_and_bonus + + +@dataclass +class DFlashDraftInput(SpecInput): + """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. + + This object is stored on `ScheduleBatch.spec_info` between decode iterations. + It is NOT sent to model attention backends; the DFlash worker uses it to run + the draft model and to carry draft-side caches. + + Invariant (per request): + - `draft_cache.get_seq_length() + ctx_len == batch.seq_lens[i]` + where `ctx_len` is the number of target context-feature tokens carried in + `target_hidden` for that request. + """ + + # Current token to start the next DFlash block (one per request). + verified_id: torch.Tensor + + # Flattened context features for tokens that need to be appended into the draft cache. + # Shape: [sum(ctx_lens), num_draft_layers * hidden_size] + target_hidden: torch.Tensor + + # Context lengths on CPU, one per request. Used to slice `target_hidden`. + ctx_lens_cpu: List[int] + + # Per-request transformers DynamicCache objects for the draft model. + draft_caches: List[object] + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Draft state does not change token accounting. + return (1, 1) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + keep_indices = new_indices.tolist() + + old_ctx_lens_cpu = self.ctx_lens_cpu + old_target_hidden = self.target_hidden + + self.verified_id = self.verified_id[new_indices] + self.draft_caches = [self.draft_caches[i] for i in keep_indices] + self.ctx_lens_cpu = [old_ctx_lens_cpu[i] for i in keep_indices] + + if old_target_hidden is None or old_target_hidden.numel() == 0: + self.target_hidden = old_target_hidden + return + + old_offsets: List[int] = [0] + for ln in old_ctx_lens_cpu: + old_offsets.append(old_offsets[-1] + int(ln)) + + segments: List[torch.Tensor] = [] + for idx in keep_indices: + ln = int(old_ctx_lens_cpu[idx]) + if ln == 0: + continue + segments.append(old_target_hidden[old_offsets[idx] : old_offsets[idx + 1]]) + + self.target_hidden = torch.cat(segments, dim=0) if segments else old_target_hidden[:0] + + def merge_batch(self, spec_info: "DFlashDraftInput"): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.draft_caches.extend(spec_info.draft_caches) + self.ctx_lens_cpu.extend(spec_info.ctx_lens_cpu) + if self.target_hidden is None or self.target_hidden.numel() == 0: + self.target_hidden = spec_info.target_hidden + elif spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0: + self.target_hidden = torch.cat([self.target_hidden, spec_info.target_hidden], dim=0) + + +@dataclass +class DFlashVerifyInput(SpecInput): + """Inputs for a target-model verify forward in DFlash (spec-v1). + + The verify forward is run with `ForwardMode.TARGET_VERIFY` so that the target + model returns logits for all tokens in the block, enabling accept-length + computation. + """ + + draft_token: torch.Tensor + positions: torch.Tensor + draft_token_num: int + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_VERIFY) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + + def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): + if batch.forward_mode.is_idle(): + return + + batch.input_ids = self.draft_token + + if page_size == 1: + batch.out_cache_loc = alloc_token_slots(batch.tree_cache, len(batch.input_ids)) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu + end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), + ) + self.last_loc = last_loc + + bs = batch.batch_size() + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + device = req_pool_indices.device + bs = len(req_pool_indices) + + qo_indptr = torch.arange( + 0, + (bs + 1) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device=device, + ) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + self.draft_token_num * bs, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + + # Causal custom mask for the verify block: + # - all query tokens can attend to all prefix tokens + # - within the block, use a lower-triangular mask + prefix_mask = torch.full( + (paged_kernel_lens_sum * self.draft_token_num,), + True, + dtype=torch.bool, + device=device, + ) + tri = torch.tril( + torch.ones( + (self.draft_token_num, self.draft_token_num), + dtype=torch.bool, + device=device, + ) + ).flatten() + custom_mask = torch.cat([prefix_mask, tri.repeat(bs)], dim=0) + + return kv_indices, cum_kv_seq_len, qo_indptr, custom_mask + + def verify( + self, + *, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + page_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Greedy DFlash verification. + + Returns: + new_verified_id: int64 tensor [bs] (the new current token per request) + commit_lens: int32 tensor [bs] (how many verify-input tokens are committed) + next_target_hidden: tensor [sum(commit_lens), feature_dim] + accept_length_per_req_cpu: list[int] (accepted draft tokens per request) + """ + if batch.forward_mode.is_idle(): + empty = torch.empty((0,), dtype=torch.int64, device=batch.device) + return empty, empty.to(torch.int32), empty, [] + + bs = batch.batch_size() + device = logits_output.next_token_logits.device + + candidates = self.draft_token.view(bs, self.draft_token_num) + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, self.draft_token_num + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + + candidates_cpu = candidates.cpu().tolist() + accept_len_cpu = accept_len.cpu().tolist() + bonus_cpu = bonus.cpu().tolist() + + commit_lens_cpu: List[int] = [] + new_verified_cpu: List[int] = [] + accept_length_per_req_cpu: List[int] = [] + + for i, req in enumerate(batch.reqs): + # Proposed: accepted draft tokens, then the bonus token. + proposed = candidates_cpu[i][1 : 1 + accept_len_cpu[i]] + [bonus_cpu[i]] + + appended = 0 + for tok in proposed: + req.output_ids.append(int(tok)) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(int(tok)) + + # DFlash always treats the last appended token as the new "current token" + # (uncommitted); therefore we commit exactly `appended` verify-input tokens. + if appended <= 0: + raise RuntimeError("DFLASH verify unexpectedly appended 0 tokens.") + commit_lens_cpu.append(appended) + new_verified_cpu.append(req.output_ids[-1]) + accept_length_per_req_cpu.append(max(0, appended - 1)) + + req.spec_verify_ct += 1 + req.spec_accepted_tokens += accept_length_per_req_cpu[-1] + + commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) + + # Free uncommitted KV cache slots and compact out_cache_loc. + if page_size == 1: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + keep_mask = ( + torch.arange(self.draft_token_num, device=device)[None, :] + < commit_lens[:, None] + ) + batch.token_to_kv_pool_allocator.free(out_cache_loc[~keep_mask]) + batch.out_cache_loc = out_cache_loc[keep_mask] + else: + # Page-size > 1 is not supported in the initial DFlash implementation. + raise NotImplementedError("DFLASH verify with page_size > 1 is not supported yet.") + + # Update req-level KV cache accounting. + for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): + req.kv_committed_len += commit_len + req.kv_allocated_len = req.kv_committed_len + + # Update req_to_token pool mapping for newly committed tokens. + end_offset = batch.seq_lens + commit_lens.to(batch.seq_lens.dtype) + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + # Update batch seq lens. + batch.seq_lens.add_(commit_lens.to(batch.seq_lens.dtype)) + batch.seq_lens_cpu.add_(torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype)) + # Keep seq_lens_sum in sync; flashinfer indices updaters rely on this for buffer sizing. + batch.seq_lens_sum += sum(commit_lens_cpu) + + # Build next-step context features from the committed verify-input tokens. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError("DFLASH verify requires target hidden states, but got None.") + hidden = hidden.view(bs, self.draft_token_num, -1) + segments: List[torch.Tensor] = [] + for i, ln in enumerate(commit_lens_cpu): + if ln > 0: + segments.append(hidden[i, :ln, :]) + next_target_hidden = torch.cat(segments, dim=0) if segments else hidden[:0] + + # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). + logits_output.hidden_states = None + + new_verified_id = torch.tensor(new_verified_cpu, dtype=torch.int64, device=device) + return new_verified_id, commit_lens, next_target_hidden, accept_length_per_req_cpu diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py new file mode 100644 index 000000000000..f1a9b191bc72 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import List, Tuple + +import torch + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Select target layer indices used to build DFlash context features. + + Mirrors the upstream DFlash helper in `docs/dflash/model/utils.py`, but keeps the + logic local to SGLang. + + Args: + num_target_layers: Number of transformer layers in the runtime target model. + num_draft_layers: Number of layers in the DFlash draft model. + + Returns: + A list of 0-based target layer indices of length `num_draft_layers`. + + Notes: + - DFlash uses hidden states after each selected target layer (HF-style). + - SGLang captures "before layer i", so the model hook will typically add +1 + when mapping to capture points. + """ + if num_target_layers <= 0: + raise ValueError(f"num_target_layers must be positive, got {num_target_layers}.") + if num_draft_layers <= 0: + raise ValueError(f"num_draft_layers must be positive, got {num_draft_layers}.") + + if num_draft_layers == 1: + return [num_target_layers // 2] + + start = 1 + end = num_target_layers - 3 + if end < start: + raise ValueError( + "DFlash layer selection requires num_target_layers >= 4. " + f"Got num_target_layers={num_target_layers}." + ) + + span = end - start + return [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + + +def compute_dflash_accept_len_and_bonus( + *, + candidates: torch.Tensor, + target_predict: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens (greedy verify rule). + + Args: + candidates: Token ids proposed by the DFlash draft, including the current token. + Shape: [bs, block_size]. candidates[:, 0] is the current token. + target_predict: Token ids predicted by the target model for each position in the block. + Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t. + + Returns: + accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). + bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append). + + Notes: + Matches the reference implementation rule: + accept while candidates[:, 1:] == target_predict[:, :-1] consecutively. + """ + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if target_predict.shape != candidates.shape: + raise ValueError( + "target_predict must have the same shape as candidates. " + f"candidates.shape={tuple(candidates.shape)}, target_predict.shape={tuple(target_predict.shape)}" + ) + + bs, block_size = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}.") + + matches = candidates[:, 1:] == target_predict[:, :-1] + accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) + bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] + return accept_len, bonus.to(torch.int64) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py new file mode 100644 index 000000000000..fb788bd0a54e --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -0,0 +1,328 @@ +import logging +from typing import List, Optional, Union + +import torch +import torch.nn.functional as F + +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.dflash_draft_model import load_dflash_draft_model +from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + +logger = logging.getLogger(__name__) + + +class DFlashWorker: + """DFlash speculative decoding worker (spec-v1, tp=1/pp=1).""" + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.server_args = server_args + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moe_ep_rank = moe_ep_rank + self.nccl_port = nccl_port + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.tp_rank = tp_rank + self.page_size = server_args.page_size + self.device = target_worker.device + + self._mask_token_id = self._resolve_mask_token_id() + self._warned_forced_greedy = False + self._logged_first_verify = False + + # Load the DFlash draft model (weights are separate from the target model). + # This is kept as a standalone module (not a TpModelWorker) since its forward + # is non-causal and differs from standard decoder-only models. + draft_device = torch.device(target_worker.device) + draft_dtype = target_worker.model_runner.dtype + self.draft_model, self.draft_config = load_dflash_draft_model( + server_args.speculative_draft_model_path, + device=draft_device, + dtype=draft_dtype, + ) + self.block_size = int(getattr(self.draft_config, "block_size", 16)) + if self.tp_rank == 0: + logger.info( + "Loaded DFLASH draft model. path=%s, dtype=%s, device=%s, block_size=%s, num_hidden_layers=%s, mask_token_id=%s", + server_args.speculative_draft_model_path, + draft_dtype, + draft_device, + self.block_size, + getattr(self.draft_config, "num_hidden_layers", None), + self._mask_token_id, + ) + + def __getattr__(self, name): + # Delegate anything not implemented yet to the target worker. + return getattr(self.target_worker, name) + + def clear_cache_pool(self): + # No draft-side pools to clear in the stub implementation. + return + + def _resolve_mask_token_id(self) -> int: + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is None: + raise RuntimeError("DFLASH requires tokenizer initialization (skip_tokenizer_init is not supported).") + + vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) + mask_token_id = getattr(tokenizer, "mask_token_id", None) + if mask_token_id is None: + # `convert_tokens_to_ids` can return `None` (or an unk id) depending on tokenizer. + # Prefer checking the explicit vocab mapping first. + vocab = tokenizer.get_vocab() + mask_token_id = vocab.get("<|MASK|>", None) + + if mask_token_id is None: + # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. + # This is safe only when the resulting id stays within the target model vocab size. + added = tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + mask_token_id = getattr(tokenizer, "mask_token_id", None) + if mask_token_id is None: + mask_token_id = tokenizer.convert_tokens_to_ids("<|MASK|>") + + if added and self.tp_rank == 0: + logger.info( + "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", + "<|MASK|>", + mask_token_id, + len(tokenizer), + vocab_size, + ) + + if mask_token_id is None or int(mask_token_id) < 0: + raise ValueError("DFLASH requires a `<|MASK|>` token id, but it could not be resolved.") + + if mask_token_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={mask_token_id}, vocab_size={vocab_size}. " + "This likely means `<|MASK|>` requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + return int(mask_token_id) + + def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: DFlashDraftInput): + if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): + return + + if batch.has_grammar: + raise ValueError("DFLASH does not support grammar-constrained decoding yet.") + if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: + if not self._warned_forced_greedy and self.tp_rank == 0: + logger.warning( + "DFLASH currently supports greedy verification only; " + "ignoring non-greedy sampling params (e.g. temperature/top_p/top_k) and using argmax." + ) + self._warned_forced_greedy = True + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError("DFLASH draft state missing target_hidden context features.") + if len(draft_input.ctx_lens_cpu) != bs: + raise RuntimeError( + f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." + ) + + embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() + + # Slice ragged target_hidden on CPU for simplicity. + offsets: List[int] = [0] + for ln in draft_input.ctx_lens_cpu: + offsets.append(offsets[-1] + int(ln)) + + candidates: List[torch.Tensor] = [] + for i, (req, ctx_len) in enumerate(zip(batch.reqs, draft_input.ctx_lens_cpu, strict=True)): + start_pos = int(batch.seq_lens_cpu[i].item()) + cache = draft_input.draft_caches[i] + cache_len = int(cache.get_seq_length()) + + if cache_len + int(ctx_len) != start_pos: + raise RuntimeError( + "DFLASH draft cache length mismatch. " + f"{cache_len=} + {ctx_len=} != {start_pos=}. " + "This can happen if prefix caching is enabled; start with `--disable-radix-cache` for now." + ) + + target_hidden = draft_input.target_hidden[offsets[i] : offsets[i + 1]] + target_hidden = target_hidden.unsqueeze(0) # [1, ctx_len, feat] + + block_ids = torch.full( + (1, self.block_size), + self._mask_token_id, + dtype=torch.long, + device=device, + ) + block_ids[0, 0] = draft_input.verified_id[i].to(torch.long) + + noise_embedding = F.embedding(block_ids, embed_weight) + position_ids = torch.arange( + cache_len, + start_pos + self.block_size, + dtype=torch.long, + device=device, + ).unsqueeze(0) + + with torch.inference_mode(): + hidden = self.draft_model( + noise_embedding=noise_embedding, + target_hidden=target_hidden, + position_ids=position_ids, + past_key_values=cache, + use_cache=True, + ) + cache.crop(start_pos) + + draft_hidden = hidden[:, -self.block_size + 1 :, :] + draft_logits = F.linear(draft_hidden, head_weight) + draft_tokens = torch.argmax(draft_logits, dim=-1).to(torch.long) + + candidate = torch.cat( + [block_ids[0, 0].view(1), draft_tokens.view(-1)], + dim=0, + ) + candidates.append(candidate) + + draft_tokens = torch.stack(candidates, dim=0) # [bs, block_size] + positions = ( + batch.seq_lens.to(torch.long).unsqueeze(1) + + torch.arange(self.block_size, device=device, dtype=torch.long)[None, :] + ).flatten() + + verify_input = DFlashVerifyInput( + draft_token=draft_tokens.flatten(), + positions=positions, + draft_token_num=self.block_size, + ) + verify_input.prepare_for_verify(batch, self.page_size) + + batch.forward_mode = ForwardMode.TARGET_VERIFY if not batch.forward_mode.is_idle() else ForwardMode.IDLE + batch.spec_info = verify_input + batch.return_hidden_states = False + + def forward_batch_generation( + self, + batch: Union[ScheduleBatch, ModelWorkerBatch], + **kwargs, + ) -> GenerationBatchResult: + if getattr(batch, "return_logprob", False): + raise ValueError("DFLASH speculative decoding does not support return_logprob yet.") + + if isinstance(batch, ModelWorkerBatch): + # Should not happen for spec-v1 (non-overlap) scheduling, but keep a sane fallback. + return self.target_worker.forward_batch_generation(batch, **kwargs) + + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + if any(len(req.prefix_indices) > 0 for req in batch.reqs): + raise ValueError( + "DFLASH currently does not support radix/prefix cache hits (prefix_indices != 0). " + "Start with `--disable-radix-cache` for now." + ) + + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + logits_output, next_token_ids = ( + batch_result.logits_output, + batch_result.next_token_ids, + ) + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + draft_caches = [self.draft_model.make_cache() for _ in batch.reqs] + ctx_lens_cpu = model_worker_batch.seq_lens_cpu.tolist() + + batch.spec_info = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens_cpu=ctx_lens_cpu, + draft_caches=draft_caches, + ) + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=0, + can_run_cuda_graph=batch_result.can_run_cuda_graph, + ) + + # Decode / target-verify stage. + draft_input = batch.spec_info + if not isinstance(draft_input, DFlashDraftInput): + raise RuntimeError( + "DFLASH decode requires DFlashDraftInput state on the running batch. " + "This usually means the request did not complete the prefill stage." + ) + + self._prepare_for_speculative_decoding(batch, draft_input) + + model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.forward_mode.is_target_verify() + verify_input = model_worker_batch.spec_info + assert isinstance(verify_input, DFlashVerifyInput) + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True, **kwargs + ) + logits_output, can_run_cuda_graph = ( + batch_result.logits_output, + batch_result.can_run_cuda_graph, + ) + + ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) = verify_input.verify( + batch=batch, + logits_output=logits_output, + page_size=self.page_size, + ) + + # Update draft state for the next iteration. + draft_input.verified_id = new_verified_id + draft_input.target_hidden = next_target_hidden + draft_input.ctx_lens_cpu = commit_lens.cpu().tolist() + batch.spec_info = draft_input + batch.forward_mode = ForwardMode.DECODE + + num_accepted_tokens = sum(accept_length_per_req_cpu) + if not self._logged_first_verify and self.tp_rank == 0: + logger.info( + "DFLASH verify completed. accept_length_per_req=%s", + accept_length_per_req_cpu, + ) + self._logged_first_verify = True + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=new_verified_id, + num_accepted_tokens=num_accepted_tokens, + accept_length_per_req_cpu=accept_length_per_req_cpu, + can_run_cuda_graph=can_run_cuda_graph, + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 855e02d9d2d4..d448bab9c66e 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -14,6 +14,7 @@ class SpeculativeAlgorithm(Enum): """Enumeration of speculative decoding algorithms.""" + DFLASH = auto() EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() @@ -39,6 +40,9 @@ def is_eagle(self) -> bool: def is_eagle3(self) -> bool: return self == SpeculativeAlgorithm.EAGLE3 + def is_dflash(self) -> bool: + return self == SpeculativeAlgorithm.DFLASH + def is_standalone(self) -> bool: return self == SpeculativeAlgorithm.STANDALONE @@ -54,6 +58,13 @@ def create_worker( if self.is_none(): return None + if self.is_dflash(): + if enable_overlap: + raise ValueError("DFLASH does not support overlap scheduling (spec v2).") + from sglang.srt.speculative.dflash_worker import DFlashWorker + + return DFlashWorker + if self.is_eagle(): if enable_overlap: from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2 @@ -92,6 +103,8 @@ class SpecInputType(IntEnum): # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() EAGLE_VERIFY = auto() + DFLASH_DRAFT = auto() + DFLASH_VERIFY = auto() NGRAM_VERIFY = auto() @@ -102,11 +115,12 @@ def __init__(self, spec_input_type: SpecInputType): def is_draft_input(self) -> bool: # FIXME: remove this function which is only used for assertion # or use another variable name like `draft_input` to substitute `spec_info` - return self.spec_input_type == SpecInputType.EAGLE_DRAFT + return self.spec_input_type in {SpecInputType.EAGLE_DRAFT, SpecInputType.DFLASH_DRAFT} def is_verify_input(self) -> bool: return self.spec_input_type in { SpecInputType.EAGLE_VERIFY, + SpecInputType.DFLASH_VERIFY, SpecInputType.NGRAM_VERIFY, } diff --git a/test/manual/models/test_qwen3_dflash_correctness.py b/test/manual/models/test_qwen3_dflash_correctness.py new file mode 100644 index 000000000000..2fc3e876ecbd --- /dev/null +++ b/test/manual/models/test_qwen3_dflash_correctness.py @@ -0,0 +1,132 @@ +import os +import unittest + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + CustomTestCase, + find_available_port, + is_in_ci, + popen_launch_server, +) + + +def _send_generate(base_url: str, prompt: str, *, max_new_tokens: int) -> dict: + resp = requests.post( + base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": max_new_tokens, + }, + }, + timeout=600, + ) + resp.raise_for_status() + return resp.json() + + +class TestQwen3DFlashCorrectness(CustomTestCase): + def test_qwen3_dflash_matches_target_only_greedy(self): + if is_in_ci(): + self.skipTest("Manual test; skipped in CI.") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this manual DFlash integration test.") + + target_model = os.getenv("SGLANG_DFLASH_TARGET_MODEL", "Qwen/Qwen3-8B") + draft_model_path = os.getenv( + "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" + ) + if not os.path.isdir(draft_model_path): + self.skipTest( + f"Draft model folder not found: {draft_model_path}. " + "Set SGLANG_DFLASH_DRAFT_MODEL_PATH to run this test." + ) + + max_new_tokens = int(os.getenv("SGLANG_DFLASH_MAX_NEW_TOKENS", "128")) + prompt = os.getenv( + "SGLANG_DFLASH_PROMPT", + "How many positive whole-number divisors does 196 have?", + ) + + baseline_port = find_available_port(20000) + dflash_port = find_available_port(baseline_port + 1) + baseline_url = f"http://127.0.0.1:{baseline_port}" + dflash_url = f"http://127.0.0.1:{dflash_port}" + + # 1) Target-only baseline. + baseline_proc = popen_launch_server( + target_model, + baseline_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--disable-radix-cache", + ], + ) + try: + baseline = _send_generate( + baseline_url, prompt, max_new_tokens=max_new_tokens + ) + finally: + kill_process_tree(baseline_proc.pid) + + # 2) DFLASH speculative decoding. + dflash_proc = popen_launch_server( + target_model, + dflash_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--disable-radix-cache", + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + draft_model_path, + ], + ) + try: + dflash = _send_generate(dflash_url, prompt, max_new_tokens=max_new_tokens) + finally: + kill_process_tree(dflash_proc.pid) + + self.assertEqual( + baseline["output_ids"], + dflash["output_ids"], + f"Token IDs mismatch.\nbaseline={baseline['output_ids']}\ndflash={dflash['output_ids']}", + ) + self.assertEqual( + baseline["text"], + dflash["text"], + "Decoded text mismatch for greedy decoding.", + ) + + meta = dflash.get("meta_info", {}) + self.assertIn("spec_verify_ct", meta, f"Missing spec metrics: {meta.keys()}") + self.assertGreater(meta["spec_verify_ct"], 0, "DFLASH did not run verify steps.") + self.assertIn("spec_accept_length", meta, f"Missing spec_accept_length: {meta.keys()}") + self.assertGreaterEqual( + float(meta["spec_accept_length"]), + 1.0, + "Spec accept length should be >= 1.0 (bonus token).", + ) + print( + "DFLASH metrics:", + { + "spec_verify_ct": meta.get("spec_verify_ct"), + "spec_accept_length": meta.get("spec_accept_length"), + "spec_accept_rate": meta.get("spec_accept_rate"), + "spec_accept_token_num": meta.get("spec_accept_token_num"), + "spec_draft_token_num": meta.get("spec_draft_token_num"), + "completion_tokens": meta.get("completion_tokens"), + }, + flush=True, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_dflash_acceptance_unit.py b/test/srt/test_dflash_acceptance_unit.py new file mode 100644 index 000000000000..021df085575f --- /dev/null +++ b/test/srt/test_dflash_acceptance_unit.py @@ -0,0 +1,55 @@ +import unittest + +import torch + +from sglang.srt.speculative.dflash_utils import compute_dflash_accept_len_and_bonus + + +class TestDFlashAcceptanceUnit(unittest.TestCase): + def test_accept_len_and_bonus_basic(self): + candidates = torch.tensor( + [ + [10, 11, 12, 13], + [20, 21, 22, 23], + ], + dtype=torch.long, + ) + target_predict = torch.tensor( + [ + [11, 12, 55, 0], # accept 11,12 then bonus=55 + [99, 21, 22, 0], # accept none then bonus=99 + ], + dtype=torch.long, + ) + + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + self.assertEqual(accept_len.tolist(), [2, 0]) + self.assertEqual(bonus.tolist(), [55, 99]) + + def test_accept_len_all_accepted(self): + candidates = torch.tensor([[10, 11, 12, 13]], dtype=torch.long) + target_predict = torch.tensor([[11, 12, 13, 77]], dtype=torch.long) + + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + self.assertEqual(accept_len.tolist(), [3]) + self.assertEqual(bonus.tolist(), [77]) + + def test_shape_mismatch_raises(self): + candidates = torch.zeros((2, 4), dtype=torch.long) + target_predict = torch.zeros((2, 5), dtype=torch.long) + with self.assertRaises(ValueError): + compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + + +if __name__ == "__main__": + unittest.main() + From 289e74855352b900307b0712b29344ba0e22be3f Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 7 Jan 2026 01:15:18 +0000 Subject: [PATCH 02/73] fix verify mismatch --- .../layers/attention/trtllm_mha_backend.py | 58 +++++++------------ python/sglang/srt/speculative/dflash_info.py | 23 +------- .../models/test_qwen3_dflash_correctness.py | 5 ++ 3 files changed, 30 insertions(+), 56 deletions(-) diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index aa418b7af669..7ef1d116a97d 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -698,42 +698,28 @@ def forward_extend( bmm1_scale = q_scale * k_scale * layer.scaling bmm2_scale = 1.0 - if forward_batch.forward_mode.is_target_verify(): - o = flashinfer.decode.trtllm_batch_decode_with_kv_cache( - query=q, - kv_cache=kv_cache, - workspace_buffer=self.workspace_buffer, - block_tables=self.forward_metadata.page_table, - seq_lens=self.forward_metadata.cache_seqlens_int32, - max_seq_len=self.max_context_len, - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, - window_left=layer.sliding_window_size, - # TODO: add attention_sink operation or nvfp4 scale factor if needed - sinks=attention_sink, - out_dtype=self.q_data_type, # model_runner.dtype - q_len_per_req=self.forward_metadata.max_seq_len_q, - ) - else: - - o = flashinfer.prefill.trtllm_batch_context_with_kv_cache( - query=q, - kv_cache=kv_cache, - workspace_buffer=self.workspace_buffer, - block_tables=self.forward_metadata.page_table, - seq_lens=self.forward_metadata.cache_seqlens_int32, - max_q_len=self.forward_metadata.max_seq_len_q, - max_kv_len=self.max_context_len, - bmm1_scale=bmm1_scale, - bmm2_scale=bmm2_scale, - batch_size=forward_batch.batch_size, - cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, - cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, - window_left=layer.sliding_window_size, - # TODO: add attention_sink operation or nvfp4 scale factor if needed - sinks=attention_sink, - out_dtype=self.q_data_type, # model_runner.dtype - ) + # Target-verify runs a multi-token causal "block prefill" over draft tokens. + # Use the context (prefill) kernel so each query token only attends to the + # prefix and earlier tokens in the verify block. The decode kernel treats + # each query token independently and can leak future tokens when q_len > 1. + o = flashinfer.prefill.trtllm_batch_context_with_kv_cache( + query=q, + kv_cache=kv_cache, + workspace_buffer=self.workspace_buffer, + block_tables=self.forward_metadata.page_table, + seq_lens=self.forward_metadata.cache_seqlens_int32, + max_q_len=self.forward_metadata.max_seq_len_q, + max_kv_len=self.max_context_len, + bmm1_scale=bmm1_scale, + bmm2_scale=bmm2_scale, + batch_size=forward_batch.batch_size, + cum_seq_lens_q=self.forward_metadata.cu_seqlens_q, + cum_seq_lens_kv=self.forward_metadata.cu_seqlens_k, + window_left=layer.sliding_window_size, + # TODO: add attention_sink operation or nvfp4 scale factor if needed + sinks=attention_sink, + out_dtype=self.q_data_type, # model_runner.dtype + ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index e9a0c93924e0..72835c7fa58c 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -185,26 +185,9 @@ def generate_attn_arg_prefill( kv_indices, req_to_token.size(1), ) - - # Causal custom mask for the verify block: - # - all query tokens can attend to all prefix tokens - # - within the block, use a lower-triangular mask - prefix_mask = torch.full( - (paged_kernel_lens_sum * self.draft_token_num,), - True, - dtype=torch.bool, - device=device, - ) - tri = torch.tril( - torch.ones( - (self.draft_token_num, self.draft_token_num), - dtype=torch.bool, - device=device, - ) - ).flatten() - custom_mask = torch.cat([prefix_mask, tri.repeat(bs)], dim=0) - - return kv_indices, cum_kv_seq_len, qo_indptr, custom_mask + # DFlash verify is a standard causal block prefill. Let the attention backend + # apply its built-in causal masking; no custom mask is required here. + return kv_indices, cum_kv_seq_len, qo_indptr, None def verify( self, diff --git a/test/manual/models/test_qwen3_dflash_correctness.py b/test/manual/models/test_qwen3_dflash_correctness.py index 2fc3e876ecbd..8bef7afb63a6 100644 --- a/test/manual/models/test_qwen3_dflash_correctness.py +++ b/test/manual/models/test_qwen3_dflash_correctness.py @@ -54,6 +54,7 @@ def test_qwen3_dflash_matches_target_only_greedy(self): "SGLANG_DFLASH_PROMPT", "How many positive whole-number divisors does 196 have?", ) + attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") baseline_port = find_available_port(20000) dflash_port = find_available_port(baseline_port + 1) @@ -67,6 +68,8 @@ def test_qwen3_dflash_matches_target_only_greedy(self): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--disable-radix-cache", + "--attention-backend", + attention_backend, ], ) try: @@ -83,6 +86,8 @@ def test_qwen3_dflash_matches_target_only_greedy(self): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--disable-radix-cache", + "--attention-backend", + attention_backend, "--speculative-algorithm", "DFLASH", "--speculative-draft-model-path", From f1efc03c105c380fe528088aea6d6d0ee7611f74 Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 7 Jan 2026 01:59:43 +0000 Subject: [PATCH 03/73] add gsm8k bench --- .../models/test_qwen3_dflash_gsm8k_bench.py | 276 ++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 test/manual/models/test_qwen3_dflash_gsm8k_bench.py diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py new file mode 100644 index 000000000000..3ceb99c781af --- /dev/null +++ b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py @@ -0,0 +1,276 @@ +import ast +import json +import os +import re +import statistics +import time +import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed + +import requests +import torch + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + CustomTestCase, + find_available_port, + is_in_ci, + popen_launch_server, +) +from sglang.utils import download_and_cache_file, read_jsonl + +INVALID = -9999999 + + +def _get_one_example(lines, i: int, include_answer: bool) -> str: + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def _get_few_shot_examples(lines, k: int) -> str: + ret = "" + for i in range(k): + ret += _get_one_example(lines, i, True) + "\n\n" + return ret + + +def _get_answer_value(answer_str: str) -> int: + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def _send_generate(base_url: str, prompt: str, *, max_new_tokens: int) -> dict: + resp = requests.post( + base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": max_new_tokens, + # Avoid extra decoding after an answer. + "stop": ["Question", "Assistant:", "<|separator|>"], + }, + }, + timeout=600, + ) + resp.raise_for_status() + return resp.json() + + +def _run_generate_batch( + base_url: str, + prompts: list[str], + *, + max_new_tokens: int, + parallel: int, +) -> tuple[float, list[dict]]: + start = time.perf_counter() + outputs: list[dict] = [None for _ in range(len(prompts))] # type: ignore[list-item] + with ThreadPoolExecutor(max_workers=parallel) as pool: + futures = { + pool.submit(_send_generate, base_url, prompt, max_new_tokens=max_new_tokens): i + for i, prompt in enumerate(prompts) + } + for fut in as_completed(futures): + idx = futures[fut] + outputs[idx] = fut.result() + latency = time.perf_counter() - start + return latency, outputs + + +def _summarize(values: list[float]) -> dict: + if not values: + return {"mean": None, "p50": None, "p90": None} + values_sorted = sorted(values) + p50 = values_sorted[int(0.50 * (len(values_sorted) - 1))] + p90 = values_sorted[int(0.90 * (len(values_sorted) - 1))] + return { + "mean": float(statistics.mean(values_sorted)), + "p50": float(p50), + "p90": float(p90), + } + + +class TestQwen3DFlashGSM8KBench(CustomTestCase): + def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): + if is_in_ci(): + self.skipTest("Manual benchmark; skipped in CI.") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this manual DFlash benchmark.") + + target_model = os.getenv("SGLANG_DFLASH_TARGET_MODEL", "Qwen/Qwen3-8B") + draft_model_path = os.getenv( + "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" + ) + if not os.path.isdir(draft_model_path): + self.skipTest( + f"Draft model folder not found: {draft_model_path}. " + "Set SGLANG_DFLASH_DRAFT_MODEL_PATH to run this benchmark." + ) + + attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") + max_new_tokens = int(os.getenv("SGLANG_DFLASH_MAX_NEW_TOKENS", "2048")) + parallel = int(os.getenv("SGLANG_DFLASH_PARALLEL", "1")) + num_questions = int(os.getenv("SGLANG_DFLASH_NUM_QUESTIONS", "100")) + num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) + disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" + + # Read GSM8K data (download if absent). + data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + if not os.path.isfile(data_path): + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) + + few_shot = _get_few_shot_examples(lines, num_shots) + prompts: list[str] = [] + labels: list[int] = [] + for i in range(len(lines[:num_questions])): + prompts.append(few_shot + _get_one_example(lines, i, False)) + labels.append(_get_answer_value(lines[i]["answer"])) + self.assertTrue(all(l != INVALID for l in labels), "Invalid labels in GSM8K data") + + common_server_args = ["--attention-backend", attention_backend] + if disable_radix_cache: + common_server_args.append("--disable-radix-cache") + + baseline_port = find_available_port(20000) + dflash_port = find_available_port(baseline_port + 1) + baseline_url = f"http://127.0.0.1:{baseline_port}" + dflash_url = f"http://127.0.0.1:{dflash_port}" + + # 1) Target-only baseline. + baseline_proc = popen_launch_server( + target_model, + baseline_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=common_server_args, + ) + try: + _send_generate(baseline_url, "Hello", max_new_tokens=8) # warmup + baseline_latency, baseline_outputs = _run_generate_batch( + baseline_url, + prompts, + max_new_tokens=max_new_tokens, + parallel=parallel, + ) + finally: + kill_process_tree(baseline_proc.pid) + try: + baseline_proc.wait(timeout=30) + except Exception: + pass + + # 2) DFLASH speculative decoding. + dflash_proc = popen_launch_server( + target_model, + dflash_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + draft_model_path, + ], + ) + try: + _send_generate(dflash_url, "Hello", max_new_tokens=8) # warmup + dflash_latency, dflash_outputs = _run_generate_batch( + dflash_url, + prompts, + max_new_tokens=max_new_tokens, + parallel=parallel, + ) + finally: + kill_process_tree(dflash_proc.pid) + try: + dflash_proc.wait(timeout=30) + except Exception: + pass + + def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: + completion_tokens = [] + preds = [] + total_completion_tokens = 0 + for out in outputs: + meta = out.get("meta_info", {}) + total_completion_tokens += int(meta.get("completion_tokens", 0)) + completion_tokens.append(int(meta.get("completion_tokens", 0))) + preds.append(_get_answer_value(out.get("text", ""))) + return total_completion_tokens, preds + + baseline_total_tokens, baseline_preds = _collect_common_metrics(baseline_outputs) + dflash_total_tokens, dflash_preds = _collect_common_metrics(dflash_outputs) + + baseline_throughput = baseline_total_tokens / max(baseline_latency, 1e-6) + dflash_throughput = dflash_total_tokens / max(dflash_latency, 1e-6) + speedup = dflash_throughput / max(baseline_throughput, 1e-6) + + baseline_acc = sum( + int(p == l) for p, l in zip(baseline_preds, labels, strict=True) + ) / len(labels) + dflash_acc = sum( + int(p == l) for p, l in zip(dflash_preds, labels, strict=True) + ) / len(labels) + + spec_accept_lengths: list[float] = [] + spec_accept_rates: list[float] = [] + spec_verify_cts: list[int] = [] + for out in dflash_outputs: + meta = out.get("meta_info", {}) + if "spec_verify_ct" in meta: + spec_verify_cts.append(int(meta["spec_verify_ct"])) + if "spec_accept_length" in meta: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + if "spec_accept_rate" in meta: + spec_accept_rates.append(float(meta["spec_accept_rate"])) + + # Basic sanity checks that DFLASH actually ran. + self.assertTrue(spec_verify_cts, "Missing spec_verify_ct in DFLASH responses.") + self.assertGreater(sum(spec_verify_cts), 0, "DFLASH did not run verify steps.") + + report = { + "settings": { + "target_model": target_model, + "draft_model_path": draft_model_path, + "attention_backend": attention_backend, + "max_new_tokens": max_new_tokens, + "parallel": parallel, + "num_questions": num_questions, + "num_shots": num_shots, + "disable_radix_cache": disable_radix_cache, + }, + "baseline": { + "latency_s": round(baseline_latency, 3), + "completion_tokens": baseline_total_tokens, + "throughput_tok_s": round(baseline_throughput, 3), + "accuracy": round(baseline_acc, 3), + }, + "dflash": { + "latency_s": round(dflash_latency, 3), + "completion_tokens": dflash_total_tokens, + "throughput_tok_s": round(dflash_throughput, 3), + "accuracy": round(dflash_acc, 3), + "spec_accept_length": _summarize(spec_accept_lengths), + "spec_accept_rate": _summarize(spec_accept_rates), + "spec_verify_ct_mean": float(statistics.mean(spec_verify_cts)), + }, + "speedup": round(speedup, 3), + } + print(json.dumps(report, indent=2), flush=True) + + +if __name__ == "__main__": + unittest.main() From e807216a8b05bd4b33cbc5ea060d4207557fb620 Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 7 Jan 2026 03:06:21 +0000 Subject: [PATCH 04/73] support more backends, investigate accuracy --- python/sglang/srt/speculative/dflash_info.py | 25 +++++- .../models/test_qwen3_dflash_gsm8k_bench.py | 79 ++++++++++++++++--- 2 files changed, 91 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 72835c7fa58c..70ef09dcb30e 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -101,6 +101,9 @@ class DFlashVerifyInput(SpecInput): draft_token: torch.Tensor positions: torch.Tensor draft_token_num: int + # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). + # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. + custom_mask: torch.Tensor | None = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL def __post_init__(self): @@ -149,6 +152,24 @@ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): bs, ) + # Build a standard causal attention *allow* mask over [prefix + verify_block] for each request. + # Layout matches other speculative inputs: flatten per request, row-major over + # [q_len=draft_token_num, kv_len=prefix_len + draft_token_num]. + if self.draft_token_num <= 0: + raise ValueError(f"DFLASH draft_token_num must be positive, got {self.draft_token_num}.") + mask_chunks: List[torch.Tensor] = [] + q_len = int(self.draft_token_num) + q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) + for prefix_len in batch.seq_lens_cpu.tolist(): + prefix_len_i = int(prefix_len) + kv_len = prefix_len_i + q_len + k_idx = torch.arange(kv_len, device=batch.device, dtype=torch.int32).unsqueeze(0) + # Allow attending to the full prefix and to tokens up to (and including) the + # current query position within the verify block (standard causal masking). + allow = k_idx <= (prefix_len_i + q_idx) + mask_chunks.append(allow.flatten()) + self.custom_mask = torch.cat(mask_chunks, dim=0) if mask_chunks else torch.empty((0,), dtype=torch.bool, device=batch.device) + def generate_attn_arg_prefill( self, req_pool_indices: torch.Tensor, @@ -185,9 +206,7 @@ def generate_attn_arg_prefill( kv_indices, req_to_token.size(1), ) - # DFlash verify is a standard causal block prefill. Let the attention backend - # apply its built-in causal masking; no custom mask is required here. - return kv_indices, cum_kv_seq_len, qo_indptr, None + return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask def verify( self, diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py index 3ceb99c781af..a46d38419591 100644 --- a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py +++ b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py @@ -1,7 +1,22 @@ +"""Manual GSM8K benchmark for DFLASH (vs target-only baseline). + +Notes / known limitations (as of this initial integration): + - Prompting style matters a lot for acceptance length. The upstream DFlash HF demo/bench + uses a Qwen chat-template prompt; use `SGLANG_DFLASH_PROMPT_STYLE=dflash_chat` and + typically `SGLANG_DFLASH_STOP=` (empty) to get closer acceptance numbers. + - DFLASH may *diverge* from target-only greedy decoding on some prompts. This is because + DFLASH verifies a whole block with `ForwardMode.TARGET_VERIFY` (prefill-style kernels), + while the baseline uses the normal decode path. Some attention backends can produce + different argmax tokens across these modes (numerical differences), which makes direct + "accuracy" comparisons misleading. + Use `SGLANG_DFLASH_ASSERT_MATCH=1` to detect any token-level divergence. +""" + import ast import json import os import re +import shlex import statistics import time import unittest @@ -9,6 +24,7 @@ import requests import torch +from transformers import AutoTokenizer from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -49,18 +65,21 @@ def _get_answer_value(answer_str: str) -> int: def _send_generate(base_url: str, prompt: str, *, max_new_tokens: int) -> dict: + stop = os.getenv("SGLANG_DFLASH_STOP", "Question,Assistant:,<|separator|>") + stop_list = [s for s in stop.split(",") if s] if stop else [] + sampling_params = { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": max_new_tokens, + } + if stop_list: + sampling_params["stop"] = stop_list resp = requests.post( base_url + "/generate", json={ "text": prompt, - "sampling_params": { - "temperature": 0.0, - "top_p": 1.0, - "top_k": 1, - "max_new_tokens": max_new_tokens, - # Avoid extra decoding after an answer. - "stop": ["Question", "Assistant:", "<|separator|>"], - }, + "sampling_params": sampling_params, }, timeout=600, ) @@ -125,6 +144,8 @@ def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): num_questions = int(os.getenv("SGLANG_DFLASH_NUM_QUESTIONS", "100")) num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" + prompt_style = os.getenv("SGLANG_DFLASH_PROMPT_STYLE", "fewshot_qa") + assert_match = os.getenv("SGLANG_DFLASH_ASSERT_MATCH", "0") != "0" # Read GSM8K data (download if absent). data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") @@ -133,17 +154,41 @@ def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): data_path = download_and_cache_file(url) lines = list(read_jsonl(data_path)) - few_shot = _get_few_shot_examples(lines, num_shots) + tokenizer = None + if prompt_style == "dflash_chat": + tokenizer = AutoTokenizer.from_pretrained(target_model) + + few_shot = _get_few_shot_examples(lines, num_shots) if prompt_style == "fewshot_qa" else "" prompts: list[str] = [] labels: list[int] = [] for i in range(len(lines[:num_questions])): - prompts.append(few_shot + _get_one_example(lines, i, False)) + if prompt_style == "fewshot_qa": + prompts.append(few_shot + _get_one_example(lines, i, False)) + elif prompt_style == "dflash_chat": + assert tokenizer is not None + user_content = ( + lines[i]["question"] + + "\nPlease reason step by step, and put your final answer within \\boxed{}." + ) + prompts.append( + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ) + else: + raise ValueError(f"Unsupported SGLANG_DFLASH_PROMPT_STYLE: {prompt_style}") labels.append(_get_answer_value(lines[i]["answer"])) self.assertTrue(all(l != INVALID for l in labels), "Invalid labels in GSM8K data") common_server_args = ["--attention-backend", attention_backend] if disable_radix_cache: common_server_args.append("--disable-radix-cache") + extra_server_args = os.getenv("SGLANG_DFLASH_EXTRA_SERVER_ARGS", "").strip() + if extra_server_args: + common_server_args.extend(shlex.split(extra_server_args)) baseline_port = find_available_port(20000) dflash_port = find_available_port(baseline_port + 1) @@ -214,10 +259,23 @@ def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: baseline_total_tokens, baseline_preds = _collect_common_metrics(baseline_outputs) dflash_total_tokens, dflash_preds = _collect_common_metrics(dflash_outputs) + if assert_match: + for i, (baseline_out, dflash_out) in enumerate( + zip(baseline_outputs, dflash_outputs, strict=True) + ): + if baseline_out.get("output_ids") != dflash_out.get("output_ids"): + raise AssertionError( + "Baseline and DFLASH outputs diverged at index " + f"{i}.\nbaseline={baseline_out.get('output_ids')}\ndflash={dflash_out.get('output_ids')}" + ) + baseline_throughput = baseline_total_tokens / max(baseline_latency, 1e-6) dflash_throughput = dflash_total_tokens / max(dflash_latency, 1e-6) speedup = dflash_throughput / max(baseline_throughput, 1e-6) + # WARNING: Until baseline-vs-DFLASH greedy outputs are guaranteed identical, these + # "accuracy" numbers are not strictly comparable. Prefer asserting matches via + # `SGLANG_DFLASH_ASSERT_MATCH=1` when debugging correctness. baseline_acc = sum( int(p == l) for p, l in zip(baseline_preds, labels, strict=True) ) / len(labels) @@ -250,6 +308,7 @@ def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: "parallel": parallel, "num_questions": num_questions, "num_shots": num_shots, + "prompt_style": prompt_style, "disable_radix_cache": disable_radix_cache, }, "baseline": { From 99e140af07943fec8bb4f482bb7b8925411afc3f Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 05:53:11 +0000 Subject: [PATCH 05/73] native sglang backend --- python/sglang/srt/managers/scheduler.py | 4 +- .../scheduler_output_processor_mixin.py | 19 +- python/sglang/srt/models/dflash.py | 217 ++++++++++ python/sglang/srt/speculative/dflash_info.py | 27 +- .../sglang/srt/speculative/dflash_worker.py | 380 ++++++++++++++++-- .../models/test_qwen3_dflash_gsm8k_bench.py | 115 ++++++ 6 files changed, 722 insertions(+), 40 deletions(-) create mode 100644 python/sglang/srt/models/dflash.py diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 92d2868977d1..ae5efea66f7e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -2607,14 +2607,14 @@ def abort_request(self, recv_req: AbortReq): self.send_to_tokenizer.send_output(AbortReq(rid=req.rid), req) # For disaggregation decode mode, the request in the waiting queue has KV cache allocated. if self.disaggregation_mode == DisaggregationMode.DECODE: - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) # For mamba radix cache if ( req.mamba_pool_idx is not None and self.disaggregation_mode != DisaggregationMode.DECODE ): - release_kv_cache(req, self.tree_cache, is_insert=False) + self._release_kv_cache_and_draft(req, is_insert=False) logger.debug(f"Abort queued request. {req.rid=}") # Delete the requests in the grammar queue diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index e40586c24cc1..d6a8d1589e6b 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -44,6 +44,13 @@ class SchedulerOutputProcessorMixin: We put them into a separate file to make the `scheduler.py` shorter. """ + def _release_kv_cache_and_draft(self: Scheduler, req: Req, *, is_insert: bool = True): + release_kv_cache(req, self.tree_cache, is_insert=is_insert) + draft_worker = getattr(self, "draft_worker", None) + hook = getattr(draft_worker, "on_req_finished", None) if draft_worker else None + if hook is not None: + hook(req) + def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): assert self.disaggregation_mode == DisaggregationMode.DECODE for req in batch.reqs: @@ -57,7 +64,7 @@ def process_batch_result_prebuilt(self: Scheduler, batch: ScheduleBatch): req.rid, thread_finish_flag=True, ) - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) # Note: Logprobs should be handled on the prefill engine. trace_slice_batch(RequestStage.DECODE_FAKE_OUTPUT, batch.reqs) @@ -137,7 +144,7 @@ def process_batch_result_prefill( if req.finished(): self.maybe_collect_routed_experts(req) - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) req.time_stats.completion_time = time.perf_counter() elif not batch.decoding_reqs or req not in batch.decoding_reqs: # This updates radix so others can match @@ -271,7 +278,7 @@ def process_batch_result_prefill( req.check_finished() if req.finished(): - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) else: self.tree_cache.cache_unfinished_req(req) else: @@ -345,7 +352,7 @@ def process_batch_result_dllm( req.check_finished() if req.finished(): - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) req.time_stats.completion_time = time.perf_counter() break @@ -415,9 +422,9 @@ def process_batch_result_decode( if self.server_args.disaggregation_decode_enable_offload_kvcache: # Asynchronously offload KV cache; release_kv_cache will be called after Device->Host transfer completes if not self.decode_offload_manager.offload_kv_cache(req): - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) else: - release_kv_cache(req, self.tree_cache) + self._release_kv_cache_and_draft(req) req.time_stats.completion_time = time.perf_counter() diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py new file mode 100644 index 000000000000..68286140e613 --- /dev/null +++ b/python/sglang/srt/models/dflash.py @@ -0,0 +1,217 @@ +# Adapted from the DFlash reference implementation (HF) but implemented with +# SGLang primitives (RadixAttention + SGLang KV cache). This model intentionally +# does not include token embeddings or an LM head; DFlash uses the target model's +# embedding/lm_head. + +from __future__ import annotations + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.radix_attention import AttentionType, RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import apply_qk_norm + +logger = logging.getLogger(__name__) + + +class DFlashAttention(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + num_heads = int(config.num_attention_heads) + num_kv_heads = int(getattr(config, "num_key_value_heads", num_heads)) + head_dim = int(getattr(config, "head_dim", hidden_size // num_heads)) + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + + attention_bias = bool(getattr(config, "attention_bias", False)) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.q_proj = nn.Linear(hidden_size, self.q_size, bias=attention_bias) + self.k_proj = nn.Linear(hidden_size, self.kv_size, bias=attention_bias) + self.v_proj = nn.Linear(hidden_size, self.kv_size, bias=attention_bias) + self.o_proj = nn.Linear(self.q_size, hidden_size, bias=attention_bias) + + # Per-head Q/K RMSNorm, matching HF Qwen3. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + rope_theta = float(getattr(config, "rope_theta", 1000000)) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = int(getattr(config, "max_position_embeddings", 32768)) + self.rotary_emb = get_rope( + head_dim, + rotary_dim=head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + + self.scaling = head_dim**-0.5 + # DFlash uses non-causal attention over the draft block. + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=num_kv_heads, + layer_id=layer_id, + attn_type=AttentionType.ENCODER_ONLY, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + q, k = apply_qk_norm( + q=q, + k=k, + q_norm=self.q_norm, + k_norm=self.k_norm, + head_dim=self.head_dim, + ) + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v, forward_batch) + return self.o_proj(attn_output) + + +class DFlashMLP(nn.Module): + def __init__(self, config) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + intermediate_size = int(getattr(config, "intermediate_size", 0)) + if intermediate_size <= 0: + raise ValueError(f"Invalid intermediate_size={intermediate_size} for DFlash MLP.") + + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class DFlashDecoderLayer(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = DFlashAttention(config=config, layer_id=layer_id) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = DFlashMLP(config=config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashDraftModel(nn.Module): + """SGLang-native DFlash draft model (no embedding / lm_head weights). + + The checkpoint provides: + - transformer weights for `layers.*` + - `fc.weight`, `hidden_norm.weight` for projecting target context features + - `norm.weight` for final normalization + """ + + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + self.config = config + + hidden_size = int(config.hidden_size) + num_layers = int(config.num_hidden_layers) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] + ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + # Project per-token target context features: + # concat(num_layers * hidden_size) -> hidden_size + self.fc = nn.Linear(num_layers * hidden_size, hidden_size, bias=False) + self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.block_size = int(getattr(config, "block_size", 16)) + + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: + """Project concatenated target-layer hidden states into draft hidden_size.""" + return self.hidden_norm(self.fc(target_hidden)) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + pp_proxy_tensors=None, + ) -> torch.Tensor: + if input_embeds is None: + raise ValueError( + "DFlashDraftModel requires `input_embeds` (use the target embedding)." + ) + hidden_states = input_embeds + + for layer in self.layers: + hidden_states = layer(positions, hidden_states, forward_batch) + + if hidden_states.numel() == 0: + return hidden_states + return self.norm(hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if name.endswith(".bias") and name not in params_dict: + # Some quantized checkpoints may have extra biases. + continue + param = params_dict.get(name) + if param is None: + # Ignore unexpected weights (e.g., HF rotary caches). + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DFlashDraftModel + diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 70ef09dcb30e..e5a7bf012262 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -24,10 +24,11 @@ class DFlashDraftInput(SpecInput): This object is stored on `ScheduleBatch.spec_info` between decode iterations. It is NOT sent to model attention backends; the DFlash worker uses it to run - the draft model and to carry draft-side caches. + the draft model and to track draft-side cache progress. Invariant (per request): - - `draft_cache.get_seq_length() + ctx_len == batch.seq_lens[i]` + - Native path: `draft_seq_len + ctx_len == batch.seq_lens[i]` + - HF path: `draft_cache.get_seq_length() + ctx_len == batch.seq_lens[i]` where `ctx_len` is the number of target context-feature tokens carried in `target_hidden` for that request. """ @@ -42,8 +43,12 @@ class DFlashDraftInput(SpecInput): # Context lengths on CPU, one per request. Used to slice `target_hidden`. ctx_lens_cpu: List[int] - # Per-request transformers DynamicCache objects for the draft model. - draft_caches: List[object] + # Native implementation: how many tokens are already materialized in the draft KV cache. + # The next draft step appends `ctx_lens_cpu[i]` tokens starting at `draft_seq_lens_cpu[i]`. + draft_seq_lens_cpu: List[int] | None = None + + # HF-style baseline implementation: per-request transformers DynamicCache objects. + draft_caches: List[object] | None = None def __post_init__(self): super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) @@ -59,8 +64,11 @@ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True old_target_hidden = self.target_hidden self.verified_id = self.verified_id[new_indices] - self.draft_caches = [self.draft_caches[i] for i in keep_indices] self.ctx_lens_cpu = [old_ctx_lens_cpu[i] for i in keep_indices] + if self.draft_seq_lens_cpu is not None: + self.draft_seq_lens_cpu = [self.draft_seq_lens_cpu[i] for i in keep_indices] + if self.draft_caches is not None: + self.draft_caches = [self.draft_caches[i] for i in keep_indices] if old_target_hidden is None or old_target_hidden.numel() == 0: self.target_hidden = old_target_hidden @@ -81,8 +89,15 @@ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True def merge_batch(self, spec_info: "DFlashDraftInput"): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) - self.draft_caches.extend(spec_info.draft_caches) self.ctx_lens_cpu.extend(spec_info.ctx_lens_cpu) + if self.draft_seq_lens_cpu is not None or spec_info.draft_seq_lens_cpu is not None: + if self.draft_seq_lens_cpu is None or spec_info.draft_seq_lens_cpu is None: + raise ValueError("Cannot merge DFLASH draft batches with mismatched draft_seq_lens_cpu presence.") + self.draft_seq_lens_cpu.extend(spec_info.draft_seq_lens_cpu) + if self.draft_caches is not None or spec_info.draft_caches is not None: + if self.draft_caches is None or spec_info.draft_caches is None: + raise ValueError("Cannot merge DFLASH draft batches with mismatched draft_caches presence.") + self.draft_caches.extend(spec_info.draft_caches) if self.target_hidden is None or self.target_hidden.numel() == 0: self.target_hidden = spec_info.target_hidden elif spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0: diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index fb788bd0a54e..cf9d4963ccd1 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -1,4 +1,6 @@ import logging +import os +from copy import deepcopy from typing import List, Optional, Union import torch @@ -7,11 +9,17 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + compute_position, +) from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.dflash_draft_model import load_dflash_draft_model from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func logger = logging.getLogger(__name__) @@ -45,25 +53,73 @@ def __init__( self._warned_forced_greedy = False self._logged_first_verify = False - # Load the DFlash draft model (weights are separate from the target model). - # This is kept as a standalone module (not a TpModelWorker) since its forward - # is non-causal and differs from standard decoder-only models. + self.draft_impl = os.getenv("SGLANG_DFLASH_DRAFT_IMPL", "native").strip().lower() + if self.draft_impl not in ("native", "hf"): + raise ValueError( + "Invalid SGLANG_DFLASH_DRAFT_IMPL. " + f"Expected 'native' or 'hf', got: {self.draft_impl!r}" + ) + + self.native_draft_worker: TpModelWorker | None = None + self.native_draft_model_runner = None + self.native_draft_model = None + self.draft_model = None + self.draft_config = None + draft_device = torch.device(target_worker.device) draft_dtype = target_worker.model_runner.dtype - self.draft_model, self.draft_config = load_dflash_draft_model( - server_args.speculative_draft_model_path, - device=draft_device, - dtype=draft_dtype, - ) - self.block_size = int(getattr(self.draft_config, "block_size", 16)) + + if self.draft_impl == "native": + # Native (SGLang) draft runner (separate KV cache + attention backend). + draft_server_args = deepcopy(server_args) + draft_server_args.skip_tokenizer_init = True + draft_server_args.disable_cuda_graph = True + # Force FA3 for draft (Hopper-friendly, and supports ENCODER_ONLY attention). + draft_server_args.attention_backend = "fa3" + # Keep draft context length aligned with the target. + draft_server_args.context_length = target_worker.model_runner.model_config.context_len + self.native_draft_worker = TpModelWorker( + server_args=draft_server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + pp_rank=0, + dp_rank=dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + ) + self.native_draft_model_runner = self.native_draft_worker.model_runner + self.native_draft_model = self.native_draft_model_runner.model + self.block_size = int(getattr(self.native_draft_model, "block_size", 16)) + if self.tp_rank == 0: + logger.info( + "Initialized native DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", + getattr(draft_server_args, "attention_backend", None), + self.native_draft_model.__class__.__name__, + self.block_size, + ) + else: + # HF-style baseline draft implementation (DynamicCache). + self.draft_model, self.draft_config = load_dflash_draft_model( + server_args.speculative_draft_model_path, + device=draft_device, + dtype=draft_dtype, + ) + self.block_size = int(getattr(self.draft_config, "block_size", 16)) + if self.tp_rank == 0: + logger.info( + "Initialized HF-style DFLASH draft model. path=%s, dtype=%s, device=%s, block_size=%s, num_hidden_layers=%s", + server_args.speculative_draft_model_path, + draft_dtype, + draft_device, + self.block_size, + getattr(self.draft_config, "num_hidden_layers", None), + ) + if self.tp_rank == 0: logger.info( - "Loaded DFLASH draft model. path=%s, dtype=%s, device=%s, block_size=%s, num_hidden_layers=%s, mask_token_id=%s", - server_args.speculative_draft_model_path, - draft_dtype, - draft_device, - self.block_size, - getattr(self.draft_config, "num_hidden_layers", None), + "DFLASH draft impl selected. impl=%s, mask_token_id=%s", + self.draft_impl, self._mask_token_id, ) @@ -72,8 +128,37 @@ def __getattr__(self, name): return getattr(self.target_worker, name) def clear_cache_pool(self): - # No draft-side pools to clear in the stub implementation. - return + if self.native_draft_model_runner is None: + return + self.native_draft_model_runner.req_to_token_pool.clear() + self.native_draft_model_runner.token_to_kv_pool_allocator.clear() + + def on_req_finished(self, req): + """Release native-draft KV cache for a finished request. + + The native draft path uses a separate KV pool that is not managed by the + scheduler/tree-cache. We must explicitly free per-request draft KV slots + when the request completes to avoid leaking draft KV memory across + requests. + """ + if self.draft_impl != "native": + return + if self.native_draft_model_runner is None: + return + req_pool_idx = getattr(req, "req_pool_idx", None) + if req_pool_idx is None: + return + draft_len = getattr(req, "dflash_draft_seq_len", None) + if draft_len is None: + return + draft_len = int(draft_len) + if draft_len <= 0: + return + kv_indices = self.native_draft_model_runner.req_to_token_pool.req_to_token[ + req_pool_idx, :draft_len + ] + self.native_draft_model_runner.token_to_kv_pool_allocator.free(kv_indices) + req.dflash_draft_seq_len = 0 def _resolve_mask_token_id(self) -> int: tokenizer = getattr(self.target_worker, "tokenizer", None) @@ -119,6 +204,13 @@ def _resolve_mask_token_id(self) -> int: return int(mask_token_id) def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: DFlashDraftInput): + if self.draft_impl == "native": + return self._prepare_for_speculative_decoding_native(batch, draft_input) + return self._prepare_for_speculative_decoding_hf(batch, draft_input) + + def _prepare_for_speculative_decoding_native( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): return @@ -141,6 +233,228 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D raise RuntimeError( f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." ) + if draft_input.draft_seq_lens_cpu is None: + raise RuntimeError("DFLASH native draft state missing draft_seq_lens_cpu.") + if len(draft_input.draft_seq_lens_cpu) != bs: + raise RuntimeError( + "DFLASH draft_seq_lens_cpu length mismatch: " + f"got {len(draft_input.draft_seq_lens_cpu)} for bs={bs}." + ) + if self.native_draft_model_runner is None or self.native_draft_model is None: + raise RuntimeError("DFLASH native draft runner is not initialized.") + + embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() + + # --- 1) Append new context tokens into the native draft KV cache. + start_pos_cpu = batch.seq_lens_cpu.tolist() + for cache_len, ctx_len, start_pos in zip( + draft_input.draft_seq_lens_cpu, + draft_input.ctx_lens_cpu, + start_pos_cpu, + strict=True, + ): + if int(cache_len) + int(ctx_len) != int(start_pos): + raise RuntimeError( + "DFLASH native draft cache length mismatch. " + f"cache_len={int(cache_len)}, ctx_len={int(ctx_len)}, start_pos={int(start_pos)}. " + "This can happen if prefix caching is enabled; start with `--disable-radix-cache` for now." + ) + + total_ctx = int(sum(int(x) for x in draft_input.ctx_lens_cpu)) + if total_ctx > 0: + ctx_start = torch.tensor( + draft_input.draft_seq_lens_cpu, dtype=torch.int64, device=device + ) + ctx_len = torch.tensor(draft_input.ctx_lens_cpu, dtype=torch.int64, device=device) + ctx_end = ctx_start + ctx_len + + ctx_cache_loc = self.native_draft_model_runner.token_to_kv_pool_allocator.alloc( + total_ctx + ) + if ctx_cache_loc is None: + raise RuntimeError( + f"DFLASH native draft OOM when allocating {total_ctx} context tokens." + ) + + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.native_draft_model_runner.req_to_token_pool.req_to_token, + ctx_start, + ctx_end, + ctx_cache_loc, + bs, + ) + + ctx_positions_chunks: List[torch.Tensor] = [] + for s, e in zip(ctx_start.tolist(), ctx_end.tolist(), strict=True): + if e > s: + ctx_positions_chunks.append( + torch.arange(s, e, device=device, dtype=torch.int64) + ) + ctx_positions = ( + torch.cat(ctx_positions_chunks, dim=0) + if ctx_positions_chunks + else torch.empty((0,), dtype=torch.int64, device=device) + ) + + with torch.inference_mode(): + ctx_hidden = self.native_draft_model.project_target_hidden( + draft_input.target_hidden + ) # [sum(ctx), hidden] + + for layer in self.native_draft_model.layers: + attn = layer.self_attn + k = attn.k_proj(ctx_hidden) + v = attn.v_proj(ctx_hidden) + k = attn.k_norm(k.view(-1, attn.head_dim)).view_as(k) + dummy_q = k.new_empty((k.shape[0], attn.head_dim)) + _, k = attn.rotary_emb(ctx_positions, dummy_q, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.native_draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + draft_input.draft_seq_lens_cpu = ctx_end.to(torch.int64).cpu().tolist() + for req, seq_len in zip( + batch.reqs, draft_input.draft_seq_lens_cpu, strict=True + ): + req.dflash_draft_seq_len = int(seq_len) + draft_input.ctx_lens_cpu = [0] * bs + draft_input.target_hidden = draft_input.target_hidden[:0] + + # --- 2) Draft a non-causal block with the native draft model. + block_ids = torch.full( + (bs, self.block_size), + self._mask_token_id, + dtype=torch.long, + device=device, + ) + block_ids[:, 0] = draft_input.verified_id.to(torch.long) + + noise_embedding = F.embedding(block_ids, embed_weight) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + prefix_lens_cpu = [int(x) for x in draft_input.draft_seq_lens_cpu] + prefix_lens = torch.tensor(prefix_lens_cpu, dtype=torch.int32, device=device) + extend_lens = torch.full( + (bs,), int(self.block_size), dtype=torch.int32, device=device + ) + positions, extend_start_loc = compute_position( + self.native_draft_model_runner.server_args.attention_backend, + prefix_lens, + extend_lens, + bs * self.block_size, + ) + + block_start = prefix_lens.to(torch.int64) + block_end = block_start + int(self.block_size) + block_cache_loc = self.native_draft_model_runner.token_to_kv_pool_allocator.alloc( + bs * self.block_size + ) + if block_cache_loc is None: + raise RuntimeError( + f"DFLASH native draft OOM when allocating {bs * self.block_size} block tokens." + ) + + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.native_draft_model_runner.req_to_token_pool.req_to_token, + block_start, + block_end, + block_cache_loc, + bs, + ) + + seq_lens = block_end.to(torch.int64) + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=batch.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=block_cache_loc, + seq_lens_sum=int(seq_lens.sum().item()), + seq_lens_cpu=torch.tensor(seq_lens.cpu().tolist(), dtype=torch.int64), + positions=positions, + extend_num_tokens=bs * self.block_size, + extend_seq_lens=extend_lens, + extend_prefix_lens=prefix_lens, + extend_start_loc=extend_start_loc, + extend_prefix_lens_cpu=prefix_lens_cpu, + extend_seq_lens_cpu=[int(self.block_size)] * bs, + extend_logprob_start_lens_cpu=[0] * bs, + req_to_token_pool=self.native_draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.native_draft_model_runner.token_to_kv_pool, + attn_backend=self.native_draft_model_runner.attn_backend, + input_embeds=input_embeds, + ) + + with torch.inference_mode(): + draft_hidden = self.native_draft_model_runner.forward(forward_batch).logits_output + + # Crop: drop the speculative block from the draft KV cache (context stays). + self.native_draft_model_runner.token_to_kv_pool_allocator.free(block_cache_loc) + + draft_hidden = draft_hidden.view(bs, self.block_size, -1) + draft_logits = F.linear(draft_hidden[:, 1:, :], head_weight) + draft_next = torch.argmax(draft_logits, dim=-1).to(torch.long) + draft_tokens = torch.cat([block_ids[:, :1], draft_next], dim=1) # [bs, block_size] + positions = ( + batch.seq_lens.to(torch.long).unsqueeze(1) + + torch.arange(self.block_size, device=device, dtype=torch.long)[None, :] + ).flatten() + + verify_input = DFlashVerifyInput( + draft_token=draft_tokens.flatten(), + positions=positions, + draft_token_num=self.block_size, + ) + verify_input.prepare_for_verify(batch, self.page_size) + + batch.forward_mode = ForwardMode.TARGET_VERIFY if not batch.forward_mode.is_idle() else ForwardMode.IDLE + batch.spec_info = verify_input + batch.return_hidden_states = False + + def _prepare_for_speculative_decoding_hf( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): + if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): + return + + if batch.has_grammar: + raise ValueError("DFLASH does not support grammar-constrained decoding yet.") + if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: + if not self._warned_forced_greedy and self.tp_rank == 0: + logger.warning( + "DFLASH currently supports greedy verification only; " + "ignoring non-greedy sampling params (e.g. temperature/top_p/top_k) and using argmax." + ) + self._warned_forced_greedy = True + + if self.draft_model is None: + raise RuntimeError("DFLASH HF draft model is not initialized.") + if draft_input.draft_caches is None: + raise RuntimeError("DFLASH HF draft state missing draft_caches.") + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError("DFLASH draft state missing target_hidden context features.") + if len(draft_input.ctx_lens_cpu) != bs: + raise RuntimeError( + f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." + ) + if len(draft_input.draft_caches) != bs: + raise RuntimeError( + f"DFLASH draft_caches length mismatch: got {len(draft_input.draft_caches)} for bs={bs}." + ) embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() @@ -150,7 +464,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D offsets.append(offsets[-1] + int(ln)) candidates: List[torch.Tensor] = [] - for i, (req, ctx_len) in enumerate(zip(batch.reqs, draft_input.ctx_lens_cpu, strict=True)): + for i, ctx_len in enumerate(draft_input.ctx_lens_cpu): start_pos = int(batch.seq_lens_cpu[i].item()) cache = draft_input.draft_caches[i] cache_len = int(cache.get_seq_length()) @@ -253,15 +567,29 @@ def forward_batch_generation( "Make sure the target model has DFlash layers-to-capture configured." ) - draft_caches = [self.draft_model.make_cache() for _ in batch.reqs] ctx_lens_cpu = model_worker_batch.seq_lens_cpu.tolist() - batch.spec_info = DFlashDraftInput( - verified_id=next_token_ids.to(torch.int64), - target_hidden=logits_output.hidden_states, - ctx_lens_cpu=ctx_lens_cpu, - draft_caches=draft_caches, - ) + if self.draft_impl == "native": + batch.spec_info = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens_cpu=ctx_lens_cpu, + draft_seq_lens_cpu=[0] * len(ctx_lens_cpu), + draft_caches=None, + ) + for req in batch.reqs: + req.dflash_draft_seq_len = 0 + else: + if self.draft_model is None: + raise RuntimeError("DFLASH HF draft model is not initialized.") + draft_caches = [self.draft_model.make_cache() for _ in batch.reqs] + batch.spec_info = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens_cpu=ctx_lens_cpu, + draft_seq_lens_cpu=None, + draft_caches=draft_caches, + ) return GenerationBatchResult( logits_output=logits_output, diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py index a46d38419591..74e5551a61cf 100644 --- a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py +++ b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py @@ -330,6 +330,121 @@ def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: } print(json.dumps(report, indent=2), flush=True) + def test_qwen3_dflash_native_matches_hf(self): + if is_in_ci(): + self.skipTest("Manual benchmark; skipped in CI.") + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for this manual DFlash benchmark.") + + target_model = os.getenv("SGLANG_DFLASH_TARGET_MODEL", "Qwen/Qwen3-8B") + draft_model_path = os.getenv( + "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" + ) + if not os.path.isdir(draft_model_path): + self.skipTest( + f"Draft model folder not found: {draft_model_path}. " + "Set SGLANG_DFLASH_DRAFT_MODEL_PATH to run this benchmark." + ) + + attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") + max_new_tokens = int(os.getenv("SGLANG_DFLASH_PARITY_MAX_NEW_TOKENS", "256")) + parallel = int(os.getenv("SGLANG_DFLASH_PARITY_PARALLEL", "1")) + num_questions = int(os.getenv("SGLANG_DFLASH_PARITY_NUM_QUESTIONS", "10")) + max_total_tokens = int(os.getenv("SGLANG_DFLASH_PARITY_MAX_TOTAL_TOKENS", "8192")) + num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) + disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" + prompt_style = os.getenv("SGLANG_DFLASH_PROMPT_STYLE", "fewshot_qa") + + # Read GSM8K data (download if absent). + data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + if not os.path.isfile(data_path): + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) + + tokenizer = None + if prompt_style == "dflash_chat": + tokenizer = AutoTokenizer.from_pretrained(target_model) + + few_shot = _get_few_shot_examples(lines, num_shots) if prompt_style == "fewshot_qa" else "" + prompts: list[str] = [] + for i in range(len(lines[:num_questions])): + if prompt_style == "fewshot_qa": + prompts.append(few_shot + _get_one_example(lines, i, False)) + elif prompt_style == "dflash_chat": + assert tokenizer is not None + user_content = ( + lines[i]["question"] + + "\nPlease reason step by step, and put your final answer within \\boxed{}." + ) + prompts.append( + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ) + else: + raise ValueError(f"Unsupported SGLANG_DFLASH_PROMPT_STYLE: {prompt_style}") + + common_server_args = [ + "--attention-backend", + attention_backend, + "--max-total-tokens", + str(max_total_tokens), + ] + if disable_radix_cache: + common_server_args.append("--disable-radix-cache") + extra_server_args = os.getenv("SGLANG_DFLASH_EXTRA_SERVER_ARGS", "").strip() + if extra_server_args: + common_server_args.extend(shlex.split(extra_server_args)) + + def _run_dflash(draft_impl: str, port: int) -> list[dict]: + base_url = f"http://127.0.0.1:{port}" + env = dict(os.environ) + env["SGLANG_DFLASH_DRAFT_IMPL"] = draft_impl + proc = popen_launch_server( + target_model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + draft_model_path, + ], + env=env, + ) + try: + _send_generate(base_url, "Hello", max_new_tokens=8) # warmup + _, outputs = _run_generate_batch( + base_url, + prompts, + max_new_tokens=max_new_tokens, + parallel=parallel, + ) + return outputs + finally: + kill_process_tree(proc.pid) + try: + proc.wait(timeout=30) + except Exception: + pass + + hf_outputs = _run_dflash("hf", find_available_port(21000)) + native_outputs = _run_dflash("native", find_available_port(22000)) + + for i, (hf_out, native_out) in enumerate( + zip(hf_outputs, native_outputs, strict=True) + ): + if hf_out.get("output_ids") != native_out.get("output_ids"): + raise AssertionError( + "HF and native DFLASH outputs diverged at index " + f"{i}.\nhf={hf_out.get('output_ids')}\nnative={native_out.get('output_ids')}" + ) + if __name__ == "__main__": unittest.main() From 2c64b0e4bacc67ae0382959457e2c9ab97eaf381 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 06:34:33 +0000 Subject: [PATCH 06/73] remove hf backend --- .../srt/speculative/dflash_draft_model.py | 273 ------------------ python/sglang/srt/speculative/dflash_info.py | 22 +- .../sglang/srt/speculative/dflash_worker.py | 250 +++------------- .../models/test_qwen3_dflash_gsm8k_bench.py | 77 +++-- 4 files changed, 76 insertions(+), 546 deletions(-) delete mode 100644 python/sglang/srt/speculative/dflash_draft_model.py diff --git a/python/sglang/srt/speculative/dflash_draft_model.py b/python/sglang/srt/speculative/dflash_draft_model.py deleted file mode 100644 index bd6b22132c09..000000000000 --- a/python/sglang/srt/speculative/dflash_draft_model.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -import logging -import os -from typing import Optional - -import torch -from safetensors.torch import safe_open -from torch import nn -from transformers import AutoConfig, DynamicCache -from transformers.cache_utils import Cache -from transformers.models.qwen3.modeling_qwen3 import ( - ALL_ATTENTION_FUNCTIONS, - Qwen3MLP, - Qwen3RMSNorm, - Qwen3RotaryEmbedding, - FlashAttentionKwargs, - eager_attention_forward, - rotate_half, -) -from typing_extensions import Unpack - -logger = logging.getLogger(__name__) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - unsqueeze_dim: int = 1, -) -> tuple[torch.Tensor, torch.Tensor]: - cos = cos.unsqueeze(unsqueeze_dim) - sin = sin.unsqueeze(unsqueeze_dim) - q_len = q.size(-2) - q_embed = (q * cos[..., -q_len:, :]) + (rotate_half(q) * sin[..., -q_len:, :]) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class Qwen3DFlashAttention(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = config.attention_dropout - self.is_causal = False - - self.q_proj = nn.Linear( - config.hidden_size, - config.num_attention_heads * self.head_dim, - bias=config.attention_bias, - ) - self.k_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.v_proj = nn.Linear( - config.hidden_size, - config.num_key_value_heads * self.head_dim, - bias=config.attention_bias, - ) - self.o_proj = nn.Linear( - config.num_attention_heads * self.head_dim, - config.hidden_size, - bias=config.attention_bias, - ) - - self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) - - def forward( - self, - hidden_states: torch.Tensor, - target_hidden: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - past_key_values: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - bsz, q_len = hidden_states.shape[:-1] - ctx_len = target_hidden.shape[1] - - q = self.q_proj(hidden_states) - q = q.view(bsz, q_len, -1, self.head_dim) - q = self.q_norm(q).transpose(1, 2) - - k_ctx = self.k_proj(target_hidden) - k_noise = self.k_proj(hidden_states) - v_ctx = self.v_proj(target_hidden) - v_noise = self.v_proj(hidden_states) - k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) - v = torch.cat([v_ctx, v_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) - k = self.k_norm(k).transpose(1, 2) - v = v.transpose(1, 2) - - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb(q, k, cos, sin) - - if past_key_values is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k, v = past_key_values.update(k, v, self.layer_idx, cache_kwargs) - - attn_fn = eager_attention_forward - if getattr(self.config, "_attn_implementation", "eager") != "eager": - attn_fn = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attn_fn( - self, - q, - k, - v, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=None, - **kwargs, - ) - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.o_proj(attn_output) - return attn_output, attn_weights - - -class Qwen3DFlashDecoderLayer(nn.Module): - def __init__(self, config, layer_idx: int): - super().__init__() - self.self_attn = Qwen3DFlashAttention(config=config, layer_idx=layer_idx) - self.mlp = Qwen3MLP(config) - self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - def forward( - self, - *, - target_hidden: torch.Tensor, - hidden_states: torch.Tensor, - position_embeddings: tuple[torch.Tensor, torch.Tensor], - attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, - cache_position: Optional[torch.LongTensor] = None, - use_cache: bool = False, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - hidden_states=hidden_states, - target_hidden=target_hidden, - attention_mask=attention_mask, - past_key_values=past_key_value, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - )[0] - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class DFlashDraftModel(nn.Module): - """Local (non-trust_remote_code) DFlash draft model implementation. - - This is adapted from the DFlash reference `modeling_dflash.py` shipped with - the draft checkpoint, but is loaded as first-party code in SGLang. - - The model intentionally does NOT include embedding or lm_head weights; the - DFlash algorithm uses the target model's embedding and lm_head. - """ - - def __init__(self, config) -> None: - super().__init__() - self.config = config - - self.layers = nn.ModuleList( - [ - Qwen3DFlashDecoderLayer(config, layer_idx=i) - for i in range(config.num_hidden_layers) - ] - ) - self.norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = Qwen3RotaryEmbedding(config) - - # DFlash context feature projection: concat(draft_num_layers x hidden_size) -> hidden_size. - self.fc = nn.Linear( - config.num_hidden_layers * config.hidden_size, - config.hidden_size, - bias=False, - ) - self.hidden_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.block_size = config.block_size - - def forward( - self, - *, - noise_embedding: torch.Tensor, - target_hidden: torch.Tensor, - position_ids: torch.LongTensor, - past_key_values: Optional[Cache] = None, - use_cache: bool = False, - attention_mask: Optional[torch.Tensor] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> torch.Tensor: - hidden_states = noise_embedding - target_hidden = self.hidden_norm(self.fc(target_hidden)) - - position_embeddings = self.rotary_emb(hidden_states, position_ids) - for layer in self.layers: - hidden_states = layer( - hidden_states=hidden_states, - target_hidden=target_hidden, - attention_mask=attention_mask, - past_key_value=past_key_values, - cache_position=cache_position, - use_cache=use_cache, - position_embeddings=position_embeddings, - **kwargs, - ) - - return self.norm(hidden_states) - - def make_cache(self) -> DynamicCache: - return DynamicCache() - - -def load_dflash_draft_model( - model_path: str, - *, - device: torch.device, - dtype: torch.dtype, -) -> tuple[DFlashDraftModel, object]: - """Load DFlash draft model weights from a local folder.""" - config = AutoConfig.from_pretrained(model_path, trust_remote_code=False) - # Ensure we don't accidentally select optional FlashAttention implementations. - setattr(config, "_attn_implementation", "eager") - - model = DFlashDraftModel(config).to(device=device, dtype=dtype) - - weights_path = os.path.join(model_path, "model.safetensors") - if not os.path.isfile(weights_path): - raise FileNotFoundError(f"DFLASH draft weights not found: {weights_path}") - - model_state = model.state_dict() - unexpected: list[str] = [] - with safe_open(weights_path, framework="pt", device=str(device)) as f: - for key in f.keys(): - if key not in model_state: - unexpected.append(key) - continue - model_state[key].copy_(f.get_tensor(key)) - - if unexpected: - logger.warning( - "DFLASH draft checkpoint has %d unexpected keys (ignored). Example: %s", - len(unexpected), - unexpected[0], - ) - - model.eval() - model.requires_grad_(False) - return model, config - diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index e5a7bf012262..e7804356af54 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -27,8 +27,7 @@ class DFlashDraftInput(SpecInput): the draft model and to track draft-side cache progress. Invariant (per request): - - Native path: `draft_seq_len + ctx_len == batch.seq_lens[i]` - - HF path: `draft_cache.get_seq_length() + ctx_len == batch.seq_lens[i]` + - `draft_seq_len + ctx_len == batch.seq_lens[i]` where `ctx_len` is the number of target context-feature tokens carried in `target_hidden` for that request. """ @@ -45,10 +44,7 @@ class DFlashDraftInput(SpecInput): # Native implementation: how many tokens are already materialized in the draft KV cache. # The next draft step appends `ctx_lens_cpu[i]` tokens starting at `draft_seq_lens_cpu[i]`. - draft_seq_lens_cpu: List[int] | None = None - - # HF-style baseline implementation: per-request transformers DynamicCache objects. - draft_caches: List[object] | None = None + draft_seq_lens_cpu: List[int] def __post_init__(self): super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) @@ -65,10 +61,7 @@ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True self.verified_id = self.verified_id[new_indices] self.ctx_lens_cpu = [old_ctx_lens_cpu[i] for i in keep_indices] - if self.draft_seq_lens_cpu is not None: - self.draft_seq_lens_cpu = [self.draft_seq_lens_cpu[i] for i in keep_indices] - if self.draft_caches is not None: - self.draft_caches = [self.draft_caches[i] for i in keep_indices] + self.draft_seq_lens_cpu = [self.draft_seq_lens_cpu[i] for i in keep_indices] if old_target_hidden is None or old_target_hidden.numel() == 0: self.target_hidden = old_target_hidden @@ -90,14 +83,7 @@ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True def merge_batch(self, spec_info: "DFlashDraftInput"): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) self.ctx_lens_cpu.extend(spec_info.ctx_lens_cpu) - if self.draft_seq_lens_cpu is not None or spec_info.draft_seq_lens_cpu is not None: - if self.draft_seq_lens_cpu is None or spec_info.draft_seq_lens_cpu is None: - raise ValueError("Cannot merge DFLASH draft batches with mismatched draft_seq_lens_cpu presence.") - self.draft_seq_lens_cpu.extend(spec_info.draft_seq_lens_cpu) - if self.draft_caches is not None or spec_info.draft_caches is not None: - if self.draft_caches is None or spec_info.draft_caches is None: - raise ValueError("Cannot merge DFLASH draft batches with mismatched draft_caches presence.") - self.draft_caches.extend(spec_info.draft_caches) + self.draft_seq_lens_cpu.extend(spec_info.draft_seq_lens_cpu) if self.target_hidden is None or self.target_hidden.numel() == 0: self.target_hidden = spec_info.target_hidden elif spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0: diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index cf9d4963ccd1..e99cc9b55464 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -1,5 +1,4 @@ import logging -import os from copy import deepcopy from typing import List, Optional, Union @@ -16,7 +15,6 @@ compute_position, ) from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.dflash_draft_model import load_dflash_draft_model from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func @@ -53,73 +51,36 @@ def __init__( self._warned_forced_greedy = False self._logged_first_verify = False - self.draft_impl = os.getenv("SGLANG_DFLASH_DRAFT_IMPL", "native").strip().lower() - if self.draft_impl not in ("native", "hf"): - raise ValueError( - "Invalid SGLANG_DFLASH_DRAFT_IMPL. " - f"Expected 'native' or 'hf', got: {self.draft_impl!r}" - ) - - self.native_draft_worker: TpModelWorker | None = None - self.native_draft_model_runner = None - self.native_draft_model = None - self.draft_model = None - self.draft_config = None - - draft_device = torch.device(target_worker.device) - draft_dtype = target_worker.model_runner.dtype - - if self.draft_impl == "native": - # Native (SGLang) draft runner (separate KV cache + attention backend). - draft_server_args = deepcopy(server_args) - draft_server_args.skip_tokenizer_init = True - draft_server_args.disable_cuda_graph = True - # Force FA3 for draft (Hopper-friendly, and supports ENCODER_ONLY attention). - draft_server_args.attention_backend = "fa3" - # Keep draft context length aligned with the target. - draft_server_args.context_length = target_worker.model_runner.model_config.context_len - self.native_draft_worker = TpModelWorker( - server_args=draft_server_args, - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - pp_rank=0, - dp_rank=dp_rank, - nccl_port=nccl_port, - is_draft_worker=True, - ) - self.native_draft_model_runner = self.native_draft_worker.model_runner - self.native_draft_model = self.native_draft_model_runner.model - self.block_size = int(getattr(self.native_draft_model, "block_size", 16)) - if self.tp_rank == 0: - logger.info( - "Initialized native DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", - getattr(draft_server_args, "attention_backend", None), - self.native_draft_model.__class__.__name__, - self.block_size, - ) - else: - # HF-style baseline draft implementation (DynamicCache). - self.draft_model, self.draft_config = load_dflash_draft_model( - server_args.speculative_draft_model_path, - device=draft_device, - dtype=draft_dtype, - ) - self.block_size = int(getattr(self.draft_config, "block_size", 16)) - if self.tp_rank == 0: - logger.info( - "Initialized HF-style DFLASH draft model. path=%s, dtype=%s, device=%s, block_size=%s, num_hidden_layers=%s", - server_args.speculative_draft_model_path, - draft_dtype, - draft_device, - self.block_size, - getattr(self.draft_config, "num_hidden_layers", None), - ) - + # Native (SGLang) draft runner (separate KV cache + attention backend). + draft_server_args = deepcopy(server_args) + draft_server_args.skip_tokenizer_init = True + draft_server_args.disable_cuda_graph = True + # Force FA3 for draft (Hopper-friendly, and supports ENCODER_ONLY attention). + draft_server_args.attention_backend = "fa3" + # Keep draft context length aligned with the target. + draft_server_args.context_length = target_worker.model_runner.model_config.context_len + self.native_draft_worker = TpModelWorker( + server_args=draft_server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + pp_rank=0, + dp_rank=dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + ) + self.native_draft_model_runner = self.native_draft_worker.model_runner + self.native_draft_model = self.native_draft_model_runner.model + self.block_size = int(getattr(self.native_draft_model, "block_size", 16)) if self.tp_rank == 0: logger.info( - "DFLASH draft impl selected. impl=%s, mask_token_id=%s", - self.draft_impl, + "Initialized native DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", + getattr(draft_server_args, "attention_backend", None), + self.native_draft_model.__class__.__name__, + self.block_size, + ) + logger.info( + "DFLASH draft impl selected. impl=native, mask_token_id=%s", self._mask_token_id, ) @@ -141,10 +102,6 @@ def on_req_finished(self, req): when the request completes to avoid leaking draft KV memory across requests. """ - if self.draft_impl != "native": - return - if self.native_draft_model_runner is None: - return req_pool_idx = getattr(req, "req_pool_idx", None) if req_pool_idx is None: return @@ -204,13 +161,6 @@ def _resolve_mask_token_id(self) -> int: return int(mask_token_id) def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: DFlashDraftInput): - if self.draft_impl == "native": - return self._prepare_for_speculative_decoding_native(batch, draft_input) - return self._prepare_for_speculative_decoding_hf(batch, draft_input) - - def _prepare_for_speculative_decoding_native( - self, batch: ScheduleBatch, draft_input: DFlashDraftInput - ): if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): return @@ -233,15 +183,11 @@ def _prepare_for_speculative_decoding_native( raise RuntimeError( f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." ) - if draft_input.draft_seq_lens_cpu is None: - raise RuntimeError("DFLASH native draft state missing draft_seq_lens_cpu.") if len(draft_input.draft_seq_lens_cpu) != bs: raise RuntimeError( "DFLASH draft_seq_lens_cpu length mismatch: " f"got {len(draft_input.draft_seq_lens_cpu)} for bs={bs}." ) - if self.native_draft_model_runner is None or self.native_draft_model is None: - raise RuntimeError("DFLASH native draft runner is not initialized.") embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() @@ -421,117 +367,6 @@ def _prepare_for_speculative_decoding_native( batch.spec_info = verify_input batch.return_hidden_states = False - def _prepare_for_speculative_decoding_hf( - self, batch: ScheduleBatch, draft_input: DFlashDraftInput - ): - if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): - return - - if batch.has_grammar: - raise ValueError("DFLASH does not support grammar-constrained decoding yet.") - if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: - if not self._warned_forced_greedy and self.tp_rank == 0: - logger.warning( - "DFLASH currently supports greedy verification only; " - "ignoring non-greedy sampling params (e.g. temperature/top_p/top_k) and using argmax." - ) - self._warned_forced_greedy = True - - if self.draft_model is None: - raise RuntimeError("DFLASH HF draft model is not initialized.") - if draft_input.draft_caches is None: - raise RuntimeError("DFLASH HF draft state missing draft_caches.") - - bs = batch.batch_size() - device = self.model_runner.device - - if draft_input.target_hidden is None: - raise RuntimeError("DFLASH draft state missing target_hidden context features.") - if len(draft_input.ctx_lens_cpu) != bs: - raise RuntimeError( - f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." - ) - if len(draft_input.draft_caches) != bs: - raise RuntimeError( - f"DFLASH draft_caches length mismatch: got {len(draft_input.draft_caches)} for bs={bs}." - ) - - embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() - - # Slice ragged target_hidden on CPU for simplicity. - offsets: List[int] = [0] - for ln in draft_input.ctx_lens_cpu: - offsets.append(offsets[-1] + int(ln)) - - candidates: List[torch.Tensor] = [] - for i, ctx_len in enumerate(draft_input.ctx_lens_cpu): - start_pos = int(batch.seq_lens_cpu[i].item()) - cache = draft_input.draft_caches[i] - cache_len = int(cache.get_seq_length()) - - if cache_len + int(ctx_len) != start_pos: - raise RuntimeError( - "DFLASH draft cache length mismatch. " - f"{cache_len=} + {ctx_len=} != {start_pos=}. " - "This can happen if prefix caching is enabled; start with `--disable-radix-cache` for now." - ) - - target_hidden = draft_input.target_hidden[offsets[i] : offsets[i + 1]] - target_hidden = target_hidden.unsqueeze(0) # [1, ctx_len, feat] - - block_ids = torch.full( - (1, self.block_size), - self._mask_token_id, - dtype=torch.long, - device=device, - ) - block_ids[0, 0] = draft_input.verified_id[i].to(torch.long) - - noise_embedding = F.embedding(block_ids, embed_weight) - position_ids = torch.arange( - cache_len, - start_pos + self.block_size, - dtype=torch.long, - device=device, - ).unsqueeze(0) - - with torch.inference_mode(): - hidden = self.draft_model( - noise_embedding=noise_embedding, - target_hidden=target_hidden, - position_ids=position_ids, - past_key_values=cache, - use_cache=True, - ) - cache.crop(start_pos) - - draft_hidden = hidden[:, -self.block_size + 1 :, :] - draft_logits = F.linear(draft_hidden, head_weight) - draft_tokens = torch.argmax(draft_logits, dim=-1).to(torch.long) - - candidate = torch.cat( - [block_ids[0, 0].view(1), draft_tokens.view(-1)], - dim=0, - ) - candidates.append(candidate) - - draft_tokens = torch.stack(candidates, dim=0) # [bs, block_size] - positions = ( - batch.seq_lens.to(torch.long).unsqueeze(1) - + torch.arange(self.block_size, device=device, dtype=torch.long)[None, :] - ).flatten() - - verify_input = DFlashVerifyInput( - draft_token=draft_tokens.flatten(), - positions=positions, - draft_token_num=self.block_size, - ) - verify_input.prepare_for_verify(batch, self.page_size) - - batch.forward_mode = ForwardMode.TARGET_VERIFY if not batch.forward_mode.is_idle() else ForwardMode.IDLE - batch.spec_info = verify_input - batch.return_hidden_states = False - def forward_batch_generation( self, batch: Union[ScheduleBatch, ModelWorkerBatch], @@ -569,27 +404,14 @@ def forward_batch_generation( ctx_lens_cpu = model_worker_batch.seq_lens_cpu.tolist() - if self.draft_impl == "native": - batch.spec_info = DFlashDraftInput( - verified_id=next_token_ids.to(torch.int64), - target_hidden=logits_output.hidden_states, - ctx_lens_cpu=ctx_lens_cpu, - draft_seq_lens_cpu=[0] * len(ctx_lens_cpu), - draft_caches=None, - ) - for req in batch.reqs: - req.dflash_draft_seq_len = 0 - else: - if self.draft_model is None: - raise RuntimeError("DFLASH HF draft model is not initialized.") - draft_caches = [self.draft_model.make_cache() for _ in batch.reqs] - batch.spec_info = DFlashDraftInput( - verified_id=next_token_ids.to(torch.int64), - target_hidden=logits_output.hidden_states, - ctx_lens_cpu=ctx_lens_cpu, - draft_seq_lens_cpu=None, - draft_caches=draft_caches, - ) + batch.spec_info = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens_cpu=ctx_lens_cpu, + draft_seq_lens_cpu=[0] * len(ctx_lens_cpu), + ) + for req in batch.reqs: + req.dflash_draft_seq_len = 0 return GenerationBatchResult( logits_output=logits_output, diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py index 74e5551a61cf..fc51bd64135e 100644 --- a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py +++ b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py @@ -331,6 +331,7 @@ def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: print(json.dumps(report, indent=2), flush=True) def test_qwen3_dflash_native_matches_hf(self): + """Legacy name: previously asserted HF-vs-native parity; now a native smoke/stability run.""" if is_in_ci(): self.skipTest("Manual benchmark; skipped in CI.") if not torch.cuda.is_available(): @@ -347,6 +348,7 @@ def test_qwen3_dflash_native_matches_hf(self): ) attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") + # Keep env var names for backwards compatibility with the previous parity test. max_new_tokens = int(os.getenv("SGLANG_DFLASH_PARITY_MAX_NEW_TOKENS", "256")) parallel = int(os.getenv("SGLANG_DFLASH_PARITY_PARALLEL", "1")) num_questions = int(os.getenv("SGLANG_DFLASH_PARITY_NUM_QUESTIONS", "10")) @@ -400,50 +402,43 @@ def test_qwen3_dflash_native_matches_hf(self): if extra_server_args: common_server_args.extend(shlex.split(extra_server_args)) - def _run_dflash(draft_impl: str, port: int) -> list[dict]: - base_url = f"http://127.0.0.1:{port}" - env = dict(os.environ) - env["SGLANG_DFLASH_DRAFT_IMPL"] = draft_impl - proc = popen_launch_server( - target_model, + port = find_available_port(21000) + base_url = f"http://127.0.0.1:{port}" + proc = popen_launch_server( + target_model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + draft_model_path, + ], + ) + try: + _send_generate(base_url, "Hello", max_new_tokens=8) # warmup + _, outputs = _run_generate_batch( base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *common_server_args, - "--speculative-algorithm", - "DFLASH", - "--speculative-draft-model-path", - draft_model_path, - ], - env=env, + prompts, + max_new_tokens=max_new_tokens, + parallel=parallel, ) + finally: + kill_process_tree(proc.pid) try: - _send_generate(base_url, "Hello", max_new_tokens=8) # warmup - _, outputs = _run_generate_batch( - base_url, - prompts, - max_new_tokens=max_new_tokens, - parallel=parallel, - ) - return outputs - finally: - kill_process_tree(proc.pid) - try: - proc.wait(timeout=30) - except Exception: - pass - - hf_outputs = _run_dflash("hf", find_available_port(21000)) - native_outputs = _run_dflash("native", find_available_port(22000)) - - for i, (hf_out, native_out) in enumerate( - zip(hf_outputs, native_outputs, strict=True) - ): - if hf_out.get("output_ids") != native_out.get("output_ids"): - raise AssertionError( - "HF and native DFLASH outputs diverged at index " - f"{i}.\nhf={hf_out.get('output_ids')}\nnative={native_out.get('output_ids')}" - ) + proc.wait(timeout=30) + except Exception: + pass + + self.assertEqual(len(outputs), len(prompts)) + spec_verify_cts: list[int] = [] + for out in outputs: + meta = out.get("meta_info", {}) + if "spec_verify_ct" in meta: + spec_verify_cts.append(int(meta["spec_verify_ct"])) + self.assertTrue(spec_verify_cts, "Missing spec_verify_ct in DFLASH responses.") + self.assertGreater(sum(spec_verify_cts), 0, "DFLASH did not run verify steps.") if __name__ == "__main__": From f1a42626a9e90b3f1b8854f75a20fbd01e219186 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 06:54:38 +0000 Subject: [PATCH 07/73] dflash support flashinfer --- .../layers/attention/flashinfer_backend.py | 21 ++++++++----------- .../sglang/srt/speculative/dflash_worker.py | 19 +++++++++++++++-- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 3a72264a85b6..19a1e7816cf1 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -768,10 +768,14 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + causal = ( + not layer.is_cross_attention + and layer.attn_type != AttentionType.ENCODER_ONLY + ) o = prefill_wrapper_paged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, + causal=causal, sm_scale=layer.scaling, # Disable sliding window attention for multi-item scoring: # - Sliding window could cut across item boundaries, breaking semantic coherence @@ -793,12 +797,10 @@ def forward_extend( v_scale=layer.v_scale_float, ) else: - causal = True - if ( - layer.is_cross_attention - or layer.attn_type == AttentionType.ENCODER_ONLY - ): - causal = False + causal = ( + not layer.is_cross_attention + and layer.attn_type != AttentionType.ENCODER_ONLY + ) if save_kv_cache and layer.attn_type == AttentionType.ENCODER_ONLY: save_kv_cache = False @@ -816,11 +818,6 @@ def forward_extend( ) else: - if not self.is_dllm_model: - # TODO: design a better interface - # For other models, use causal attention for the ragged part as previously - causal = True - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index e99cc9b55464..63a154fd2a64 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -55,8 +55,23 @@ def __init__( draft_server_args = deepcopy(server_args) draft_server_args.skip_tokenizer_init = True draft_server_args.disable_cuda_graph = True - # Force FA3 for draft (Hopper-friendly, and supports ENCODER_ONLY attention). - draft_server_args.attention_backend = "fa3" + draft_backend = draft_server_args.speculative_draft_attention_backend + if draft_backend is None: + draft_backend, _ = draft_server_args.get_attention_backends() + if draft_backend is None: + draft_backend = "flashinfer" + if draft_backend not in ("flashinfer", "fa3"): + raise ValueError( + "DFLASH draft worker only supports attention_backend in {'flashinfer', 'fa3'} for now, " + f"but got {draft_backend!r}. " + "Use `--speculative-draft-attention-backend` to override the draft backend." + ) + + # Make the draft worker backend explicit and self-contained (no further overrides). + draft_server_args.speculative_draft_attention_backend = None + draft_server_args.prefill_attention_backend = None + draft_server_args.decode_attention_backend = None + draft_server_args.attention_backend = draft_backend # Keep draft context length aligned with the target. draft_server_args.context_length = target_worker.model_runner.model_config.context_len self.native_draft_worker = TpModelWorker( From 2c5b34610273ecccce192cdc7cae09936da6336b Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 07:32:55 +0000 Subject: [PATCH 08/73] remove manual management of dflash kv pool --- .../sglang/srt/speculative/dflash_worker.py | 185 ++++++++---------- 1 file changed, 87 insertions(+), 98 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 63a154fd2a64..4f8caf720b22 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -52,6 +52,11 @@ def __init__( self._logged_first_verify = False # Native (SGLang) draft runner (separate KV cache + attention backend). + # Share req_to_token_pool + token_to_kv_pool_allocator with the target worker (EAGLE3-style), + # while keeping a separate draft KV cache pool (the draft model has different KV values). + shared_req_to_token_pool, shared_token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) draft_server_args = deepcopy(server_args) draft_server_args.skip_tokenizer_init = True draft_server_args.disable_cuda_graph = True @@ -83,6 +88,8 @@ def __init__( dp_rank=dp_rank, nccl_port=nccl_port, is_draft_worker=True, + req_to_token_pool=shared_req_to_token_pool, + token_to_kv_pool_allocator=shared_token_to_kv_pool_allocator, ) self.native_draft_model_runner = self.native_draft_worker.model_runner self.native_draft_model = self.native_draft_model_runner.model @@ -104,33 +111,14 @@ def __getattr__(self, name): return getattr(self.target_worker, name) def clear_cache_pool(self): - if self.native_draft_model_runner is None: - return - self.native_draft_model_runner.req_to_token_pool.clear() - self.native_draft_model_runner.token_to_kv_pool_allocator.clear() + # allocator and req_to_token_pool are shared with target worker + pass def on_req_finished(self, req): - """Release native-draft KV cache for a finished request. - - The native draft path uses a separate KV pool that is not managed by the - scheduler/tree-cache. We must explicitly free per-request draft KV slots - when the request completes to avoid leaking draft KV memory across - requests. - """ - req_pool_idx = getattr(req, "req_pool_idx", None) - if req_pool_idx is None: - return - draft_len = getattr(req, "dflash_draft_seq_len", None) - if draft_len is None: - return - draft_len = int(draft_len) - if draft_len <= 0: - return - kv_indices = self.native_draft_model_runner.req_to_token_pool.req_to_token[ - req_pool_idx, :draft_len - ] - self.native_draft_model_runner.token_to_kv_pool_allocator.free(kv_indices) - req.dflash_draft_seq_len = 0 + # allocator and req_to_token_pool are shared with the target worker; + # there is no separate draft allocation to release here. + if hasattr(req, "dflash_draft_seq_len"): + req.dflash_draft_seq_len = 0 def _resolve_mask_token_id(self) -> int: tokenizer = getattr(self.target_worker, "tokenizer", None) @@ -223,35 +211,37 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D total_ctx = int(sum(int(x) for x in draft_input.ctx_lens_cpu)) if total_ctx > 0: - ctx_start = torch.tensor( - draft_input.draft_seq_lens_cpu, dtype=torch.int64, device=device - ) - ctx_len = torch.tensor(draft_input.ctx_lens_cpu, dtype=torch.int64, device=device) - ctx_end = ctx_start + ctx_len + req_to_token = self.native_draft_model_runner.req_to_token_pool.req_to_token + req_pool_indices_cpu = batch.req_pool_indices.tolist() - ctx_cache_loc = self.native_draft_model_runner.token_to_kv_pool_allocator.alloc( - total_ctx - ) - if ctx_cache_loc is None: - raise RuntimeError( - f"DFLASH native draft OOM when allocating {total_ctx} context tokens." + ctx_cache_loc_chunks: List[torch.Tensor] = [] + ctx_positions_chunks: List[torch.Tensor] = [] + new_draft_seq_lens_cpu: List[int] = [] + for req_pool_idx, cache_len, ctx_len in zip( + req_pool_indices_cpu, + draft_input.draft_seq_lens_cpu, + draft_input.ctx_lens_cpu, + strict=True, + ): + cache_len_i = int(cache_len) + ctx_len_i = int(ctx_len) + new_draft_seq_lens_cpu.append(cache_len_i + ctx_len_i) + if ctx_len_i <= 0: + continue + s = cache_len_i + e = cache_len_i + ctx_len_i + ctx_cache_loc_chunks.append( + req_to_token[req_pool_idx, s:e].to(torch.int64) ) - - assign_req_to_token_pool_func( - batch.req_pool_indices, - self.native_draft_model_runner.req_to_token_pool.req_to_token, - ctx_start, - ctx_end, - ctx_cache_loc, - bs, + ctx_positions_chunks.append( + torch.arange(s, e, device=device, dtype=torch.int64) + ) + ctx_cache_loc = ( + torch.cat(ctx_cache_loc_chunks, dim=0) + if ctx_cache_loc_chunks + else torch.empty((0,), dtype=torch.int64, device=device) ) - ctx_positions_chunks: List[torch.Tensor] = [] - for s, e in zip(ctx_start.tolist(), ctx_end.tolist(), strict=True): - if e > s: - ctx_positions_chunks.append( - torch.arange(s, e, device=device, dtype=torch.int64) - ) ctx_positions = ( torch.cat(ctx_positions_chunks, dim=0) if ctx_positions_chunks @@ -281,11 +271,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D attn.attn.v_scale, ) - draft_input.draft_seq_lens_cpu = ctx_end.to(torch.int64).cpu().tolist() - for req, seq_len in zip( - batch.reqs, draft_input.draft_seq_lens_cpu, strict=True - ): - req.dflash_draft_seq_len = int(seq_len) + draft_input.draft_seq_lens_cpu = new_draft_seq_lens_cpu draft_input.ctx_lens_cpu = [0] * bs draft_input.target_hidden = draft_input.target_hidden[:0] @@ -315,52 +301,55 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D block_start = prefix_lens.to(torch.int64) block_end = block_start + int(self.block_size) - block_cache_loc = self.native_draft_model_runner.token_to_kv_pool_allocator.alloc( - bs * self.block_size - ) - if block_cache_loc is None: - raise RuntimeError( - f"DFLASH native draft OOM when allocating {bs * self.block_size} block tokens." - ) - - assign_req_to_token_pool_func( - batch.req_pool_indices, - self.native_draft_model_runner.req_to_token_pool.req_to_token, - block_start, - block_end, - block_cache_loc, - bs, - ) + allocator = self.native_draft_model_runner.token_to_kv_pool_allocator + token_to_kv_pool_state_backup = allocator.backup_state() + try: + block_cache_loc = allocator.alloc(bs * self.block_size) + if block_cache_loc is None: + raise RuntimeError( + f"DFLASH native draft OOM when allocating {bs * self.block_size} block tokens." + ) - seq_lens = block_end.to(torch.int64) - forward_batch = ForwardBatch( - forward_mode=ForwardMode.EXTEND, - batch_size=bs, - input_ids=block_ids.flatten(), - req_pool_indices=batch.req_pool_indices, - seq_lens=seq_lens, - out_cache_loc=block_cache_loc, - seq_lens_sum=int(seq_lens.sum().item()), - seq_lens_cpu=torch.tensor(seq_lens.cpu().tolist(), dtype=torch.int64), - positions=positions, - extend_num_tokens=bs * self.block_size, - extend_seq_lens=extend_lens, - extend_prefix_lens=prefix_lens, - extend_start_loc=extend_start_loc, - extend_prefix_lens_cpu=prefix_lens_cpu, - extend_seq_lens_cpu=[int(self.block_size)] * bs, - extend_logprob_start_lens_cpu=[0] * bs, - req_to_token_pool=self.native_draft_model_runner.req_to_token_pool, - token_to_kv_pool=self.native_draft_model_runner.token_to_kv_pool, - attn_backend=self.native_draft_model_runner.attn_backend, - input_embeds=input_embeds, - ) + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.native_draft_model_runner.req_to_token_pool.req_to_token, + block_start, + block_end, + block_cache_loc, + bs, + ) - with torch.inference_mode(): - draft_hidden = self.native_draft_model_runner.forward(forward_batch).logits_output + seq_lens = block_end.to(torch.int64) + forward_batch = ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=batch.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=block_cache_loc, + seq_lens_sum=int(seq_lens.sum().item()), + seq_lens_cpu=torch.tensor(seq_lens.cpu().tolist(), dtype=torch.int64), + positions=positions, + extend_num_tokens=bs * self.block_size, + extend_seq_lens=extend_lens, + extend_prefix_lens=prefix_lens, + extend_start_loc=extend_start_loc, + extend_prefix_lens_cpu=prefix_lens_cpu, + extend_seq_lens_cpu=[int(self.block_size)] * bs, + extend_logprob_start_lens_cpu=[0] * bs, + req_to_token_pool=self.native_draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.native_draft_model_runner.token_to_kv_pool, + attn_backend=self.native_draft_model_runner.attn_backend, + input_embeds=input_embeds, + ) - # Crop: drop the speculative block from the draft KV cache (context stays). - self.native_draft_model_runner.token_to_kv_pool_allocator.free(block_cache_loc) + with torch.inference_mode(): + draft_hidden = self.native_draft_model_runner.forward( + forward_batch + ).logits_output + finally: + # Drop the speculative block from the shared allocator (EAGLE3-style). + allocator.restore_state(token_to_kv_pool_state_backup) draft_hidden = draft_hidden.view(bs, self.block_size, -1) draft_logits = F.linear(draft_hidden[:, 1:, :], head_weight) From 6a38e63a8179503479dd9811a13a9427b820ada3 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 08:10:28 +0000 Subject: [PATCH 09/73] add cuda graph --- .../srt/model_executor/cuda_graph_runner.py | 22 +++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 4 +++ python/sglang/srt/speculative/dflash_info.py | 28 ++++++++++++++++++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 6be1b11c68a8..234de4224d39 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -279,6 +279,7 @@ def __init__(self, model_runner: ModelRunner): model_runner.spec_algorithm.is_eagle() or model_runner.spec_algorithm.is_standalone() or model_runner.spec_algorithm.is_ngram() + or model_runner.spec_algorithm.is_dflash() ): if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen") @@ -354,6 +355,15 @@ def __init__(self, model_runner: ModelRunner): and model_runner.eagle_use_aux_hidden_state ): self.model_runner.model.set_eagle3_layers_to_capture() + if model_runner.spec_algorithm.is_dflash() and model_runner.dflash_use_aux_hidden_state: + if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH aux hidden capture." + ) + self.model_runner.model.set_dflash_layers_to_capture( + self.model_runner.dflash_aux_hidden_state_layer_ids + ) # Capture try: @@ -904,6 +914,18 @@ def get_spec_info(self, num_tokens: int): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.model_runner.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + + if self.model_runner.is_draft_worker: + raise RuntimeError("This should not happen.") + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + custom_mask=self.buffers.custom_mask, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) elif self.model_runner.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d50668ff950e..7d8999e28368 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1747,6 +1747,7 @@ def _should_run_flashinfer_autotune(self) -> bool: self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone() or self.spec_algorithm.is_ngram() + or self.spec_algorithm.is_dflash() ): return not self.is_draft_worker @@ -1775,6 +1776,7 @@ def _dummy_run(self, batch_size: int): self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone() or self.spec_algorithm.is_ngram() + or self.spec_algorithm.is_dflash() ): if self.is_draft_worker: raise RuntimeError("This should not happen") @@ -1794,6 +1796,8 @@ def _dummy_run(self, batch_size: int): if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() + if self.dflash_use_aux_hidden_state: + self.model.set_dflash_layers_to_capture(self.dflash_aux_hidden_state_layer_ids) require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args) if require_gathered_buffer(self.server_args): diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index e7804356af54..c1144be52900 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -107,8 +107,13 @@ class DFlashVerifyInput(SpecInput): custom_mask: torch.Tensor | None = None capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + # Shape info for padding (e.g., DP attention / CUDA graph). + num_tokens_per_batch: int = -1 + def __post_init__(self): super().__init__(spec_input_type=SpecInputType.DFLASH_VERIFY) + if self.num_tokens_per_batch == -1: + self.num_tokens_per_batch = int(self.draft_token_num) def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: return self.draft_token_num, self.draft_token_num @@ -207,7 +212,28 @@ def generate_attn_arg_prefill( kv_indices, req_to_token.size(1), ) - return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask + mask = self.custom_mask + if mask is not None: + mask_numel = ( + paged_kernel_lens_sum * self.draft_token_num + + (self.draft_token_num**2) * bs + ) + if mask.numel() < mask_numel: + # FIXME(attn): temporary fix for custom mask padding with cuda graph + mask = torch.cat( + [ + mask, + torch.full( + (mask_numel - mask.numel(),), + True, + dtype=torch.bool, + device=device, + ), + ], + dim=0, + ) + self.custom_mask = mask + return kv_indices, cum_kv_seq_len, qo_indptr, mask def verify( self, From 40a81aff6c14677f889d3dbc8bffa0b976df2ada Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 08:43:16 +0000 Subject: [PATCH 10/73] add cuda graph to draft worker --- .../srt/model_executor/cuda_graph_runner.py | 49 +++++++++++++++---- .../sglang/srt/speculative/dflash_worker.py | 27 +++++----- 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 234de4224d39..1515f441066f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -282,12 +282,13 @@ def __init__(self, model_runner: ModelRunner): or model_runner.spec_algorithm.is_dflash() ): if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) + # EAGLE/standalone/ngram draft workers use separate cuda-graph runners; do not + # capture TARGET_VERIFY graphs here. DFLASH draft uses a fixed-size block and + # reuses TARGET_VERIFY graphs for performance. + if not self.model_runner.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = self.model_runner.server_args.speculative_num_draft_tokens elif self.is_dllm: self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size @@ -716,6 +717,12 @@ def run_once(): kwargs["pp_proxy_tensors"] = PPProxyTensors( {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and "input_embeds" in inspect.signature(forward).parameters + ): + kwargs["input_embeds"] = buffers.input_embeds[:num_tokens] logits_output_or_pp_proxy_tensors = forward( input_ids, @@ -813,6 +820,14 @@ def replay_prepare( ), pp_proxy_tensors=pp_proxy_tensors, ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds) + if bs != raw_bs: + buffers.input_embeds[raw_num_token : bs * self.num_tokens_per_bs].zero_() if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, @@ -858,6 +873,14 @@ def replay( # In speculative decoding, these two fields are still needed. self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + self.buffers.input_embeds[: self.raw_num_token].copy_( + forward_batch.input_embeds + ) # Replay if self.enable_pdmux: @@ -867,6 +890,8 @@ def replay( self.graphs[graph_key].replay() output = self.output_buffers[graph_key] + if isinstance(output, torch.Tensor): + return output[: self.raw_num_token] if isinstance(output, LogitsProcessorOutput): if self.is_dllm: next_token_logits = None @@ -917,14 +942,18 @@ def get_spec_info(self, num_tokens: int): elif self.model_runner.spec_algorithm.is_dflash(): from sglang.srt.speculative.dflash_info import DFlashVerifyInput - if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen.") spec_info = DFlashVerifyInput( draft_token=None, positions=None, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, - custom_mask=self.buffers.custom_mask, - capture_hidden_mode=CaptureHiddenMode.FULL, + custom_mask=( + None if self.model_runner.is_draft_worker else self.buffers.custom_mask + ), + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.model_runner.is_draft_worker + else CaptureHiddenMode.FULL + ), ) elif self.model_runner.spec_algorithm.is_ngram(): diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 4f8caf720b22..dd3ac36436c7 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -59,7 +59,6 @@ def __init__( ) draft_server_args = deepcopy(server_args) draft_server_args.skip_tokenizer_init = True - draft_server_args.disable_cuda_graph = True draft_backend = draft_server_args.speculative_draft_attention_backend if draft_backend is None: draft_backend, _ = draft_server_args.get_attention_backends() @@ -319,28 +318,34 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D bs, ) - seq_lens = block_end.to(torch.int64) + # Use TARGET_VERIFY mode (cuda-graphable) to run a fixed-size draft block. + # In this mode, `seq_lens` stores the prefix lengths; attention backends + # derive kv_len by adding `draft_token_num`. + draft_spec_info = DFlashVerifyInput( + draft_token=torch.empty((0,), dtype=torch.long, device=device), + positions=torch.empty((0,), dtype=torch.int64, device=device), + draft_token_num=int(self.block_size), + custom_mask=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + seq_lens = prefix_lens.to(torch.int32) forward_batch = ForwardBatch( - forward_mode=ForwardMode.EXTEND, + forward_mode=ForwardMode.TARGET_VERIFY, batch_size=bs, input_ids=block_ids.flatten(), req_pool_indices=batch.req_pool_indices, seq_lens=seq_lens, out_cache_loc=block_cache_loc, seq_lens_sum=int(seq_lens.sum().item()), - seq_lens_cpu=torch.tensor(seq_lens.cpu().tolist(), dtype=torch.int64), + seq_lens_cpu=torch.tensor(prefix_lens_cpu, dtype=torch.int32), positions=positions, - extend_num_tokens=bs * self.block_size, - extend_seq_lens=extend_lens, - extend_prefix_lens=prefix_lens, - extend_start_loc=extend_start_loc, - extend_prefix_lens_cpu=prefix_lens_cpu, - extend_seq_lens_cpu=[int(self.block_size)] * bs, - extend_logprob_start_lens_cpu=[0] * bs, req_to_token_pool=self.native_draft_model_runner.req_to_token_pool, token_to_kv_pool=self.native_draft_model_runner.token_to_kv_pool, attn_backend=self.native_draft_model_runner.attn_backend, input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=draft_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, ) with torch.inference_mode(): From 510bf0ccf610a9d1da5aaeb8c0fd5e8ff29e84fa Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 8 Jan 2026 18:18:02 +0000 Subject: [PATCH 11/73] update test --- test/manual/models/test_qwen3_dflash_gsm8k_bench.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py index fc51bd64135e..6cee6f173c0c 100644 --- a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py +++ b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py @@ -132,11 +132,6 @@ def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): draft_model_path = os.getenv( "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" ) - if not os.path.isdir(draft_model_path): - self.skipTest( - f"Draft model folder not found: {draft_model_path}. " - "Set SGLANG_DFLASH_DRAFT_MODEL_PATH to run this benchmark." - ) attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") max_new_tokens = int(os.getenv("SGLANG_DFLASH_MAX_NEW_TOKENS", "2048")) @@ -341,11 +336,6 @@ def test_qwen3_dflash_native_matches_hf(self): draft_model_path = os.getenv( "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" ) - if not os.path.isdir(draft_model_path): - self.skipTest( - f"Draft model folder not found: {draft_model_path}. " - "Set SGLANG_DFLASH_DRAFT_MODEL_PATH to run this benchmark." - ) attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") # Keep env var names for backwards compatibility with the previous parity test. From c54f336e4f833670efbc921605d0134201324538 Mon Sep 17 00:00:00 2001 From: David Wang Date: Fri, 9 Jan 2026 01:21:17 +0000 Subject: [PATCH 12/73] fix flashinfer backend --- .../srt/layers/attention/flashinfer_backend.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 19a1e7816cf1..f3e82ced943f 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -583,8 +583,23 @@ def init_forward_metadata_capture_cuda_graph( fast_decode_plan, decode_wrappers[i] ) elif forward_mode.is_target_verify(): + # FlashInfer's prefill wrapper decides mask mode based on whether + # `custom_mask_buf` is initialized (not whether a custom mask is provided). + # For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a + # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise + # FlashInfer will treat the (zero) buffer as a real mask and block attention. + use_custom_mask = ( + spec_info is not None and getattr(spec_info, "custom_mask", None) is not None + ) prefill_wrappers = [] for i in range(self.num_wrappers): + wrapper_kwargs = {} + if use_custom_mask: + wrapper_kwargs = { + "custom_mask_buf": self.cuda_graph_custom_mask, + "mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1], + } + prefill_wrappers.append( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, @@ -594,8 +609,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], - custom_mask_buf=self.cuda_graph_custom_mask, - mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + **wrapper_kwargs, ) ) seq_lens_sum = seq_lens.sum().item() From 8c8ee9c9da52946e63d07b6898653481952e7ba0 Mon Sep 17 00:00:00 2001 From: David Wang Date: Fri, 9 Jan 2026 02:15:25 +0000 Subject: [PATCH 13/73] initial radix cache support --- .../sglang/srt/speculative/dflash_worker.py | 228 ++++++++++-------- 1 file changed, 123 insertions(+), 105 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index dd3ac36436c7..3c07fdf5ca5e 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -179,101 +179,11 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D bs = batch.batch_size() device = self.model_runner.device - if draft_input.target_hidden is None: - raise RuntimeError("DFLASH draft state missing target_hidden context features.") - if len(draft_input.ctx_lens_cpu) != bs: - raise RuntimeError( - f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." - ) - if len(draft_input.draft_seq_lens_cpu) != bs: - raise RuntimeError( - "DFLASH draft_seq_lens_cpu length mismatch: " - f"got {len(draft_input.draft_seq_lens_cpu)} for bs={bs}." - ) + # --- 1) Append any newly committed tokens into the native draft KV cache. + self._append_target_hidden_to_native_draft_kv(batch, draft_input) embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() - # --- 1) Append new context tokens into the native draft KV cache. - start_pos_cpu = batch.seq_lens_cpu.tolist() - for cache_len, ctx_len, start_pos in zip( - draft_input.draft_seq_lens_cpu, - draft_input.ctx_lens_cpu, - start_pos_cpu, - strict=True, - ): - if int(cache_len) + int(ctx_len) != int(start_pos): - raise RuntimeError( - "DFLASH native draft cache length mismatch. " - f"cache_len={int(cache_len)}, ctx_len={int(ctx_len)}, start_pos={int(start_pos)}. " - "This can happen if prefix caching is enabled; start with `--disable-radix-cache` for now." - ) - - total_ctx = int(sum(int(x) for x in draft_input.ctx_lens_cpu)) - if total_ctx > 0: - req_to_token = self.native_draft_model_runner.req_to_token_pool.req_to_token - req_pool_indices_cpu = batch.req_pool_indices.tolist() - - ctx_cache_loc_chunks: List[torch.Tensor] = [] - ctx_positions_chunks: List[torch.Tensor] = [] - new_draft_seq_lens_cpu: List[int] = [] - for req_pool_idx, cache_len, ctx_len in zip( - req_pool_indices_cpu, - draft_input.draft_seq_lens_cpu, - draft_input.ctx_lens_cpu, - strict=True, - ): - cache_len_i = int(cache_len) - ctx_len_i = int(ctx_len) - new_draft_seq_lens_cpu.append(cache_len_i + ctx_len_i) - if ctx_len_i <= 0: - continue - s = cache_len_i - e = cache_len_i + ctx_len_i - ctx_cache_loc_chunks.append( - req_to_token[req_pool_idx, s:e].to(torch.int64) - ) - ctx_positions_chunks.append( - torch.arange(s, e, device=device, dtype=torch.int64) - ) - ctx_cache_loc = ( - torch.cat(ctx_cache_loc_chunks, dim=0) - if ctx_cache_loc_chunks - else torch.empty((0,), dtype=torch.int64, device=device) - ) - - ctx_positions = ( - torch.cat(ctx_positions_chunks, dim=0) - if ctx_positions_chunks - else torch.empty((0,), dtype=torch.int64, device=device) - ) - - with torch.inference_mode(): - ctx_hidden = self.native_draft_model.project_target_hidden( - draft_input.target_hidden - ) # [sum(ctx), hidden] - - for layer in self.native_draft_model.layers: - attn = layer.self_attn - k = attn.k_proj(ctx_hidden) - v = attn.v_proj(ctx_hidden) - k = attn.k_norm(k.view(-1, attn.head_dim)).view_as(k) - dummy_q = k.new_empty((k.shape[0], attn.head_dim)) - _, k = attn.rotary_emb(ctx_positions, dummy_q, k) - k = k.view(-1, attn.num_kv_heads, attn.head_dim) - v = v.view(-1, attn.num_kv_heads, attn.head_dim) - self.native_draft_model_runner.token_to_kv_pool.set_kv_buffer( - attn.attn, - ctx_cache_loc, - k, - v, - attn.attn.k_scale, - attn.attn.v_scale, - ) - - draft_input.draft_seq_lens_cpu = new_draft_seq_lens_cpu - draft_input.ctx_lens_cpu = [0] * bs - draft_input.target_hidden = draft_input.target_hidden[:0] - # --- 2) Draft a non-causal block with the native draft model. block_ids = torch.full( (bs, self.block_size), @@ -376,6 +286,111 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D batch.spec_info = verify_input batch.return_hidden_states = False + def _append_target_hidden_to_native_draft_kv( + self, + batch: ScheduleBatch, + draft_input: DFlashDraftInput, + ) -> None: + """Materialize the target hidden-state features into the native draft KV cache. + + This must be run before exposing new tokens to radix cache (prefix hits), otherwise + another request could reuse target KV indices without having draft KV values. + """ + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError("DFLASH draft state missing target_hidden context features.") + if len(draft_input.ctx_lens_cpu) != bs: + raise RuntimeError( + f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." + ) + if len(draft_input.draft_seq_lens_cpu) != bs: + raise RuntimeError( + "DFLASH draft_seq_lens_cpu length mismatch: " + f"got {len(draft_input.draft_seq_lens_cpu)} for bs={bs}." + ) + + # Invariant: draft_seq_len + ctx_len == current target prefix length. + start_pos_cpu = batch.seq_lens_cpu.tolist() + for cache_len, ctx_len, start_pos in zip( + draft_input.draft_seq_lens_cpu, + draft_input.ctx_lens_cpu, + start_pos_cpu, + strict=True, + ): + if int(cache_len) + int(ctx_len) != int(start_pos): + raise RuntimeError( + "DFLASH native draft cache length mismatch. " + f"cache_len={int(cache_len)}, ctx_len={int(ctx_len)}, start_pos={int(start_pos)}." + ) + + total_ctx = int(sum(int(x) for x in draft_input.ctx_lens_cpu)) + if total_ctx <= 0: + return + + req_to_token = self.native_draft_model_runner.req_to_token_pool.req_to_token + req_pool_indices_cpu = batch.req_pool_indices.tolist() + + ctx_cache_loc_chunks: List[torch.Tensor] = [] + ctx_positions_chunks: List[torch.Tensor] = [] + new_draft_seq_lens_cpu: List[int] = [] + for req_pool_idx, cache_len, ctx_len in zip( + req_pool_indices_cpu, + draft_input.draft_seq_lens_cpu, + draft_input.ctx_lens_cpu, + strict=True, + ): + cache_len_i = int(cache_len) + ctx_len_i = int(ctx_len) + new_draft_seq_lens_cpu.append(cache_len_i + ctx_len_i) + if ctx_len_i <= 0: + continue + s = cache_len_i + e = cache_len_i + ctx_len_i + ctx_cache_loc_chunks.append(req_to_token[req_pool_idx, s:e].to(torch.int64)) + ctx_positions_chunks.append(torch.arange(s, e, device=device, dtype=torch.int64)) + + ctx_cache_loc = ( + torch.cat(ctx_cache_loc_chunks, dim=0) + if ctx_cache_loc_chunks + else torch.empty((0,), dtype=torch.int64, device=device) + ) + + ctx_positions = ( + torch.cat(ctx_positions_chunks, dim=0) + if ctx_positions_chunks + else torch.empty((0,), dtype=torch.int64, device=device) + ) + + with torch.inference_mode(): + ctx_hidden = self.native_draft_model.project_target_hidden( + draft_input.target_hidden + ) # [sum(ctx), hidden] + + for layer in self.native_draft_model.layers: + attn = layer.self_attn + k = attn.k_proj(ctx_hidden) + v = attn.v_proj(ctx_hidden) + k = attn.k_norm(k.view(-1, attn.head_dim)).view_as(k) + dummy_q = k.new_empty((k.shape[0], attn.head_dim)) + _, k = attn.rotary_emb(ctx_positions, dummy_q, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.native_draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + draft_input.draft_seq_lens_cpu = new_draft_seq_lens_cpu + draft_input.ctx_lens_cpu = [0] * bs + draft_input.target_hidden = draft_input.target_hidden[:0] + def forward_batch_generation( self, batch: Union[ScheduleBatch, ModelWorkerBatch], @@ -389,12 +404,6 @@ def forward_batch_generation( return self.target_worker.forward_batch_generation(batch, **kwargs) if batch.forward_mode.is_extend() or batch.is_extend_in_batch: - if any(len(req.prefix_indices) > 0 for req in batch.reqs): - raise ValueError( - "DFLASH currently does not support radix/prefix cache hits (prefix_indices != 0). " - "Start with `--disable-radix-cache` for now." - ) - model_worker_batch = batch.get_model_worker_batch() model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL @@ -411,16 +420,23 @@ def forward_batch_generation( "Make sure the target model has DFlash layers-to-capture configured." ) - ctx_lens_cpu = model_worker_batch.seq_lens_cpu.tolist() + if model_worker_batch.extend_seq_lens is None or model_worker_batch.extend_prefix_lens is None: + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, but got None." + ) - batch.spec_info = DFlashDraftInput( + # Materialize the prompt tokens into the draft KV cache immediately. This is required + # for radix cache support, since the scheduler may update radix after prefill returns. + draft_input = DFlashDraftInput( verified_id=next_token_ids.to(torch.int64), target_hidden=logits_output.hidden_states, - ctx_lens_cpu=ctx_lens_cpu, - draft_seq_lens_cpu=[0] * len(ctx_lens_cpu), + ctx_lens_cpu=[int(x) for x in model_worker_batch.extend_seq_lens], + draft_seq_lens_cpu=[int(x) for x in model_worker_batch.extend_prefix_lens], ) - for req in batch.reqs: - req.dflash_draft_seq_len = 0 + self._append_target_hidden_to_native_draft_kv(batch, draft_input) + batch.spec_info = draft_input + for req, draft_len in zip(batch.reqs, draft_input.draft_seq_lens_cpu, strict=True): + req.dflash_draft_seq_len = int(draft_len) return GenerationBatchResult( logits_output=logits_output, @@ -463,10 +479,12 @@ def forward_batch_generation( page_size=self.page_size, ) - # Update draft state for the next iteration. + # Update draft state for the next iteration. Also materialize the committed verify tokens + # into the draft KV cache immediately so radix cache entries are safe to reuse. draft_input.verified_id = new_verified_id draft_input.target_hidden = next_target_hidden draft_input.ctx_lens_cpu = commit_lens.cpu().tolist() + self._append_target_hidden_to_native_draft_kv(batch, draft_input) batch.spec_info = draft_input batch.forward_mode = ForwardMode.DECODE From 0edea3f163ee5241f57553deb5989895e227a474 Mon Sep 17 00:00:00 2001 From: David Wang Date: Fri, 9 Jan 2026 03:42:21 +0000 Subject: [PATCH 14/73] tp_size > 1 support --- python/sglang/srt/models/dflash.py | 88 ++++++++--- python/sglang/srt/server_args.py | 4 +- .../sglang/srt/speculative/dflash_worker.py | 147 ++++++++++++++++-- .../models/test_qwen3_dflash_gsm8k_bench.py | 23 ++- 4 files changed, 230 insertions(+), 32 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 68286140e613..3592bdb1f42a 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -12,6 +12,8 @@ import torch.nn.functional as F from torch import nn +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -26,24 +28,51 @@ class DFlashAttention(nn.Module): def __init__(self, config, layer_id: int) -> None: super().__init__() hidden_size = int(config.hidden_size) - num_heads = int(config.num_attention_heads) - num_kv_heads = int(getattr(config, "num_key_value_heads", num_heads)) - head_dim = int(getattr(config, "head_dim", hidden_size // num_heads)) + tp_size = int(get_tensor_model_parallel_world_size()) + total_num_heads = int(config.num_attention_heads) + total_num_kv_heads = int(getattr(config, "num_key_value_heads", total_num_heads)) + head_dim = int(getattr(config, "head_dim", hidden_size // total_num_heads)) self.hidden_size = hidden_size - self.num_heads = num_heads - self.num_kv_heads = num_kv_heads + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + assert self.total_num_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_heads divisible by tp_size. " + f"total_num_heads={self.total_num_heads}, tp_size={tp_size}." + ) + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_kv_heads divisible by tp_size when >= tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + else: + assert tp_size % self.total_num_kv_heads == 0, ( + f"DFlashAttention requires tp_size divisible by total_num_kv_heads when total_num_kv_heads < tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim - self.q_size = num_heads * head_dim - self.kv_size = num_kv_heads * head_dim + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim attention_bias = bool(getattr(config, "attention_bias", False)) rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) - self.q_proj = nn.Linear(hidden_size, self.q_size, bias=attention_bias) - self.k_proj = nn.Linear(hidden_size, self.kv_size, bias=attention_bias) - self.v_proj = nn.Linear(hidden_size, self.kv_size, bias=attention_bias) - self.o_proj = nn.Linear(self.q_size, hidden_size, bias=attention_bias) + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=attention_bias, + prefix="qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * head_dim, + hidden_size, + bias=attention_bias, + prefix="o_proj", + ) # Per-head Q/K RMSNorm, matching HF Qwen3. self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) @@ -63,10 +92,10 @@ def __init__(self, config, layer_id: int) -> None: self.scaling = head_dim**-0.5 # DFlash uses non-causal attention over the draft block. self.attn = RadixAttention( - num_heads=num_heads, + num_heads=self.num_heads, head_dim=head_dim, scaling=self.scaling, - num_kv_heads=num_kv_heads, + num_kv_heads=self.num_kv_heads, layer_id=layer_id, attn_type=AttentionType.ENCODER_ONLY, ) @@ -77,9 +106,8 @@ def forward( hidden_states: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = apply_qk_norm( q=q, @@ -91,7 +119,8 @@ def forward( q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, forward_batch) - return self.o_proj(attn_output) + output, _ = self.o_proj(attn_output) + return output class DFlashMLP(nn.Module): @@ -200,11 +229,33 @@ def forward( return self.norm(hidden_states) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if name.endswith(".bias") and name not in params_dict: # Some quantized checkpoints may have extra biases. - continue + # (May still be mappable to a fused/parallel param.) + pass + + for param_name, weight_name, shard_id in stacked_params_mapping: + if f".{weight_name}." not in name: + continue + mapped_name = name.replace(weight_name, param_name) + param = params_dict.get(mapped_name) + if param is None: + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict.get(name) if param is None: # Ignore unexpected weights (e.g., HF rotary caches). @@ -214,4 +265,3 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = DFlashDraftModel - diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 67ad21690d0d..974f146d24e4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2043,9 +2043,9 @@ def _handle_speculative_decoding(self): "Currently DFLASH speculative decoding does not support dp attention." ) - if self.tp_size != 1 or self.pp_size != 1: + if self.pp_size != 1: raise ValueError( - "Currently DFLASH speculative decoding only supports tp_size == 1 and pp_size == 1." + "Currently DFLASH speculative decoding only supports pp_size == 1." ) if self.speculative_draft_model_path is None: diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 3c07fdf5ca5e..8d5e786ac09a 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -5,9 +5,11 @@ import torch import torch.nn.functional as F +from sglang.srt.distributed import get_tp_group from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.models.utils import apply_qk_norm from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -23,7 +25,7 @@ class DFlashWorker: - """DFlash speculative decoding worker (spec-v1, tp=1/pp=1).""" + """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" def __init__( self, @@ -182,7 +184,14 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D # --- 1) Append any newly committed tokens into the native draft KV cache. self._append_target_hidden_to_native_draft_kv(batch, draft_input) - embed_weight, head_weight = self.target_worker.model_runner.model.get_embed_and_head() + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if lm_head is None or not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + raise RuntimeError( + "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " + "`shard_indices` attributes." + ) # --- 2) Draft a non-causal block with the native draft model. block_ids = torch.full( @@ -193,7 +202,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D ) block_ids[:, 0] = draft_input.verified_id.to(torch.long) - noise_embedding = F.embedding(block_ids, embed_weight) + noise_embedding = embed_module(block_ids) input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) prefix_lens_cpu = [int(x) for x in draft_input.draft_seq_lens_cpu] @@ -267,8 +276,10 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D allocator.restore_state(token_to_kv_pool_state_backup) draft_hidden = draft_hidden.view(bs, self.block_size, -1) - draft_logits = F.linear(draft_hidden[:, 1:, :], head_weight) - draft_next = torch.argmax(draft_logits, dim=-1).to(torch.long) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, self.block_size - 1) draft_tokens = torch.cat([block_ids[:, :1], draft_next], dim=1) # [bs, block_size] positions = ( batch.seq_lens.to(torch.long).unsqueeze(1) @@ -286,6 +297,116 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D batch.spec_info = verify_input batch.return_hidden_states = False + def _greedy_sample_from_vocab_parallel_head( + self, + *, + hidden_states: torch.Tensor, + lm_head, + chunk_size: int = 256, + ) -> torch.Tensor: + """Greedy argmax over the target LM head in a TP-safe way. + + We cannot materialize full logits for large vocabularies efficiently, and with + TP>1 each rank only owns a shard of the LM head weight. This computes the + per-rank max, gathers candidates across TP ranks, and selects the global max. + """ + + if hidden_states.numel() == 0: + return torch.empty((0,), dtype=torch.long, device=hidden_states.device) + + tp_group = get_tp_group() + tp_size = int(tp_group.world_size) + + if not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + raise RuntimeError( + "DFLASH greedy sampling requires a vocab-parallel head with `weight` and `shard_indices`." + ) + + shard = lm_head.shard_indices + weight = lm_head.weight # [local_vocab_padded, hidden] + weight_dtype = weight.dtype + + # Valid ranges in the local shard (excluding padding): + # base vocab: [0, num_org) + # added vocab: [num_org_padded, num_org_padded + num_added) + num_org = int(shard.num_org_elements) + num_org_padded = int(shard.num_org_elements_padded) + num_added = int(shard.num_added_elements) + org_vocab_start = int(shard.org_vocab_start_index) + added_vocab_start = int(shard.added_vocab_start_index) + + num_tokens = int(hidden_states.shape[0]) + out_token_ids = torch.empty( + (num_tokens,), dtype=torch.long, device=hidden_states.device + ) + + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = hidden_states[start:end].to(weight_dtype) + chunk_len = int(hs.shape[0]) + + # Base vocab logits. + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + local_max, local_arg = torch.max(base_logits, dim=-1) + else: + local_max = torch.full( + (chunk_len,), + torch.finfo(weight_dtype).min, + dtype=weight_dtype, + device=hs.device, + ) + local_arg = torch.zeros( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + + # Added vocab logits (e.g., LoRA-added embeddings), if present. + if num_added > 0: + added_slice_start = num_org_padded + added_slice_end = num_org_padded + num_added + added_logits = torch.matmul(hs, weight[added_slice_start:added_slice_end].T) + added_max, added_arg = torch.max(added_logits, dim=-1) + use_added = added_max > local_max + local_max = torch.where(use_added, added_max, local_max) + local_arg = torch.where( + use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg + ) + + # Convert local argmax indices to global token ids. + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] + if num_added > 0: + global_ids[~is_base] = added_vocab_start + (local_arg[~is_base] - num_org_padded) + + if tp_size == 1: + out_token_ids[start:end] = global_ids.to(torch.long) + continue + + # Gather per-rank maxima and associated global ids, then select the global max. + gathered_max = torch.empty( + (tp_size * chunk_len,), + dtype=local_max.dtype, + device=hs.device, + ) + gathered_ids = torch.empty( + (tp_size * chunk_len,), + dtype=global_ids.dtype, + device=hs.device, + ) + tp_group.all_gather_into_tensor(gathered_max, local_max.contiguous()) + tp_group.all_gather_into_tensor(gathered_ids, global_ids.contiguous()) + gathered_max = gathered_max.view(tp_size, chunk_len) + gathered_ids = gathered_ids.view(tp_size, chunk_len) + + best_rank = torch.argmax(gathered_max, dim=0) + idx = torch.arange(chunk_len, device=hs.device) + out_token_ids[start:end] = gathered_ids[best_rank, idx].to(torch.long) + + return out_token_ids + def _append_target_hidden_to_native_draft_kv( self, batch: ScheduleBatch, @@ -371,11 +492,17 @@ def _append_target_hidden_to_native_draft_kv( for layer in self.native_draft_model.layers: attn = layer.self_attn - k = attn.k_proj(ctx_hidden) - v = attn.v_proj(ctx_hidden) - k = attn.k_norm(k.view(-1, attn.head_dim)).view_as(k) - dummy_q = k.new_empty((k.shape[0], attn.head_dim)) - _, k = attn.rotary_emb(ctx_positions, dummy_q, k) + qkv, _ = attn.qkv_proj(ctx_hidden) + q, k, v = qkv.split([attn.q_size, attn.kv_size, attn.kv_size], dim=-1) + + q, k = apply_qk_norm( + q=q, + k=k, + q_norm=attn.q_norm, + k_norm=attn.k_norm, + head_dim=attn.head_dim, + ) + q, k = attn.rotary_emb(ctx_positions, q, k) k = k.view(-1, attn.num_kv_heads, attn.head_dim) v = v.view(-1, attn.num_kv_heads, attn.head_dim) self.native_draft_model_runner.token_to_kv_pool.set_kv_buffer( diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py index 6cee6f173c0c..799c55c8a61b 100644 --- a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py +++ b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py @@ -138,9 +138,18 @@ def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): parallel = int(os.getenv("SGLANG_DFLASH_PARALLEL", "1")) num_questions = int(os.getenv("SGLANG_DFLASH_NUM_QUESTIONS", "100")) num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) + tp_size = int(os.getenv("SGLANG_DFLASH_TP_SIZE", "1")) disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" prompt_style = os.getenv("SGLANG_DFLASH_PROMPT_STYLE", "fewshot_qa") assert_match = os.getenv("SGLANG_DFLASH_ASSERT_MATCH", "0") != "0" + if tp_size < 1: + raise ValueError(f"Invalid SGLANG_DFLASH_TP_SIZE={tp_size}; expected >= 1.") + if torch.cuda.device_count() < tp_size: + self.skipTest( + f"tp_size={tp_size} requires at least {tp_size} visible CUDA devices, " + f"but only {torch.cuda.device_count()} are available. " + "Set CUDA_VISIBLE_DEVICES accordingly." + ) # Read GSM8K data (download if absent). data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") @@ -178,7 +187,7 @@ def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): labels.append(_get_answer_value(lines[i]["answer"])) self.assertTrue(all(l != INVALID for l in labels), "Invalid labels in GSM8K data") - common_server_args = ["--attention-backend", attention_backend] + common_server_args = ["--attention-backend", attention_backend, "--tp-size", str(tp_size)] if disable_radix_cache: common_server_args.append("--disable-radix-cache") extra_server_args = os.getenv("SGLANG_DFLASH_EXTRA_SERVER_ARGS", "").strip() @@ -303,6 +312,7 @@ def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: "parallel": parallel, "num_questions": num_questions, "num_shots": num_shots, + "tp_size": tp_size, "prompt_style": prompt_style, "disable_radix_cache": disable_radix_cache, }, @@ -344,8 +354,17 @@ def test_qwen3_dflash_native_matches_hf(self): num_questions = int(os.getenv("SGLANG_DFLASH_PARITY_NUM_QUESTIONS", "10")) max_total_tokens = int(os.getenv("SGLANG_DFLASH_PARITY_MAX_TOTAL_TOKENS", "8192")) num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) + tp_size = int(os.getenv("SGLANG_DFLASH_TP_SIZE", "1")) disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" prompt_style = os.getenv("SGLANG_DFLASH_PROMPT_STYLE", "fewshot_qa") + if tp_size < 1: + raise ValueError(f"Invalid SGLANG_DFLASH_TP_SIZE={tp_size}; expected >= 1.") + if torch.cuda.device_count() < tp_size: + self.skipTest( + f"tp_size={tp_size} requires at least {tp_size} visible CUDA devices, " + f"but only {torch.cuda.device_count()} are available. " + "Set CUDA_VISIBLE_DEVICES accordingly." + ) # Read GSM8K data (download if absent). data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") @@ -383,6 +402,8 @@ def test_qwen3_dflash_native_matches_hf(self): common_server_args = [ "--attention-backend", attention_backend, + "--tp-size", + str(tp_size), "--max-total-tokens", str(max_total_tokens), ] From f23555b1b6bf75e845f52c7d2a17521049b4e50d Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 10 Jan 2026 01:41:57 +0000 Subject: [PATCH 15/73] add optional dflash_config for overrides, add --speculative-dflash-block-size to server args --- .../srt/model_executor/cuda_graph_runner.py | 2 +- .../sglang/srt/model_executor/model_runner.py | 13 ++-- python/sglang/srt/models/dflash.py | 24 ++++++- python/sglang/srt/server_args.py | 68 +++++++++++++++++- python/sglang/srt/speculative/dflash_utils.py | 70 ++++++++++++++++++- .../sglang/srt/speculative/dflash_worker.py | 49 +++++++++---- 6 files changed, 204 insertions(+), 22 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 1515f441066f..c68b2cd90e63 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -363,7 +363,7 @@ def __init__(self, model_runner: ModelRunner): "which is required for DFLASH aux hidden capture." ) self.model_runner.model.set_dflash_layers_to_capture( - self.model_runner.dflash_aux_hidden_state_layer_ids + self.model_runner.dflash_target_layer_ids ) # Capture diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7d8999e28368..540f52a949ba 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -138,7 +138,7 @@ set_global_server_args_for_scheduler, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.speculative.dflash_utils import build_target_layer_ids +from sglang.srt.speculative.dflash_utils import resolve_dflash_target_layer_ids from sglang.srt.utils import ( MultiprocessingSerializer, cpu_has_amx_support, @@ -318,6 +318,7 @@ def __init__( # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False self.dflash_use_aux_hidden_state = False + self.dflash_target_layer_ids = None if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -369,8 +370,10 @@ def __init__( ) self.dflash_use_aux_hidden_state = True - self.dflash_aux_hidden_state_layer_ids = build_target_layer_ids( - int(target_num_layers), int(draft_num_layers) + self.dflash_target_layer_ids = resolve_dflash_target_layer_ids( + draft_hf_config=draft_model_config.hf_config, + target_num_layers=int(target_num_layers), + draft_num_layers=int(draft_num_layers), ) # Apply the rank zero filter to logger @@ -619,7 +622,7 @@ def initialize(self, min_per_gpu_memory: float): f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " "which is required for DFLASH." ) - self.model.set_dflash_layers_to_capture(self.dflash_aux_hidden_state_layer_ids) + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -1797,7 +1800,7 @@ def _dummy_run(self, batch_size: int): if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() if self.dflash_use_aux_hidden_state: - self.model.set_dflash_layers_to_capture(self.dflash_aux_hidden_state_layer_ids) + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args) if require_gathered_buffer(self.server_args): diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 3592bdb1f42a..a51d36f558ee 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -199,7 +199,29 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: self.fc = nn.Linear(num_layers * hidden_size, hidden_size, bias=False) self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.block_size = int(getattr(config, "block_size", 16)) + dflash_cfg = getattr(config, "dflash_config", None) + dflash_block_size = None + if isinstance(dflash_cfg, dict): + dflash_block_size = dflash_cfg.get("block_size", None) + + block_size = ( + dflash_block_size + if dflash_block_size is not None + else getattr(config, "block_size", None) + ) + if block_size is None: + block_size = 16 + elif getattr(config, "block_size", None) is not None and dflash_block_size is not None: + if int(dflash_block_size) != int(getattr(config, "block_size")): + logger.warning( + "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; using dflash_config.block_size.", + getattr(config, "block_size"), + dflash_block_size, + ) + try: + self.block_size = int(block_size) + except Exception as e: + raise ValueError(f"Invalid DFLASH block_size={block_size!r}.") from e def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: """Project concatenated target-layer hidden states into draft hidden_size.""" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 974f146d24e4..f79ba27600a5 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -421,6 +421,7 @@ class ServerArgs: speculative_num_steps: Optional[int] = None speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None + speculative_dflash_block_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -2053,12 +2054,48 @@ def _handle_speculative_decoding(self): "DFLASH speculative decoding requires setting --speculative-draft-model-path." ) - # Set default spec params expected by generic spec-v1 plumbing. + # DFLASH does not use EAGLE-style `num_steps`/`topk`, but those fields still + # affect generic scheduler/KV-cache accounting (buffer sizing, KV freeing, + # RoPE reservation). Force them to 1 to avoid surprising memory behavior. + # # For DFlash, the natural unit is `block_size` (verify window length). if self.speculative_num_steps is None: self.speculative_num_steps = 1 + elif int(self.speculative_num_steps) != 1: + logger.warning( + "DFLASH only supports speculative_num_steps == 1; overriding speculative_num_steps=%s to 1.", + self.speculative_num_steps, + ) + self.speculative_num_steps = 1 + if self.speculative_eagle_topk is None: self.speculative_eagle_topk = 1 + elif int(self.speculative_eagle_topk) != 1: + logger.warning( + "DFLASH only supports speculative_eagle_topk == 1; overriding speculative_eagle_topk=%s to 1.", + self.speculative_eagle_topk, + ) + self.speculative_eagle_topk = 1 + + if self.speculative_dflash_block_size is not None: + if int(self.speculative_dflash_block_size) <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-block-size to be positive, " + f"got {self.speculative_dflash_block_size}." + ) + if ( + self.speculative_num_draft_tokens is not None + and int(self.speculative_num_draft_tokens) + != int(self.speculative_dflash_block_size) + ): + raise ValueError( + "Both --speculative-num-draft-tokens and --speculative-dflash-block-size are set " + "but they differ. For DFLASH they must match. " + f"speculative_num_draft_tokens={self.speculative_num_draft_tokens}, " + f"speculative_dflash_block_size={self.speculative_dflash_block_size}." + ) + self.speculative_num_draft_tokens = int(self.speculative_dflash_block_size) + if self.speculative_num_draft_tokens is None: inferred_block_size = None try: @@ -2069,7 +2106,28 @@ def _handle_speculative_decoding(self): if os.path.isfile(draft_config_path): with open(draft_config_path, "r") as f: draft_config_json = json.load(f) - inferred_block_size = draft_config_json.get("block_size") + top_level_block_size = draft_config_json.get("block_size", None) + dflash_cfg = draft_config_json.get("dflash_config", None) + dflash_block_size = ( + dflash_cfg.get("block_size", None) + if isinstance(dflash_cfg, dict) + else None + ) + + if dflash_block_size is not None: + inferred_block_size = dflash_block_size + if ( + top_level_block_size is not None + and int(dflash_block_size) != int(top_level_block_size) + ): + logger.warning( + "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; " + "using dflash_config.block_size for speculative_num_draft_tokens inference.", + top_level_block_size, + dflash_block_size, + ) + else: + inferred_block_size = top_level_block_size except Exception as e: logger.warning( "Failed to infer DFlash block_size from draft config.json; " @@ -3536,6 +3594,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-dflash-block-size", + type=int, + help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", + default=ServerArgs.speculative_dflash_block_size, + ) parser.add_argument( "--speculative-accept-threshold-single", type=float, diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index f1a9b191bc72..4f1af9871318 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -1,10 +1,13 @@ from __future__ import annotations -from typing import List, Tuple +from typing import Any, List, Tuple import torch +DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" + + def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: """Select target layer indices used to build DFlash context features. @@ -46,6 +49,71 @@ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> Lis ] +def _get_dflash_config(config: Any) -> dict: + cfg = getattr(config, "dflash_config", None) + if cfg is None: + return {} + if isinstance(cfg, dict): + return cfg + + try: + return dict(cfg) + except Exception: + return {} + + +def resolve_dflash_target_layer_ids( + *, + draft_hf_config: Any, + target_num_layers: int, + draft_num_layers: int, +) -> List[int]: + """Resolve target layer ids used to build DFlash context features. + + Precedence: + 1) `draft_hf_config.dflash_config.target_layer_ids` + 2) default `build_target_layer_ids(target_num_layers, draft_num_layers)` + """ + cfg = _get_dflash_config(draft_hf_config) + layer_ids = cfg.get("target_layer_ids", None) + if layer_ids is None: + return build_target_layer_ids(target_num_layers, draft_num_layers) + + if not isinstance(layer_ids, (list, tuple)): + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(layer_ids).__name__}." + ) + + resolved: List[int] = [int(x) for x in layer_ids] + if len(resolved) != int(draft_num_layers): + raise ValueError( + "DFLASH target_layer_ids length must equal the draft num_hidden_layers. " + f"Got len(target_layer_ids)={len(resolved)}, draft_num_layers={int(draft_num_layers)}." + ) + + for idx, val in enumerate(resolved): + if val < 0 or val >= int(target_num_layers): + raise ValueError( + "DFLASH target_layer_ids contains an out-of-range layer id. " + f"target_layer_ids[{idx}]={val}, target_num_layers={int(target_num_layers)}." + ) + return resolved + + +def resolve_dflash_mask_token(*, draft_hf_config: Any) -> str: + cfg = _get_dflash_config(draft_hf_config) + mask_token = cfg.get("mask_token", None) + if mask_token is None: + return DEFAULT_DFLASH_MASK_TOKEN + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + "DFLASH dflash_config.mask_token must be a non-empty string, " + f"got {mask_token!r}." + ) + return mask_token + + def compute_dflash_accept_len_and_bonus( *, candidates: torch.Tensor, diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8d5e786ac09a..3d9693eac478 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -18,6 +18,7 @@ ) from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput +from sglang.srt.speculative.dflash_utils import resolve_dflash_mask_token from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func @@ -49,7 +50,6 @@ def __init__( self.page_size = server_args.page_size self.device = target_worker.device - self._mask_token_id = self._resolve_mask_token_id() self._warned_forced_greedy = False self._logged_first_verify = False @@ -94,7 +94,23 @@ def __init__( ) self.native_draft_model_runner = self.native_draft_worker.model_runner self.native_draft_model = self.native_draft_model_runner.model - self.block_size = int(getattr(self.native_draft_model, "block_size", 16)) + if server_args.speculative_num_draft_tokens is None: + # Should not happen (ServerArgs should have inferred it), but keep a fallback. + self.block_size = int(getattr(self.native_draft_model, "block_size", 16)) + else: + self.block_size = int(server_args.speculative_num_draft_tokens) + model_block_size = getattr(self.native_draft_model, "block_size", None) + if model_block_size is not None and int(model_block_size) != int(self.block_size): + logger.warning( + "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", + self.block_size, + model_block_size, + ) + + self._mask_token = resolve_dflash_mask_token( + draft_hf_config=self.native_draft_model_runner.model_config.hf_config + ) + self._mask_token_id = self._resolve_mask_token_id(mask_token=self._mask_token) if self.tp_rank == 0: logger.info( "Initialized native DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", @@ -103,7 +119,8 @@ def __init__( self.block_size, ) logger.info( - "DFLASH draft impl selected. impl=native, mask_token_id=%s", + "DFLASH draft impl selected. impl=native, mask_token=%s, mask_token_id=%s", + self._mask_token, self._mask_token_id, ) @@ -121,44 +138,52 @@ def on_req_finished(self, req): if hasattr(req, "dflash_draft_seq_len"): req.dflash_draft_seq_len = 0 - def _resolve_mask_token_id(self) -> int: + def _resolve_mask_token_id(self, *, mask_token: str) -> int: + if not isinstance(mask_token, str) or not mask_token: + raise ValueError(f"DFLASH mask_token must be a non-empty string, got {mask_token!r}.") + tokenizer = getattr(self.target_worker, "tokenizer", None) if tokenizer is None: raise RuntimeError("DFLASH requires tokenizer initialization (skip_tokenizer_init is not supported).") vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) - mask_token_id = getattr(tokenizer, "mask_token_id", None) + mask_token_id = None + if getattr(tokenizer, "mask_token", None) == mask_token: + mask_token_id = getattr(tokenizer, "mask_token_id", None) + if mask_token_id is None: - # `convert_tokens_to_ids` can return `None` (or an unk id) depending on tokenizer. # Prefer checking the explicit vocab mapping first. vocab = tokenizer.get_vocab() - mask_token_id = vocab.get("<|MASK|>", None) + mask_token_id = vocab.get(mask_token, None) if mask_token_id is None: # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. # This is safe only when the resulting id stays within the target model vocab size. - added = tokenizer.add_special_tokens({"mask_token": "<|MASK|>"}) + added = tokenizer.add_special_tokens({"mask_token": mask_token}) mask_token_id = getattr(tokenizer, "mask_token_id", None) if mask_token_id is None: - mask_token_id = tokenizer.convert_tokens_to_ids("<|MASK|>") + mask_token_id = tokenizer.convert_tokens_to_ids(mask_token) if added and self.tp_rank == 0: logger.info( "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", - "<|MASK|>", + mask_token, mask_token_id, len(tokenizer), vocab_size, ) if mask_token_id is None or int(mask_token_id) < 0: - raise ValueError("DFLASH requires a `<|MASK|>` token id, but it could not be resolved.") + raise ValueError( + "DFLASH requires resolving a mask token id, but it could not be resolved. " + f"mask_token={mask_token!r}." + ) if mask_token_id >= vocab_size: raise ValueError( "DFLASH mask_token_id is outside the target vocab size. " f"mask_token_id={mask_token_id}, vocab_size={vocab_size}. " - "This likely means `<|MASK|>` requires vocab expansion beyond the model's embedding size. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " "SGLang does not support resizing target embeddings for DFLASH yet." ) From 63c0b9a2d5c2db453a1d884e873b4ea365bfcf19 Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 10 Jan 2026 02:11:38 +0000 Subject: [PATCH 16/73] fix OOMs with default settings --- python/sglang/srt/server_args.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f79ba27600a5..f6c1a7697cc8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -974,6 +974,9 @@ def _handle_gpu_memory_settings(self, gpu_mem): if self.speculative_algorithm == "STANDALONE": # standalonedraft model and cuda graphs reserved_mem += 6 * 1024 + elif self.speculative_algorithm == "DFLASH": + # dflash draft model and cuda graphs + reserved_mem += 6 * 1024 elif self.speculative_algorithm != "NGRAM": # eagle draft models and cuda graphs reserved_mem += 2 * 1024 From 9309764c12a777ce8695c54ef7c74b58484b1788 Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 10 Jan 2026 06:28:39 +0000 Subject: [PATCH 17/73] clean up --- .../models/test_qwen3_dflash_correctness.py | 137 ------ .../models/test_qwen3_dflash_gsm8k_bench.py | 456 ------------------ test/srt/test_dflash_acceptance_unit.py | 55 --- 3 files changed, 648 deletions(-) delete mode 100644 test/manual/models/test_qwen3_dflash_correctness.py delete mode 100644 test/manual/models/test_qwen3_dflash_gsm8k_bench.py delete mode 100644 test/srt/test_dflash_acceptance_unit.py diff --git a/test/manual/models/test_qwen3_dflash_correctness.py b/test/manual/models/test_qwen3_dflash_correctness.py deleted file mode 100644 index 8bef7afb63a6..000000000000 --- a/test/manual/models/test_qwen3_dflash_correctness.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import unittest - -import requests -import torch - -from sglang.srt.utils import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - CustomTestCase, - find_available_port, - is_in_ci, - popen_launch_server, -) - - -def _send_generate(base_url: str, prompt: str, *, max_new_tokens: int) -> dict: - resp = requests.post( - base_url + "/generate", - json={ - "text": prompt, - "sampling_params": { - "temperature": 0.0, - "top_p": 1.0, - "top_k": 1, - "max_new_tokens": max_new_tokens, - }, - }, - timeout=600, - ) - resp.raise_for_status() - return resp.json() - - -class TestQwen3DFlashCorrectness(CustomTestCase): - def test_qwen3_dflash_matches_target_only_greedy(self): - if is_in_ci(): - self.skipTest("Manual test; skipped in CI.") - if not torch.cuda.is_available(): - self.skipTest("CUDA is required for this manual DFlash integration test.") - - target_model = os.getenv("SGLANG_DFLASH_TARGET_MODEL", "Qwen/Qwen3-8B") - draft_model_path = os.getenv( - "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" - ) - if not os.path.isdir(draft_model_path): - self.skipTest( - f"Draft model folder not found: {draft_model_path}. " - "Set SGLANG_DFLASH_DRAFT_MODEL_PATH to run this test." - ) - - max_new_tokens = int(os.getenv("SGLANG_DFLASH_MAX_NEW_TOKENS", "128")) - prompt = os.getenv( - "SGLANG_DFLASH_PROMPT", - "How many positive whole-number divisors does 196 have?", - ) - attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") - - baseline_port = find_available_port(20000) - dflash_port = find_available_port(baseline_port + 1) - baseline_url = f"http://127.0.0.1:{baseline_port}" - dflash_url = f"http://127.0.0.1:{dflash_port}" - - # 1) Target-only baseline. - baseline_proc = popen_launch_server( - target_model, - baseline_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--disable-radix-cache", - "--attention-backend", - attention_backend, - ], - ) - try: - baseline = _send_generate( - baseline_url, prompt, max_new_tokens=max_new_tokens - ) - finally: - kill_process_tree(baseline_proc.pid) - - # 2) DFLASH speculative decoding. - dflash_proc = popen_launch_server( - target_model, - dflash_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - "--disable-radix-cache", - "--attention-backend", - attention_backend, - "--speculative-algorithm", - "DFLASH", - "--speculative-draft-model-path", - draft_model_path, - ], - ) - try: - dflash = _send_generate(dflash_url, prompt, max_new_tokens=max_new_tokens) - finally: - kill_process_tree(dflash_proc.pid) - - self.assertEqual( - baseline["output_ids"], - dflash["output_ids"], - f"Token IDs mismatch.\nbaseline={baseline['output_ids']}\ndflash={dflash['output_ids']}", - ) - self.assertEqual( - baseline["text"], - dflash["text"], - "Decoded text mismatch for greedy decoding.", - ) - - meta = dflash.get("meta_info", {}) - self.assertIn("spec_verify_ct", meta, f"Missing spec metrics: {meta.keys()}") - self.assertGreater(meta["spec_verify_ct"], 0, "DFLASH did not run verify steps.") - self.assertIn("spec_accept_length", meta, f"Missing spec_accept_length: {meta.keys()}") - self.assertGreaterEqual( - float(meta["spec_accept_length"]), - 1.0, - "Spec accept length should be >= 1.0 (bonus token).", - ) - print( - "DFLASH metrics:", - { - "spec_verify_ct": meta.get("spec_verify_ct"), - "spec_accept_length": meta.get("spec_accept_length"), - "spec_accept_rate": meta.get("spec_accept_rate"), - "spec_accept_token_num": meta.get("spec_accept_token_num"), - "spec_draft_token_num": meta.get("spec_draft_token_num"), - "completion_tokens": meta.get("completion_tokens"), - }, - flush=True, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py b/test/manual/models/test_qwen3_dflash_gsm8k_bench.py deleted file mode 100644 index 799c55c8a61b..000000000000 --- a/test/manual/models/test_qwen3_dflash_gsm8k_bench.py +++ /dev/null @@ -1,456 +0,0 @@ -"""Manual GSM8K benchmark for DFLASH (vs target-only baseline). - -Notes / known limitations (as of this initial integration): - - Prompting style matters a lot for acceptance length. The upstream DFlash HF demo/bench - uses a Qwen chat-template prompt; use `SGLANG_DFLASH_PROMPT_STYLE=dflash_chat` and - typically `SGLANG_DFLASH_STOP=` (empty) to get closer acceptance numbers. - - DFLASH may *diverge* from target-only greedy decoding on some prompts. This is because - DFLASH verifies a whole block with `ForwardMode.TARGET_VERIFY` (prefill-style kernels), - while the baseline uses the normal decode path. Some attention backends can produce - different argmax tokens across these modes (numerical differences), which makes direct - "accuracy" comparisons misleading. - Use `SGLANG_DFLASH_ASSERT_MATCH=1` to detect any token-level divergence. -""" - -import ast -import json -import os -import re -import shlex -import statistics -import time -import unittest -from concurrent.futures import ThreadPoolExecutor, as_completed - -import requests -import torch -from transformers import AutoTokenizer - -from sglang.srt.utils import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - CustomTestCase, - find_available_port, - is_in_ci, - popen_launch_server, -) -from sglang.utils import download_and_cache_file, read_jsonl - -INVALID = -9999999 - - -def _get_one_example(lines, i: int, include_answer: bool) -> str: - ret = "Question: " + lines[i]["question"] + "\nAnswer:" - if include_answer: - ret += " " + lines[i]["answer"] - return ret - - -def _get_few_shot_examples(lines, k: int) -> str: - ret = "" - for i in range(k): - ret += _get_one_example(lines, i, True) + "\n\n" - return ret - - -def _get_answer_value(answer_str: str) -> int: - answer_str = answer_str.replace(",", "") - numbers = re.findall(r"\d+", answer_str) - if len(numbers) < 1: - return INVALID - try: - return ast.literal_eval(numbers[-1]) - except SyntaxError: - return INVALID - - -def _send_generate(base_url: str, prompt: str, *, max_new_tokens: int) -> dict: - stop = os.getenv("SGLANG_DFLASH_STOP", "Question,Assistant:,<|separator|>") - stop_list = [s for s in stop.split(",") if s] if stop else [] - sampling_params = { - "temperature": 0.0, - "top_p": 1.0, - "top_k": 1, - "max_new_tokens": max_new_tokens, - } - if stop_list: - sampling_params["stop"] = stop_list - resp = requests.post( - base_url + "/generate", - json={ - "text": prompt, - "sampling_params": sampling_params, - }, - timeout=600, - ) - resp.raise_for_status() - return resp.json() - - -def _run_generate_batch( - base_url: str, - prompts: list[str], - *, - max_new_tokens: int, - parallel: int, -) -> tuple[float, list[dict]]: - start = time.perf_counter() - outputs: list[dict] = [None for _ in range(len(prompts))] # type: ignore[list-item] - with ThreadPoolExecutor(max_workers=parallel) as pool: - futures = { - pool.submit(_send_generate, base_url, prompt, max_new_tokens=max_new_tokens): i - for i, prompt in enumerate(prompts) - } - for fut in as_completed(futures): - idx = futures[fut] - outputs[idx] = fut.result() - latency = time.perf_counter() - start - return latency, outputs - - -def _summarize(values: list[float]) -> dict: - if not values: - return {"mean": None, "p50": None, "p90": None} - values_sorted = sorted(values) - p50 = values_sorted[int(0.50 * (len(values_sorted) - 1))] - p90 = values_sorted[int(0.90 * (len(values_sorted) - 1))] - return { - "mean": float(statistics.mean(values_sorted)), - "p50": float(p50), - "p90": float(p90), - } - - -class TestQwen3DFlashGSM8KBench(CustomTestCase): - def test_qwen3_dflash_gsm8k_speedup_and_acceptance(self): - if is_in_ci(): - self.skipTest("Manual benchmark; skipped in CI.") - if not torch.cuda.is_available(): - self.skipTest("CUDA is required for this manual DFlash benchmark.") - - target_model = os.getenv("SGLANG_DFLASH_TARGET_MODEL", "Qwen/Qwen3-8B") - draft_model_path = os.getenv( - "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" - ) - - attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") - max_new_tokens = int(os.getenv("SGLANG_DFLASH_MAX_NEW_TOKENS", "2048")) - parallel = int(os.getenv("SGLANG_DFLASH_PARALLEL", "1")) - num_questions = int(os.getenv("SGLANG_DFLASH_NUM_QUESTIONS", "100")) - num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) - tp_size = int(os.getenv("SGLANG_DFLASH_TP_SIZE", "1")) - disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" - prompt_style = os.getenv("SGLANG_DFLASH_PROMPT_STYLE", "fewshot_qa") - assert_match = os.getenv("SGLANG_DFLASH_ASSERT_MATCH", "0") != "0" - if tp_size < 1: - raise ValueError(f"Invalid SGLANG_DFLASH_TP_SIZE={tp_size}; expected >= 1.") - if torch.cuda.device_count() < tp_size: - self.skipTest( - f"tp_size={tp_size} requires at least {tp_size} visible CUDA devices, " - f"but only {torch.cuda.device_count()} are available. " - "Set CUDA_VISIBLE_DEVICES accordingly." - ) - - # Read GSM8K data (download if absent). - data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") - url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" - if not os.path.isfile(data_path): - data_path = download_and_cache_file(url) - lines = list(read_jsonl(data_path)) - - tokenizer = None - if prompt_style == "dflash_chat": - tokenizer = AutoTokenizer.from_pretrained(target_model) - - few_shot = _get_few_shot_examples(lines, num_shots) if prompt_style == "fewshot_qa" else "" - prompts: list[str] = [] - labels: list[int] = [] - for i in range(len(lines[:num_questions])): - if prompt_style == "fewshot_qa": - prompts.append(few_shot + _get_one_example(lines, i, False)) - elif prompt_style == "dflash_chat": - assert tokenizer is not None - user_content = ( - lines[i]["question"] - + "\nPlease reason step by step, and put your final answer within \\boxed{}." - ) - prompts.append( - tokenizer.apply_chat_template( - [{"role": "user", "content": user_content}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) - ) - else: - raise ValueError(f"Unsupported SGLANG_DFLASH_PROMPT_STYLE: {prompt_style}") - labels.append(_get_answer_value(lines[i]["answer"])) - self.assertTrue(all(l != INVALID for l in labels), "Invalid labels in GSM8K data") - - common_server_args = ["--attention-backend", attention_backend, "--tp-size", str(tp_size)] - if disable_radix_cache: - common_server_args.append("--disable-radix-cache") - extra_server_args = os.getenv("SGLANG_DFLASH_EXTRA_SERVER_ARGS", "").strip() - if extra_server_args: - common_server_args.extend(shlex.split(extra_server_args)) - - baseline_port = find_available_port(20000) - dflash_port = find_available_port(baseline_port + 1) - baseline_url = f"http://127.0.0.1:{baseline_port}" - dflash_url = f"http://127.0.0.1:{dflash_port}" - - # 1) Target-only baseline. - baseline_proc = popen_launch_server( - target_model, - baseline_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=common_server_args, - ) - try: - _send_generate(baseline_url, "Hello", max_new_tokens=8) # warmup - baseline_latency, baseline_outputs = _run_generate_batch( - baseline_url, - prompts, - max_new_tokens=max_new_tokens, - parallel=parallel, - ) - finally: - kill_process_tree(baseline_proc.pid) - try: - baseline_proc.wait(timeout=30) - except Exception: - pass - - # 2) DFLASH speculative decoding. - dflash_proc = popen_launch_server( - target_model, - dflash_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *common_server_args, - "--speculative-algorithm", - "DFLASH", - "--speculative-draft-model-path", - draft_model_path, - ], - ) - try: - _send_generate(dflash_url, "Hello", max_new_tokens=8) # warmup - dflash_latency, dflash_outputs = _run_generate_batch( - dflash_url, - prompts, - max_new_tokens=max_new_tokens, - parallel=parallel, - ) - finally: - kill_process_tree(dflash_proc.pid) - try: - dflash_proc.wait(timeout=30) - except Exception: - pass - - def _collect_common_metrics(outputs: list[dict]) -> tuple[int, list[int]]: - completion_tokens = [] - preds = [] - total_completion_tokens = 0 - for out in outputs: - meta = out.get("meta_info", {}) - total_completion_tokens += int(meta.get("completion_tokens", 0)) - completion_tokens.append(int(meta.get("completion_tokens", 0))) - preds.append(_get_answer_value(out.get("text", ""))) - return total_completion_tokens, preds - - baseline_total_tokens, baseline_preds = _collect_common_metrics(baseline_outputs) - dflash_total_tokens, dflash_preds = _collect_common_metrics(dflash_outputs) - - if assert_match: - for i, (baseline_out, dflash_out) in enumerate( - zip(baseline_outputs, dflash_outputs, strict=True) - ): - if baseline_out.get("output_ids") != dflash_out.get("output_ids"): - raise AssertionError( - "Baseline and DFLASH outputs diverged at index " - f"{i}.\nbaseline={baseline_out.get('output_ids')}\ndflash={dflash_out.get('output_ids')}" - ) - - baseline_throughput = baseline_total_tokens / max(baseline_latency, 1e-6) - dflash_throughput = dflash_total_tokens / max(dflash_latency, 1e-6) - speedup = dflash_throughput / max(baseline_throughput, 1e-6) - - # WARNING: Until baseline-vs-DFLASH greedy outputs are guaranteed identical, these - # "accuracy" numbers are not strictly comparable. Prefer asserting matches via - # `SGLANG_DFLASH_ASSERT_MATCH=1` when debugging correctness. - baseline_acc = sum( - int(p == l) for p, l in zip(baseline_preds, labels, strict=True) - ) / len(labels) - dflash_acc = sum( - int(p == l) for p, l in zip(dflash_preds, labels, strict=True) - ) / len(labels) - - spec_accept_lengths: list[float] = [] - spec_accept_rates: list[float] = [] - spec_verify_cts: list[int] = [] - for out in dflash_outputs: - meta = out.get("meta_info", {}) - if "spec_verify_ct" in meta: - spec_verify_cts.append(int(meta["spec_verify_ct"])) - if "spec_accept_length" in meta: - spec_accept_lengths.append(float(meta["spec_accept_length"])) - if "spec_accept_rate" in meta: - spec_accept_rates.append(float(meta["spec_accept_rate"])) - - # Basic sanity checks that DFLASH actually ran. - self.assertTrue(spec_verify_cts, "Missing spec_verify_ct in DFLASH responses.") - self.assertGreater(sum(spec_verify_cts), 0, "DFLASH did not run verify steps.") - - report = { - "settings": { - "target_model": target_model, - "draft_model_path": draft_model_path, - "attention_backend": attention_backend, - "max_new_tokens": max_new_tokens, - "parallel": parallel, - "num_questions": num_questions, - "num_shots": num_shots, - "tp_size": tp_size, - "prompt_style": prompt_style, - "disable_radix_cache": disable_radix_cache, - }, - "baseline": { - "latency_s": round(baseline_latency, 3), - "completion_tokens": baseline_total_tokens, - "throughput_tok_s": round(baseline_throughput, 3), - "accuracy": round(baseline_acc, 3), - }, - "dflash": { - "latency_s": round(dflash_latency, 3), - "completion_tokens": dflash_total_tokens, - "throughput_tok_s": round(dflash_throughput, 3), - "accuracy": round(dflash_acc, 3), - "spec_accept_length": _summarize(spec_accept_lengths), - "spec_accept_rate": _summarize(spec_accept_rates), - "spec_verify_ct_mean": float(statistics.mean(spec_verify_cts)), - }, - "speedup": round(speedup, 3), - } - print(json.dumps(report, indent=2), flush=True) - - def test_qwen3_dflash_native_matches_hf(self): - """Legacy name: previously asserted HF-vs-native parity; now a native smoke/stability run.""" - if is_in_ci(): - self.skipTest("Manual benchmark; skipped in CI.") - if not torch.cuda.is_available(): - self.skipTest("CUDA is required for this manual DFlash benchmark.") - - target_model = os.getenv("SGLANG_DFLASH_TARGET_MODEL", "Qwen/Qwen3-8B") - draft_model_path = os.getenv( - "SGLANG_DFLASH_DRAFT_MODEL_PATH", "/tmp/Qwen3-8B-DFlash-bf16" - ) - - attention_backend = os.getenv("SGLANG_DFLASH_ATTENTION_BACKEND", "flashinfer") - # Keep env var names for backwards compatibility with the previous parity test. - max_new_tokens = int(os.getenv("SGLANG_DFLASH_PARITY_MAX_NEW_TOKENS", "256")) - parallel = int(os.getenv("SGLANG_DFLASH_PARITY_PARALLEL", "1")) - num_questions = int(os.getenv("SGLANG_DFLASH_PARITY_NUM_QUESTIONS", "10")) - max_total_tokens = int(os.getenv("SGLANG_DFLASH_PARITY_MAX_TOTAL_TOKENS", "8192")) - num_shots = int(os.getenv("SGLANG_DFLASH_NUM_SHOTS", "1")) - tp_size = int(os.getenv("SGLANG_DFLASH_TP_SIZE", "1")) - disable_radix_cache = os.getenv("SGLANG_DFLASH_DISABLE_RADIX_CACHE", "1") != "0" - prompt_style = os.getenv("SGLANG_DFLASH_PROMPT_STYLE", "fewshot_qa") - if tp_size < 1: - raise ValueError(f"Invalid SGLANG_DFLASH_TP_SIZE={tp_size}; expected >= 1.") - if torch.cuda.device_count() < tp_size: - self.skipTest( - f"tp_size={tp_size} requires at least {tp_size} visible CUDA devices, " - f"but only {torch.cuda.device_count()} are available. " - "Set CUDA_VISIBLE_DEVICES accordingly." - ) - - # Read GSM8K data (download if absent). - data_path = os.getenv("SGLANG_DFLASH_GSM8K_PATH", "test.jsonl") - url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" - if not os.path.isfile(data_path): - data_path = download_and_cache_file(url) - lines = list(read_jsonl(data_path)) - - tokenizer = None - if prompt_style == "dflash_chat": - tokenizer = AutoTokenizer.from_pretrained(target_model) - - few_shot = _get_few_shot_examples(lines, num_shots) if prompt_style == "fewshot_qa" else "" - prompts: list[str] = [] - for i in range(len(lines[:num_questions])): - if prompt_style == "fewshot_qa": - prompts.append(few_shot + _get_one_example(lines, i, False)) - elif prompt_style == "dflash_chat": - assert tokenizer is not None - user_content = ( - lines[i]["question"] - + "\nPlease reason step by step, and put your final answer within \\boxed{}." - ) - prompts.append( - tokenizer.apply_chat_template( - [{"role": "user", "content": user_content}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) - ) - else: - raise ValueError(f"Unsupported SGLANG_DFLASH_PROMPT_STYLE: {prompt_style}") - - common_server_args = [ - "--attention-backend", - attention_backend, - "--tp-size", - str(tp_size), - "--max-total-tokens", - str(max_total_tokens), - ] - if disable_radix_cache: - common_server_args.append("--disable-radix-cache") - extra_server_args = os.getenv("SGLANG_DFLASH_EXTRA_SERVER_ARGS", "").strip() - if extra_server_args: - common_server_args.extend(shlex.split(extra_server_args)) - - port = find_available_port(21000) - base_url = f"http://127.0.0.1:{port}" - proc = popen_launch_server( - target_model, - base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *common_server_args, - "--speculative-algorithm", - "DFLASH", - "--speculative-draft-model-path", - draft_model_path, - ], - ) - try: - _send_generate(base_url, "Hello", max_new_tokens=8) # warmup - _, outputs = _run_generate_batch( - base_url, - prompts, - max_new_tokens=max_new_tokens, - parallel=parallel, - ) - finally: - kill_process_tree(proc.pid) - try: - proc.wait(timeout=30) - except Exception: - pass - - self.assertEqual(len(outputs), len(prompts)) - spec_verify_cts: list[int] = [] - for out in outputs: - meta = out.get("meta_info", {}) - if "spec_verify_ct" in meta: - spec_verify_cts.append(int(meta["spec_verify_ct"])) - self.assertTrue(spec_verify_cts, "Missing spec_verify_ct in DFLASH responses.") - self.assertGreater(sum(spec_verify_cts), 0, "DFLASH did not run verify steps.") - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/test_dflash_acceptance_unit.py b/test/srt/test_dflash_acceptance_unit.py deleted file mode 100644 index 021df085575f..000000000000 --- a/test/srt/test_dflash_acceptance_unit.py +++ /dev/null @@ -1,55 +0,0 @@ -import unittest - -import torch - -from sglang.srt.speculative.dflash_utils import compute_dflash_accept_len_and_bonus - - -class TestDFlashAcceptanceUnit(unittest.TestCase): - def test_accept_len_and_bonus_basic(self): - candidates = torch.tensor( - [ - [10, 11, 12, 13], - [20, 21, 22, 23], - ], - dtype=torch.long, - ) - target_predict = torch.tensor( - [ - [11, 12, 55, 0], # accept 11,12 then bonus=55 - [99, 21, 22, 0], # accept none then bonus=99 - ], - dtype=torch.long, - ) - - accept_len, bonus = compute_dflash_accept_len_and_bonus( - candidates=candidates, - target_predict=target_predict, - ) - self.assertEqual(accept_len.tolist(), [2, 0]) - self.assertEqual(bonus.tolist(), [55, 99]) - - def test_accept_len_all_accepted(self): - candidates = torch.tensor([[10, 11, 12, 13]], dtype=torch.long) - target_predict = torch.tensor([[11, 12, 13, 77]], dtype=torch.long) - - accept_len, bonus = compute_dflash_accept_len_and_bonus( - candidates=candidates, - target_predict=target_predict, - ) - self.assertEqual(accept_len.tolist(), [3]) - self.assertEqual(bonus.tolist(), [77]) - - def test_shape_mismatch_raises(self): - candidates = torch.zeros((2, 4), dtype=torch.long) - target_predict = torch.zeros((2, 5), dtype=torch.long) - with self.assertRaises(ValueError): - compute_dflash_accept_len_and_bonus( - candidates=candidates, - target_predict=target_predict, - ) - - -if __name__ == "__main__": - unittest.main() - From 644ab2953644d52609d6d34c67a6ae188f92eb7e Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 10 Jan 2026 06:45:23 +0000 Subject: [PATCH 18/73] clean up dflash load_weights --- python/sglang/srt/models/dflash.py | 37 +++++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index a51d36f558ee..ed39fa15832f 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -259,31 +259,40 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if name.endswith(".bias") and name not in params_dict: - # Some quantized checkpoints may have extra biases. - # (May still be mappable to a fused/parallel param.) - pass + def resolve_param_name(name: str) -> Optional[str]: + if name in params_dict: + return name + if name.startswith("model."): + stripped_name = name[len("model.") :] + if stripped_name in params_dict: + return stripped_name + else: + prefixed_name = f"model.{name}" + if prefixed_name in params_dict: + return prefixed_name + return None + + for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: if f".{weight_name}." not in name: continue mapped_name = name.replace(weight_name, param_name) - param = params_dict.get(mapped_name) - if param is None: + resolved_name = resolve_param_name(mapped_name) + if resolved_name is None: continue + param = params_dict[resolved_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight, shard_id) break else: - if name.endswith(".bias") and name not in params_dict: + resolved_name = resolve_param_name(name) + if resolved_name is None: + # Ignore unexpected weights (e.g., HF rotary caches). continue - param = params_dict.get(name) - if param is None: - # Ignore unexpected weights (e.g., HF rotary caches). - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) - weight_loader(param, loaded_weight) + param = params_dict[resolved_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) EntryClass = DFlashDraftModel From ff6876a11caf6faeee914fab9e1948435f57f634 Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 10 Jan 2026 06:57:35 +0000 Subject: [PATCH 19/73] attention selection logic --- .../sglang/srt/speculative/dflash_worker.py | 15 +++++++++---- python/sglang/srt/speculative/draft_utils.py | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 3d9693eac478..f8678387f91a 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -66,12 +66,19 @@ def __init__( draft_backend, _ = draft_server_args.get_attention_backends() if draft_backend is None: draft_backend = "flashinfer" - if draft_backend not in ("flashinfer", "fa3"): - raise ValueError( + elif draft_backend == "trtllm_mha": + logger.warning( + "DFLASH draft worker does not support 'trtllm_mha' yet; " + "falling back to 'flashinfer'." + ) + draft_backend = "flashinfer" + elif draft_backend not in ("flashinfer", "fa3"): + logger.warning( "DFLASH draft worker only supports attention_backend in {'flashinfer', 'fa3'} for now, " - f"but got {draft_backend!r}. " - "Use `--speculative-draft-attention-backend` to override the draft backend." + "but got %r. Falling back to 'flashinfer'.", + draft_backend, ) + draft_backend = "flashinfer" # Make the draft worker backend explicit and self-contained (no further overrides). draft_server_args.speculative_draft_attention_backend = None diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index 9c630da72fb1..3f3e8c63388a 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -31,6 +31,27 @@ def _create_backend( if backend_type is None: backend_type = self.server_args.attention_backend + if backend_type is None: + backend_type = "flashinfer" + elif backend_type == "trtllm_mha": + logger.warning( + "Draft attention backend does not support 'trtllm_mha' yet; " + "falling back to 'flashinfer'." + ) + backend_type = "flashinfer" + + if backend_type not in backend_map: + fallback_backend = "flashinfer" if "flashinfer" in backend_map else None + if fallback_backend is None: + raise ValueError(error_template.format(backend_type=backend_type)) + logger.warning( + "Draft attention backend '%s' is not supported for speculative draft; " + "falling back to '%s'.", + backend_type, + fallback_backend, + ) + backend_type = fallback_backend + if backend_type not in backend_map: raise ValueError(error_template.format(backend_type=backend_type)) From d808ac920bce8a211fc5ec6f573cbf30565e46d7 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 00:29:29 +0000 Subject: [PATCH 20/73] decouple context feature count K from draft num layers --- python/sglang/srt/models/dflash.py | 51 ++++++++++++++++--- python/sglang/srt/speculative/dflash_info.py | 4 +- python/sglang/srt/speculative/dflash_utils.py | 18 ++++--- 3 files changed, 60 insertions(+), 13 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index ed39fa15832f..6fdd7a13762c 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -20,6 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.utils import apply_qk_norm +from sglang.srt.speculative.dflash_utils import get_dflash_config logger = logging.getLogger(__name__) @@ -189,20 +190,38 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: num_layers = int(config.num_hidden_layers) rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + dflash_cfg_dict = get_dflash_config(config) + self.layers = nn.ModuleList( [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] ) self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) # Project per-token target context features: - # concat(num_layers * hidden_size) -> hidden_size - self.fc = nn.Linear(num_layers * hidden_size, hidden_size, bias=False) + # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer + # feature tensors concatenated per token (not necessarily equal to num_layers). + target_layer_ids = dflash_cfg_dict.get("target_layer_ids", None) + if target_layer_ids is None: + num_context_features = num_layers + else: + if not isinstance(target_layer_ids, (list, tuple)): + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(target_layer_ids).__name__}." + ) + if len(target_layer_ids) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty, got []." + ) + num_context_features = len(target_layer_ids) + + self.num_context_features = int(num_context_features) + self.fc = nn.Linear( + self.num_context_features * hidden_size, hidden_size, bias=False + ) self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) - dflash_cfg = getattr(config, "dflash_config", None) - dflash_block_size = None - if isinstance(dflash_cfg, dict): - dflash_block_size = dflash_cfg.get("block_size", None) + dflash_block_size = dflash_cfg_dict.get("block_size", None) block_size = ( dflash_block_size @@ -225,6 +244,16 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: """Project concatenated target-layer hidden states into draft hidden_size.""" + expected = int(self.fc.in_features) + if target_hidden.ndim != 2 or int(target_hidden.shape[-1]) != expected: + raise ValueError( + "DFLASH target_hidden feature dim mismatch. " + f"Expected shape [N, {expected}] " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got shape={tuple(target_hidden.shape)}. " + "This usually means the target model is capturing a different number of layer features than " + "the draft checkpoint/config expects." + ) return self.hidden_norm(self.fc(target_hidden)) @torch.no_grad() @@ -291,6 +320,16 @@ def resolve_param_name(name: str) -> Optional[str]: # Ignore unexpected weights (e.g., HF rotary caches). continue param = params_dict[resolved_name] + if resolved_name.endswith("fc.weight") and tuple(loaded_weight.shape) != tuple( + param.shape + ): + raise ValueError( + "DFLASH fc.weight shape mismatch. This usually means the draft checkpoint's " + "number of context features (K) does not match this config. " + f"Expected fc.weight.shape={tuple(param.shape)} " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got {tuple(loaded_weight.shape)} for weight '{name}'." + ) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index c1144be52900..6c04d25f244a 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -36,7 +36,9 @@ class DFlashDraftInput(SpecInput): verified_id: torch.Tensor # Flattened context features for tokens that need to be appended into the draft cache. - # Shape: [sum(ctx_lens), num_draft_layers * hidden_size] + # Shape: [sum(ctx_lens), K * hidden_size], where K is the number of target-layer + # hidden-state features concatenated per token (len(dflash_config.target_layer_ids), + # or default K == draft_num_layers for existing checkpoints). target_hidden: torch.Tensor # Context lengths on CPU, one per request. Used to slice `target_hidden`. diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 4f1af9871318..8c0aa4facb05 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -49,7 +49,7 @@ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> Lis ] -def _get_dflash_config(config: Any) -> dict: +def get_dflash_config(config: Any) -> dict: cfg = getattr(config, "dflash_config", None) if cfg is None: return {} @@ -73,8 +73,14 @@ def resolve_dflash_target_layer_ids( Precedence: 1) `draft_hf_config.dflash_config.target_layer_ids` 2) default `build_target_layer_ids(target_num_layers, draft_num_layers)` + + Notes: + The number of draft transformer layers is *not* fundamentally tied to the number + of target-layer features (K) used as DFlash context. We treat + `len(target_layer_ids)` as K when explicitly provided. For backward compatibility + (and for current released checkpoints), the default still uses K == draft_num_layers. """ - cfg = _get_dflash_config(draft_hf_config) + cfg = get_dflash_config(draft_hf_config) layer_ids = cfg.get("target_layer_ids", None) if layer_ids is None: return build_target_layer_ids(target_num_layers, draft_num_layers) @@ -86,10 +92,10 @@ def resolve_dflash_target_layer_ids( ) resolved: List[int] = [int(x) for x in layer_ids] - if len(resolved) != int(draft_num_layers): + if len(resolved) <= 0: raise ValueError( - "DFLASH target_layer_ids length must equal the draft num_hidden_layers. " - f"Got len(target_layer_ids)={len(resolved)}, draft_num_layers={int(draft_num_layers)}." + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(resolved)}." ) for idx, val in enumerate(resolved): @@ -102,7 +108,7 @@ def resolve_dflash_target_layer_ids( def resolve_dflash_mask_token(*, draft_hf_config: Any) -> str: - cfg = _get_dflash_config(draft_hf_config) + cfg = get_dflash_config(draft_hf_config) mask_token = cfg.get("mask_token", None) if mask_token is None: return DEFAULT_DFLASH_MASK_TOKEN From e589ac1bddb5b465f49744880ed198d3e59d250e Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 00:53:58 +0000 Subject: [PATCH 21/73] clean up naming --- python/sglang/srt/models/dflash.py | 2 +- .../sglang/srt/speculative/dflash_worker.py | 60 +++++++++---------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 6fdd7a13762c..396feeaa64fc 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -174,7 +174,7 @@ def forward( class DFlashDraftModel(nn.Module): - """SGLang-native DFlash draft model (no embedding / lm_head weights). + """SGLang DFlash draft model (no embedding / lm_head weights). The checkpoint provides: - transformer weights for `layers.*` diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index f8678387f91a..ed86b483f7f1 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -53,7 +53,7 @@ def __init__( self._warned_forced_greedy = False self._logged_first_verify = False - # Native (SGLang) draft runner (separate KV cache + attention backend). + # Draft runner (separate KV cache + attention backend). # Share req_to_token_pool + token_to_kv_pool_allocator with the target worker (EAGLE3-style), # while keeping a separate draft KV cache pool (the draft model has different KV values). shared_req_to_token_pool, shared_token_to_kv_pool_allocator = ( @@ -87,7 +87,7 @@ def __init__( draft_server_args.attention_backend = draft_backend # Keep draft context length aligned with the target. draft_server_args.context_length = target_worker.model_runner.model_config.context_len - self.native_draft_worker = TpModelWorker( + self.draft_worker = TpModelWorker( server_args=draft_server_args, gpu_id=gpu_id, tp_rank=tp_rank, @@ -99,14 +99,14 @@ def __init__( req_to_token_pool=shared_req_to_token_pool, token_to_kv_pool_allocator=shared_token_to_kv_pool_allocator, ) - self.native_draft_model_runner = self.native_draft_worker.model_runner - self.native_draft_model = self.native_draft_model_runner.model + self.draft_model_runner = self.draft_worker.model_runner + self.draft_model = self.draft_model_runner.model if server_args.speculative_num_draft_tokens is None: # Should not happen (ServerArgs should have inferred it), but keep a fallback. - self.block_size = int(getattr(self.native_draft_model, "block_size", 16)) + self.block_size = int(getattr(self.draft_model, "block_size", 16)) else: self.block_size = int(server_args.speculative_num_draft_tokens) - model_block_size = getattr(self.native_draft_model, "block_size", None) + model_block_size = getattr(self.draft_model, "block_size", None) if model_block_size is not None and int(model_block_size) != int(self.block_size): logger.warning( "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", @@ -115,18 +115,18 @@ def __init__( ) self._mask_token = resolve_dflash_mask_token( - draft_hf_config=self.native_draft_model_runner.model_config.hf_config + draft_hf_config=self.draft_model_runner.model_config.hf_config ) self._mask_token_id = self._resolve_mask_token_id(mask_token=self._mask_token) if self.tp_rank == 0: logger.info( - "Initialized native DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", + "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", getattr(draft_server_args, "attention_backend", None), - self.native_draft_model.__class__.__name__, + self.draft_model.__class__.__name__, self.block_size, ) logger.info( - "DFLASH draft impl selected. impl=native, mask_token=%s, mask_token_id=%s", + "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s", self._mask_token, self._mask_token_id, ) @@ -213,8 +213,8 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D bs = batch.batch_size() device = self.model_runner.device - # --- 1) Append any newly committed tokens into the native draft KV cache. - self._append_target_hidden_to_native_draft_kv(batch, draft_input) + # --- 1) Append any newly committed tokens into the draft KV cache. + self._append_target_hidden_to_draft_kv(batch, draft_input) target_model = self.target_worker.model_runner.model embed_module = target_model.get_input_embeddings() @@ -225,7 +225,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D "`shard_indices` attributes." ) - # --- 2) Draft a non-causal block with the native draft model. + # --- 2) Draft a non-causal block with the draft model. block_ids = torch.full( (bs, self.block_size), self._mask_token_id, @@ -243,7 +243,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D (bs,), int(self.block_size), dtype=torch.int32, device=device ) positions, extend_start_loc = compute_position( - self.native_draft_model_runner.server_args.attention_backend, + self.draft_model_runner.server_args.attention_backend, prefix_lens, extend_lens, bs * self.block_size, @@ -251,18 +251,18 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D block_start = prefix_lens.to(torch.int64) block_end = block_start + int(self.block_size) - allocator = self.native_draft_model_runner.token_to_kv_pool_allocator + allocator = self.draft_model_runner.token_to_kv_pool_allocator token_to_kv_pool_state_backup = allocator.backup_state() try: block_cache_loc = allocator.alloc(bs * self.block_size) if block_cache_loc is None: raise RuntimeError( - f"DFLASH native draft OOM when allocating {bs * self.block_size} block tokens." + f"DFLASH draft OOM when allocating {bs * self.block_size} block tokens." ) assign_req_to_token_pool_func( batch.req_pool_indices, - self.native_draft_model_runner.req_to_token_pool.req_to_token, + self.draft_model_runner.req_to_token_pool.req_to_token, block_start, block_end, block_cache_loc, @@ -290,9 +290,9 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D seq_lens_sum=int(seq_lens.sum().item()), seq_lens_cpu=torch.tensor(prefix_lens_cpu, dtype=torch.int32), positions=positions, - req_to_token_pool=self.native_draft_model_runner.req_to_token_pool, - token_to_kv_pool=self.native_draft_model_runner.token_to_kv_pool, - attn_backend=self.native_draft_model_runner.attn_backend, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, input_embeds=input_embeds, spec_algorithm=SpeculativeAlgorithm.DFLASH, spec_info=draft_spec_info, @@ -300,7 +300,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D ) with torch.inference_mode(): - draft_hidden = self.native_draft_model_runner.forward( + draft_hidden = self.draft_model_runner.forward( forward_batch ).logits_output finally: @@ -439,12 +439,12 @@ def _greedy_sample_from_vocab_parallel_head( return out_token_ids - def _append_target_hidden_to_native_draft_kv( + def _append_target_hidden_to_draft_kv( self, batch: ScheduleBatch, draft_input: DFlashDraftInput, ) -> None: - """Materialize the target hidden-state features into the native draft KV cache. + """Materialize the target hidden-state features into the draft KV cache. This must be run before exposing new tokens to radix cache (prefix hits), otherwise another request could reuse target KV indices without having draft KV values. @@ -475,7 +475,7 @@ def _append_target_hidden_to_native_draft_kv( ): if int(cache_len) + int(ctx_len) != int(start_pos): raise RuntimeError( - "DFLASH native draft cache length mismatch. " + "DFLASH draft cache length mismatch. " f"cache_len={int(cache_len)}, ctx_len={int(ctx_len)}, start_pos={int(start_pos)}." ) @@ -483,7 +483,7 @@ def _append_target_hidden_to_native_draft_kv( if total_ctx <= 0: return - req_to_token = self.native_draft_model_runner.req_to_token_pool.req_to_token + req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token req_pool_indices_cpu = batch.req_pool_indices.tolist() ctx_cache_loc_chunks: List[torch.Tensor] = [] @@ -518,11 +518,11 @@ def _append_target_hidden_to_native_draft_kv( ) with torch.inference_mode(): - ctx_hidden = self.native_draft_model.project_target_hidden( + ctx_hidden = self.draft_model.project_target_hidden( draft_input.target_hidden ) # [sum(ctx), hidden] - for layer in self.native_draft_model.layers: + for layer in self.draft_model.layers: attn = layer.self_attn qkv, _ = attn.qkv_proj(ctx_hidden) q, k, v = qkv.split([attn.q_size, attn.kv_size, attn.kv_size], dim=-1) @@ -537,7 +537,7 @@ def _append_target_hidden_to_native_draft_kv( q, k = attn.rotary_emb(ctx_positions, q, k) k = k.view(-1, attn.num_kv_heads, attn.head_dim) v = v.view(-1, attn.num_kv_heads, attn.head_dim) - self.native_draft_model_runner.token_to_kv_pool.set_kv_buffer( + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( attn.attn, ctx_cache_loc, k, @@ -592,7 +592,7 @@ def forward_batch_generation( ctx_lens_cpu=[int(x) for x in model_worker_batch.extend_seq_lens], draft_seq_lens_cpu=[int(x) for x in model_worker_batch.extend_prefix_lens], ) - self._append_target_hidden_to_native_draft_kv(batch, draft_input) + self._append_target_hidden_to_draft_kv(batch, draft_input) batch.spec_info = draft_input for req, draft_len in zip(batch.reqs, draft_input.draft_seq_lens_cpu, strict=True): req.dflash_draft_seq_len = int(draft_len) @@ -643,7 +643,7 @@ def forward_batch_generation( draft_input.verified_id = new_verified_id draft_input.target_hidden = next_target_hidden draft_input.ctx_lens_cpu = commit_lens.cpu().tolist() - self._append_target_hidden_to_native_draft_kv(batch, draft_input) + self._append_target_hidden_to_draft_kv(batch, draft_input) batch.spec_info = draft_input batch.forward_mode = ForwardMode.DECODE From 074efb2b59022c6b20afc6ee08ee0070ea63b8f7 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 04:17:22 +0000 Subject: [PATCH 22/73] performance optimizations --- .../srt/model_executor/cuda_graph_runner.py | 16 ++++-- python/sglang/srt/speculative/dflash_info.py | 47 ++++++++++++---- .../sglang/srt/speculative/dflash_worker.py | 53 ++++++++++++------- 3 files changed, 83 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c68b2cd90e63..7b0a84afb2dd 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -942,13 +942,23 @@ def get_spec_info(self, num_tokens: int): elif self.model_runner.spec_algorithm.is_dflash(): from sglang.srt.speculative.dflash_info import DFlashVerifyInput + backend_name = type(self.model_runner.attn_backend).__name__ + # Avoid enabling custom-mask modes during graph capture for backends that + # can express DFLASH verify via their built-in causal path. + skip_custom_mask = backend_name in { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } spec_info = DFlashVerifyInput( draft_token=None, positions=None, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, - custom_mask=( - None if self.model_runner.is_draft_worker else self.buffers.custom_mask - ), + custom_mask=None + if (self.model_runner.is_draft_worker or skip_custom_mask) + else self.buffers.custom_mask, capture_hidden_mode=( CaptureHiddenMode.NULL if self.model_runner.is_draft_worker diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 6c04d25f244a..9c09d159a517 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -120,7 +120,13 @@ def __post_init__(self): def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: return self.draft_token_num, self.draft_token_num - def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): + def prepare_for_verify( + self, + batch: ScheduleBatch, + page_size: int, + *, + build_custom_mask: bool = True, + ): if batch.forward_mode.is_idle(): return @@ -160,23 +166,34 @@ def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): bs, ) - # Build a standard causal attention *allow* mask over [prefix + verify_block] for each request. - # Layout matches other speculative inputs: flatten per request, row-major over - # [q_len=draft_token_num, kv_len=prefix_len + draft_token_num]. + if not build_custom_mask: + self.custom_mask = None + return + if self.draft_token_num <= 0: - raise ValueError(f"DFLASH draft_token_num must be positive, got {self.draft_token_num}.") + raise ValueError( + f"DFLASH draft_token_num must be positive, got {self.draft_token_num}." + ) mask_chunks: List[torch.Tensor] = [] q_len = int(self.draft_token_num) - q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) + q_idx = torch.arange( + q_len, device=batch.device, dtype=torch.int32 + ).unsqueeze(1) for prefix_len in batch.seq_lens_cpu.tolist(): prefix_len_i = int(prefix_len) kv_len = prefix_len_i + q_len - k_idx = torch.arange(kv_len, device=batch.device, dtype=torch.int32).unsqueeze(0) + k_idx = torch.arange( + kv_len, device=batch.device, dtype=torch.int32 + ).unsqueeze(0) # Allow attending to the full prefix and to tokens up to (and including) the # current query position within the verify block (standard causal masking). allow = k_idx <= (prefix_len_i + q_idx) mask_chunks.append(allow.flatten()) - self.custom_mask = torch.cat(mask_chunks, dim=0) if mask_chunks else torch.empty((0,), dtype=torch.bool, device=batch.device) + self.custom_mask = ( + torch.cat(mask_chunks, dim=0) + if mask_chunks + else torch.empty((0,), dtype=torch.bool, device=batch.device) + ) def generate_attn_arg_prefill( self, @@ -268,9 +285,17 @@ def verify( target_predict=target_predict, ) - candidates_cpu = candidates.cpu().tolist() - accept_len_cpu = accept_len.cpu().tolist() - bonus_cpu = bonus.cpu().tolist() + packed = torch.empty( + (bs, self.draft_token_num + 2), dtype=torch.int64, device=device + ) + packed[:, : self.draft_token_num].copy_(candidates) + packed[:, self.draft_token_num].copy_(accept_len) + packed[:, self.draft_token_num + 1].copy_(bonus) + packed_cpu = packed.cpu() + + candidates_cpu = packed_cpu[:, : self.draft_token_num].tolist() + accept_len_cpu = packed_cpu[:, self.draft_token_num].tolist() + bonus_cpu = packed_cpu[:, self.draft_token_num + 1].tolist() commit_lens_cpu: List[int] = [] new_verified_cpu: List[int] = [] diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index ed86b483f7f1..8ab3e4bb1828 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -14,7 +14,6 @@ CaptureHiddenMode, ForwardBatch, ForwardMode, - compute_position, ) from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput @@ -239,15 +238,10 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D prefix_lens_cpu = [int(x) for x in draft_input.draft_seq_lens_cpu] prefix_lens = torch.tensor(prefix_lens_cpu, dtype=torch.int32, device=device) - extend_lens = torch.full( - (bs,), int(self.block_size), dtype=torch.int32, device=device - ) - positions, extend_start_loc = compute_position( - self.draft_model_runner.server_args.attention_backend, - prefix_lens, - extend_lens, - bs * self.block_size, - ) + positions = ( + prefix_lens.to(torch.long).unsqueeze(1) + + torch.arange(self.block_size, device=device, dtype=torch.long)[None, :] + ).flatten() block_start = prefix_lens.to(torch.int64) block_end = block_start + int(self.block_size) @@ -280,6 +274,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D capture_hidden_mode=CaptureHiddenMode.NULL, ) seq_lens = prefix_lens.to(torch.int32) + seq_lens_sum = int(sum(prefix_lens_cpu)) forward_batch = ForwardBatch( forward_mode=ForwardMode.TARGET_VERIFY, batch_size=bs, @@ -287,7 +282,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D req_pool_indices=batch.req_pool_indices, seq_lens=seq_lens, out_cache_loc=block_cache_loc, - seq_lens_sum=int(seq_lens.sum().item()), + seq_lens_sum=seq_lens_sum, seq_lens_cpu=torch.tensor(prefix_lens_cpu, dtype=torch.int32), positions=positions, req_to_token_pool=self.draft_model_runner.req_to_token_pool, @@ -323,7 +318,20 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D positions=positions, draft_token_num=self.block_size, ) - verify_input.prepare_for_verify(batch, self.page_size) + backend_name = type(self.model_runner.attn_backend).__name__ + skip_custom_mask = backend_name in { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } + build_custom_mask = not skip_custom_mask + verify_input.prepare_for_verify( + batch, + self.page_size, + build_custom_mask=build_custom_mask, + ) batch.forward_mode = ForwardMode.TARGET_VERIFY if not batch.forward_mode.is_idle() else ForwardMode.IDLE batch.spec_info = verify_input @@ -484,16 +492,20 @@ def _append_target_hidden_to_draft_kv( return req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token - req_pool_indices_cpu = batch.req_pool_indices.tolist() + + req_pool_indices = batch.req_pool_indices + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) ctx_cache_loc_chunks: List[torch.Tensor] = [] ctx_positions_chunks: List[torch.Tensor] = [] new_draft_seq_lens_cpu: List[int] = [] - for req_pool_idx, cache_len, ctx_len in zip( - req_pool_indices_cpu, - draft_input.draft_seq_lens_cpu, - draft_input.ctx_lens_cpu, - strict=True, + for i, (cache_len, ctx_len) in enumerate( + zip( + draft_input.draft_seq_lens_cpu, + draft_input.ctx_lens_cpu, + strict=True, + ) ): cache_len_i = int(cache_len) ctx_len_i = int(ctx_len) @@ -502,8 +514,11 @@ def _append_target_hidden_to_draft_kv( continue s = cache_len_i e = cache_len_i + ctx_len_i + req_pool_idx = req_pool_indices[i] ctx_cache_loc_chunks.append(req_to_token[req_pool_idx, s:e].to(torch.int64)) - ctx_positions_chunks.append(torch.arange(s, e, device=device, dtype=torch.int64)) + ctx_positions_chunks.append( + torch.arange(s, e, device=device, dtype=torch.int64) + ) ctx_cache_loc = ( torch.cat(ctx_cache_loc_chunks, dim=0) From fcc9bf72926cca6a4afd9d748f264c47b58fd35c Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 05:32:07 +0000 Subject: [PATCH 23/73] skip Q, fused mlp --- python/sglang/srt/models/dflash.py | 107 +++++++++++++++--- .../sglang/srt/speculative/dflash_worker.py | 15 +-- 2 files changed, 92 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 396feeaa64fc..60c9cb540aab 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -13,8 +13,14 @@ from torch import nn from sglang.srt.distributed import get_tensor_model_parallel_world_size -from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -123,21 +129,72 @@ def forward( output, _ = self.o_proj(attn_output) return output + def kv_proj_only(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Project hidden_states to K/V only (skip Q). + + This is used by DFlash to materialize ctx tokens into the draft KV cache: + we only need K/V for the cached tokens; Q is never consumed. + """ + # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. + if isinstance(getattr(self.qkv_proj, "quant_method", None), UnquantizedLinearMethod): + kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) + weight = self.qkv_proj.weight[kv_slice] + bias = self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + kv = F.linear(hidden_states, weight, bias) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + return k, v + + # Fallback: compute full QKV and discard Q (keeps compatibility with quantized weights). + qkv, _ = self.qkv_proj(hidden_states) + _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return k, v + + def apply_k_norm(self, k: torch.Tensor) -> torch.Tensor: + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + return k_by_head.view_as(k) + + def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Use a minimal dummy query (1 head) to avoid doing full-Q work. + dummy_q = k.new_empty((k.shape[0], self.head_dim)) + _, k = self.rotary_emb(positions, dummy_q, k) + return k + class DFlashMLP(nn.Module): - def __init__(self, config) -> None: + def __init__(self, config, quant_config=None, prefix: str = "") -> None: super().__init__() hidden_size = int(config.hidden_size) intermediate_size = int(getattr(config, "intermediate_size", 0)) if intermediate_size <= 0: raise ValueError(f"Invalid intermediate_size={intermediate_size} for DFlash MLP.") - self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) - self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix="gate_up_proj" if not prefix else f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix="down_proj" if not prefix else f"{prefix}.down_proj", + ) + hidden_act = getattr(config, "hidden_act", "silu") + if hidden_act != "silu": + raise ValueError( + f"Unsupported DFlash activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x class DFlashDecoderLayer(nn.Module): @@ -156,21 +213,29 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, - ) -> torch.Tensor: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.numel() == 0: + # Keep return types consistent for upstream callers. + if residual is None: + residual = hidden_states + return hidden_states, residual + + # Pre-norm attention with fused residual+norm when possible (Qwen3-style). + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_out = self.self_attn( positions=positions, hidden_states=hidden_states, forward_batch=forward_batch, ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, residual = self.post_attention_layernorm(attn_out, residual) hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class DFlashDraftModel(nn.Module): @@ -271,13 +336,17 @@ def forward( "DFlashDraftModel requires `input_embeds` (use the target embedding)." ) hidden_states = input_embeds + residual: Optional[torch.Tensor] = None for layer in self.layers: - hidden_states = layer(positions, hidden_states, forward_batch) + hidden_states, residual = layer(positions, hidden_states, forward_batch, residual) if hidden_states.numel() == 0: return hidden_states - return self.norm(hidden_states) + if residual is None: + return self.norm(hidden_states) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -285,6 +354,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] params_dict = dict(self.named_parameters()) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8ab3e4bb1828..2ff2faf388b0 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -9,7 +9,6 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker -from sglang.srt.models.utils import apply_qk_norm from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -539,17 +538,9 @@ def _append_target_hidden_to_draft_kv( for layer in self.draft_model.layers: attn = layer.self_attn - qkv, _ = attn.qkv_proj(ctx_hidden) - q, k, v = qkv.split([attn.q_size, attn.kv_size, attn.kv_size], dim=-1) - - q, k = apply_qk_norm( - q=q, - k=k, - q_norm=attn.q_norm, - k_norm=attn.k_norm, - head_dim=attn.head_dim, - ) - q, k = attn.rotary_emb(ctx_positions, q, k) + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(ctx_positions, k) k = k.view(-1, attn.num_kv_heads, attn.head_dim) v = v.view(-1, attn.num_kv_heads, attn.head_dim) self.draft_model_runner.token_to_kv_pool.set_kv_buffer( From a79264fa6eb81b744b835396a77e8ab7bacbcbc4 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 05:59:46 +0000 Subject: [PATCH 24/73] reuse buffers for decode --- .../sglang/srt/speculative/dflash_worker.py | 102 ++++++++++++------ 1 file changed, 71 insertions(+), 31 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 2ff2faf388b0..2aff5419ce39 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -129,6 +129,44 @@ def __init__( self._mask_token_id, ) + self._block_pos_offsets = torch.arange( + self.block_size, device=self.device, dtype=torch.int64 + ) + self._draft_block_ids_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_positions_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_tokens_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] + self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU + self._draft_block_spec_info = DFlashVerifyInput( + draft_token=torch.empty((0,), dtype=torch.long, device=self.device), + positions=torch.empty((0,), dtype=torch.int64, device=self.device), + draft_token_num=int(self.block_size), + custom_mask=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + def _ensure_draft_block_buffers(self, bs: int) -> None: + cap = 0 if self._draft_block_ids_buf is None else int(self._draft_block_ids_buf.shape[0]) + if cap >= int(bs): + return + + new_cap = max(int(bs), cap * 2 if cap > 0 else int(bs)) + device = self.device + block_size = int(self.block_size) + self._draft_block_ids_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_positions_buf = torch.empty( + (new_cap, block_size), dtype=torch.int64, device=device + ) + self._draft_block_tokens_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_end_buf = torch.empty( + (new_cap,), dtype=torch.int32, device=device + ) + self._draft_seq_lens_cpu_buf = torch.empty((new_cap,), dtype=torch.int32, device="cpu") + def __getattr__(self, name): # Delegate anything not implemented yet to the target worker. return getattr(self.target_worker, name) @@ -224,26 +262,35 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D ) # --- 2) Draft a non-causal block with the draft model. - block_ids = torch.full( - (bs, self.block_size), - self._mask_token_id, - dtype=torch.long, - device=device, - ) - block_ids[:, 0] = draft_input.verified_id.to(torch.long) + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id.to(torch.long)) noise_embedding = embed_module(block_ids) input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) - prefix_lens_cpu = [int(x) for x in draft_input.draft_seq_lens_cpu] - prefix_lens = torch.tensor(prefix_lens_cpu, dtype=torch.int32, device=device) - positions = ( - prefix_lens.to(torch.long).unsqueeze(1) - + torch.arange(self.block_size, device=device, dtype=torch.long)[None, :] - ).flatten() + # For spec-v1, the draft KV cache is always materialized to the current target + # prefix before drafting the next block. + prefix_lens = batch.seq_lens # int32, device - block_start = prefix_lens.to(torch.int64) - block_end = block_start + int(self.block_size) + positions_2d = self._draft_block_positions_buf[:bs] + torch.add(prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d) + positions = positions_2d.reshape(-1) + + block_start = prefix_lens + block_end = self._draft_block_end_buf[:bs] + torch.add(block_start, int(self.block_size), out=block_end) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + for i, ln in enumerate(batch.seq_lens_cpu): + seq_lens_cpu[i] = int(ln) allocator = self.draft_model_runner.token_to_kv_pool_allocator token_to_kv_pool_state_backup = allocator.backup_state() try: @@ -265,15 +312,9 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D # Use TARGET_VERIFY mode (cuda-graphable) to run a fixed-size draft block. # In this mode, `seq_lens` stores the prefix lengths; attention backends # derive kv_len by adding `draft_token_num`. - draft_spec_info = DFlashVerifyInput( - draft_token=torch.empty((0,), dtype=torch.long, device=device), - positions=torch.empty((0,), dtype=torch.int64, device=device), - draft_token_num=int(self.block_size), - custom_mask=None, - capture_hidden_mode=CaptureHiddenMode.NULL, - ) - seq_lens = prefix_lens.to(torch.int32) - seq_lens_sum = int(sum(prefix_lens_cpu)) + draft_spec_info = self._draft_block_spec_info + seq_lens = prefix_lens + seq_lens_sum = int(batch.seq_lens_sum) forward_batch = ForwardBatch( forward_mode=ForwardMode.TARGET_VERIFY, batch_size=bs, @@ -282,7 +323,7 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D seq_lens=seq_lens, out_cache_loc=block_cache_loc, seq_lens_sum=seq_lens_sum, - seq_lens_cpu=torch.tensor(prefix_lens_cpu, dtype=torch.int32), + seq_lens_cpu=seq_lens_cpu, positions=positions, req_to_token_pool=self.draft_model_runner.req_to_token_pool, token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, @@ -306,14 +347,13 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), lm_head=lm_head, ).view(bs, self.block_size - 1) - draft_tokens = torch.cat([block_ids[:, :1], draft_next], dim=1) # [bs, block_size] - positions = ( - batch.seq_lens.to(torch.long).unsqueeze(1) - + torch.arange(self.block_size, device=device, dtype=torch.long)[None, :] - ).flatten() + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + positions = positions_2d.reshape(-1) verify_input = DFlashVerifyInput( - draft_token=draft_tokens.flatten(), + draft_token=draft_tokens.reshape(-1), positions=positions, draft_token_num=self.block_size, ) From ad5adbfba99a279604659497d1ba779f139031da Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 06:10:23 +0000 Subject: [PATCH 25/73] optimize greedy sampling --- .../sglang/srt/speculative/dflash_worker.py | 44 ++++++++++++++----- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 2aff5419ce39..7a19bcf27455 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -289,8 +289,10 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D torch.add(block_start, int(self.block_size), out=block_end) seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] - for i, ln in enumerate(batch.seq_lens_cpu): - seq_lens_cpu[i] = int(ln) + if batch.seq_lens_cpu.dtype == torch.int32: + seq_lens_cpu.copy_(batch.seq_lens_cpu) + else: + seq_lens_cpu.copy_(batch.seq_lens_cpu.to(torch.int32)) allocator = self.draft_model_runner.token_to_kv_pool_allocator token_to_kv_pool_state_backup = allocator.backup_state() try: @@ -419,9 +421,27 @@ def _greedy_sample_from_vocab_parallel_head( (num_tokens,), dtype=torch.long, device=hidden_states.device ) + def _cast_hs(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype == weight_dtype else x.to(weight_dtype) + + # Fast path (common): single-rank greedy sampling over the base vocab shard. + # Avoids extra max/id bookkeeping that is only needed for TP sync or added vocab. + if tp_size == 1 and num_added == 0: + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + out_token_ids[start:end] = ( + torch.argmax(base_logits, dim=-1).to(torch.long) + org_vocab_start + ) + else: + out_token_ids[start:end] = 0 + return out_token_ids + for start in range(0, num_tokens, int(chunk_size)): end = min(num_tokens, start + int(chunk_size)) - hs = hidden_states[start:end].to(weight_dtype) + hs = _cast_hs(hidden_states[start:end]) chunk_len = int(hs.shape[0]) # Base vocab logits. @@ -447,17 +467,17 @@ def _greedy_sample_from_vocab_parallel_head( added_max, added_arg = torch.max(added_logits, dim=-1) use_added = added_max > local_max local_max = torch.where(use_added, added_max, local_max) - local_arg = torch.where( - use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg - ) + # For base/added conversion below, keep local_arg expressed in the full local + # weight index space (base + padding + added), matching `lm_head.weight`. + local_arg = torch.where(use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg) # Convert local argmax indices to global token ids. - global_ids = torch.empty( - (chunk_len,), dtype=torch.int64, device=hs.device - ) - is_base = local_arg < num_org - global_ids[is_base] = org_vocab_start + local_arg[is_base] - if num_added > 0: + if num_added == 0: + global_ids = org_vocab_start + local_arg + else: + global_ids = torch.empty((chunk_len,), dtype=torch.int64, device=hs.device) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] global_ids[~is_base] = added_vocab_start + (local_arg[~is_base] - num_org_padded) if tp_size == 1: From 37fc3f1f83f918fb50a5282a77ef08103876fc19 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 06:22:27 +0000 Subject: [PATCH 26/73] preallocate for tp>1 --- .../sglang/srt/speculative/dflash_worker.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 7a19bcf27455..a99948002232 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -144,6 +144,9 @@ def __init__( custom_mask=None, capture_hidden_mode=CaptureHiddenMode.NULL, ) + self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None + self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_gather_cap: int = 0 def _ensure_draft_block_buffers(self, bs: int) -> None: cap = 0 if self._draft_block_ids_buf is None else int(self._draft_block_ids_buf.shape[0]) @@ -485,16 +488,27 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: continue # Gather per-rank maxima and associated global ids, then select the global max. - gathered_max = torch.empty( - (tp_size * chunk_len,), - dtype=local_max.dtype, - device=hs.device, - ) - gathered_ids = torch.empty( - (tp_size * chunk_len,), - dtype=global_ids.dtype, - device=hs.device, - ) + needed = tp_size * chunk_len + if ( + self._draft_greedy_gather_cap < needed + or self._draft_greedy_gathered_max_buf is None + or self._draft_greedy_gathered_ids_buf is None + or self._draft_greedy_gathered_max_buf.dtype != local_max.dtype + or self._draft_greedy_gathered_max_buf.device != hs.device + ): + # Allocate enough space for the max chunk size to avoid reallocations. + cap = tp_size * int(chunk_size) + self._draft_greedy_gathered_max_buf = torch.empty( + (cap,), dtype=local_max.dtype, device=hs.device + ) + self._draft_greedy_gathered_ids_buf = torch.empty( + (cap,), dtype=global_ids.dtype, device=hs.device + ) + self._draft_greedy_gather_cap = cap + + gathered_max = self._draft_greedy_gathered_max_buf[:needed] + gathered_ids = self._draft_greedy_gathered_ids_buf[:needed] + tp_group.all_gather_into_tensor(gathered_max, local_max.contiguous()) tp_group.all_gather_into_tensor(gathered_ids, global_ids.contiguous()) gathered_max = gathered_max.view(tp_size, chunk_len) From 72cbd9dad8a4f51af6600be14f93544675bfc2c5 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 06:56:25 +0000 Subject: [PATCH 27/73] more buffers for tp>1 --- .../sglang/srt/speculative/dflash_worker.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index a99948002232..75b6e03c224e 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -147,6 +147,10 @@ def __init__( self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None self._draft_greedy_gather_cap: int = 0 + self._draft_greedy_best_rank_buf: Optional[torch.Tensor] = None + self._draft_greedy_rank_index_buf: Optional[torch.Tensor] = None + self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_index_cap: int = 0 def _ensure_draft_block_buffers(self, bs: int) -> None: cap = 0 if self._draft_block_ids_buf is None else int(self._draft_block_ids_buf.shape[0]) @@ -476,7 +480,8 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: # Convert local argmax indices to global token ids. if num_added == 0: - global_ids = org_vocab_start + local_arg + local_arg.add_(org_vocab_start) + global_ids = local_arg else: global_ids = torch.empty((chunk_len,), dtype=torch.int64, device=hs.device) is_base = local_arg < num_org @@ -489,6 +494,7 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: # Gather per-rank maxima and associated global ids, then select the global max. needed = tp_size * chunk_len + chunk_cap = int(chunk_size) if ( self._draft_greedy_gather_cap < needed or self._draft_greedy_gathered_max_buf is None @@ -497,7 +503,7 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: or self._draft_greedy_gathered_max_buf.device != hs.device ): # Allocate enough space for the max chunk size to avoid reallocations. - cap = tp_size * int(chunk_size) + cap = tp_size * chunk_cap self._draft_greedy_gathered_max_buf = torch.empty( (cap,), dtype=local_max.dtype, device=hs.device ) @@ -506,6 +512,25 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: ) self._draft_greedy_gather_cap = cap + if ( + self._draft_greedy_index_cap < chunk_len + or self._draft_greedy_best_rank_buf is None + or self._draft_greedy_rank_index_buf is None + or self._draft_greedy_selected_ids_buf is None + or self._draft_greedy_best_rank_buf.device != hs.device + or self._draft_greedy_selected_ids_buf.device != hs.device + ): + self._draft_greedy_best_rank_buf = torch.empty( + (chunk_cap,), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_rank_index_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_selected_ids_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_index_cap = chunk_cap + gathered_max = self._draft_greedy_gathered_max_buf[:needed] gathered_ids = self._draft_greedy_gathered_ids_buf[:needed] @@ -514,9 +539,14 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: gathered_max = gathered_max.view(tp_size, chunk_len) gathered_ids = gathered_ids.view(tp_size, chunk_len) - best_rank = torch.argmax(gathered_max, dim=0) - idx = torch.arange(chunk_len, device=hs.device) - out_token_ids[start:end] = gathered_ids[best_rank, idx].to(torch.long) + best_rank = self._draft_greedy_best_rank_buf[:chunk_len] + torch.argmax(gathered_max, dim=0, out=best_rank) + + rank_index = self._draft_greedy_rank_index_buf[:, :chunk_len] + rank_index[0].copy_(best_rank) + selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len] + torch.gather(gathered_ids, 0, rank_index, out=selected_ids) + out_token_ids[start:end].copy_(selected_ids.view(-1)) return out_token_ids From 5a577a3b2519e5d82c730e51888cb25d83698e2c Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 08:32:51 +0000 Subject: [PATCH 28/73] dflash gsm8k benchmark sweep --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 611 +++++++++++++++++++ 1 file changed, 611 insertions(+) create mode 100644 benchmark/dflash/bench_dflash_gsm8k_sweep.py diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py new file mode 100644 index 000000000000..1a2af75ffd8c --- /dev/null +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -0,0 +1,611 @@ +"""DFLASH vs baseline GSM8K sweep. + +This is a *benchmark script* (not a CI test): it can take a long time because it +launches servers for multiple (attention_backend, tp_size) configs and runs a +GSM8K workload for each (concurrency, num_questions) setting. + +Example usage: + ./venv/bin/python benchmark/gsm8k/bench_dflash_gsm8k_sweep.py --output-md dflash_gsm8k_sweep.md +""" + +from __future__ import annotations + +import argparse +import ast +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Optional + +import requests +import torch +from transformers import AutoTokenizer + +from sglang.srt.environ import envs +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + find_available_port, + popen_launch_server, +) +from sglang.utils import download_and_cache_file, read_jsonl + +INVALID = -9999999 + + +def _is_blackwell() -> bool: + # Prefer explicit env var, but also infer from compute capability (SM100+). + if envs.IS_BLACKWELL.get(): + return True + return get_device_sm() >= 100 + + +def _get_one_example(lines, i: int, include_answer: bool) -> str: + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def _get_few_shot_examples(lines, k: int) -> str: + ret = "" + for i in range(k): + ret += _get_one_example(lines, i, True) + "\n\n" + return ret + + +def _get_answer_value(answer_str: str) -> int: + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def _maybe_download_gsm8k(data_path: str) -> str: + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + if os.path.isfile(data_path): + return data_path + return download_and_cache_file(url) + + +def _flush_cache(base_url: str) -> None: + resp = requests.get(base_url + "/flush_cache", timeout=60) + resp.raise_for_status() + + +def _send_generate( + base_url: str, + prompt: str, + *, + max_new_tokens: int, + stop: list[str], + timeout_s: int, +) -> dict: + sampling_params: dict = { + "temperature": 0.0, + "max_new_tokens": int(max_new_tokens), + } + if stop: + sampling_params["stop"] = stop + resp = requests.post( + base_url + "/generate", + json={ + "text": prompt, + "sampling_params": sampling_params, + }, + timeout=int(timeout_s), + ) + resp.raise_for_status() + return resp.json() + + +@dataclass(frozen=True) +class BenchMetrics: + latency_s: float + output_tokens: int + output_toks_per_s: float + accuracy: Optional[float] + invalid_rate: Optional[float] + spec_accept_length: Optional[float] + spec_verify_ct_sum: int + + +def _run_gsm8k_requests( + base_url: str, + *, + prompts: list[str], + labels: Optional[list[int]], + max_new_tokens: int, + concurrency: int, + stop: list[str], + timeout_s: int, + expect_dflash: bool, +) -> BenchMetrics: + if labels is not None and len(labels) != len(prompts): + raise ValueError("labels length must match prompts length") + + start = time.perf_counter() + total_tokens = 0 + spec_verify_ct_sum = 0 + correct = 0 + invalid = 0 + + with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: + futures = { + pool.submit( + _send_generate, + base_url, + prompt, + max_new_tokens=max_new_tokens, + stop=stop, + timeout_s=timeout_s, + ): i + for i, prompt in enumerate(prompts) + } + for fut in as_completed(futures): + i = futures[fut] + out = fut.result() + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + + if labels is not None: + pred = _get_answer_value(out.get("text", "")) + if pred == INVALID: + invalid += 1 + if pred == labels[i]: + correct += 1 + + latency = time.perf_counter() - start + toks_per_s = total_tokens / max(latency, 1e-6) + + if expect_dflash and spec_verify_ct_sum <= 0: + raise RuntimeError( + "DFLASH sanity check failed: did not observe any `spec_verify_ct` in responses " + "(DFLASH may not have been enabled)." + ) + + spec_accept_length = ( + (total_tokens / spec_verify_ct_sum) if spec_verify_ct_sum > 0 else None + ) + + if labels is None: + acc = None + invalid_rate = None + else: + acc = correct / max(len(prompts), 1) + invalid_rate = invalid / max(len(prompts), 1) + + return BenchMetrics( + latency_s=float(latency), + output_tokens=int(total_tokens), + output_toks_per_s=float(toks_per_s), + accuracy=acc, + invalid_rate=invalid_rate, + spec_accept_length=spec_accept_length, + spec_verify_ct_sum=int(spec_verify_ct_sum), + ) + + +def _format_table( + *, + tp_sizes: list[int], + concurrencies: list[int], + values: dict[tuple[int, int], Optional[float]], + float_fmt: str, +) -> str: + header = ["tp\\conc"] + [str(c) for c in concurrencies] + lines = [ + "| " + " | ".join(header) + " |", + "| " + " | ".join(["---"] * len(header)) + " |", + ] + for tp in tp_sizes: + row = [str(tp)] + for c in concurrencies: + v = values.get((tp, c), None) + row.append("N/A" if v is None else format(v, float_fmt)) + lines.append("| " + " | ".join(row) + " |") + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--output-md", type=str, default="dflash_gsm8k_sweep.md") + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--target-model", type=str, default="Qwen/Qwen3-8B") + parser.add_argument("--draft-model", type=str, default="z-lab/Qwen3-8B-DFlash-b16") + parser.add_argument( + "--prompt-style", + type=str, + choices=["fewshot_qa", "chat"], + default="chat", + help="Prompting style: 'chat' matches the DFlash HF demo prompt.", + ) + parser.add_argument("--num-shots", type=int, default=0) + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--timeout-s", type=int, default=3600) + parser.add_argument("--mem-fraction-static", type=float, default=0.75) + parser.add_argument("--disable-radix-cache", action="store_true") + parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--chunked-prefill-size", type=int, default=1024) + parser.add_argument("--max-running-requests", type=int, default=64) + parser.add_argument( + "--tp-sizes", + type=str, + default="1,2,4,8", + help="Comma-separated list, filtered by visible CUDA devices.", + ) + parser.add_argument( + "--concurrencies", + type=str, + default="1,2,4,8,16,32", + help="Comma-separated list of client concurrency levels.", + ) + parser.add_argument( + "--questions-per-concurrency-base", + type=int, + default=128, + help="num_questions = base * concurrency (default matches the sweep plan).", + ) + parser.add_argument( + "--max-questions-per-config", + type=int, + default=1024, + help="Cap num_questions per (tp, concurrency) run (default: 1024).", + ) + parser.add_argument( + "--attention-backends", + type=str, + default="flashinfer,fa3", + help="Comma-separated list. Will auto-skip fa3 on Blackwell/SM<90.", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this sweep.") + + visible_gpus = int(torch.cuda.device_count()) + tp_sizes = [int(x) for x in args.tp_sizes.split(",") if x.strip()] + tp_sizes = [tp for tp in tp_sizes if tp >= 1 and tp <= visible_gpus] + if not tp_sizes: + raise RuntimeError( + f"No tp sizes are runnable with visible_gpus={visible_gpus}. " + "Set CUDA_VISIBLE_DEVICES accordingly." + ) + + concurrencies = [int(x) for x in args.concurrencies.split(",") if x.strip()] + concurrencies = [c for c in concurrencies if c >= 1] + if not concurrencies: + raise RuntimeError("No concurrencies specified.") + + num_questions_by_conc = { + c: min(int(args.questions_per_concurrency_base) * int(c), int(args.max_questions_per_config)) + for c in concurrencies + } + max_questions = max(num_questions_by_conc.values()) + + attention_backends = [s.strip() for s in args.attention_backends.split(",") if s.strip()] + is_blackwell = _is_blackwell() + device_sm = get_device_sm() + if is_blackwell: + attention_backends = [b for b in attention_backends if b == "flashinfer"] + if device_sm < 90: + attention_backends = [b for b in attention_backends if b != "fa3"] + attention_backends = attention_backends or ["flashinfer"] + + data_path = _maybe_download_gsm8k(args.data_path) + lines = list(read_jsonl(data_path)) + if len(lines) < max_questions: + raise RuntimeError( + f"GSM8K file only has {len(lines)} lines, but need {max_questions}." + ) + + tokenizer = None + if args.prompt_style == "chat": + tokenizer = AutoTokenizer.from_pretrained(args.target_model) + + few_shot = ( + _get_few_shot_examples(lines, int(args.num_shots)) + if args.prompt_style == "fewshot_qa" + else "" + ) + + prompts: list[str] = [] + labels: list[int] = [] + for i in range(max_questions): + if args.prompt_style == "fewshot_qa": + prompts.append(few_shot + _get_one_example(lines, i, False)) + else: + assert tokenizer is not None + user_content = ( + lines[i]["question"] + + "\nPlease reason step by step, and put your final answer within \\boxed{}." + ) + prompts.append( + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ) + labels.append(_get_answer_value(lines[i]["answer"])) + if not all(l != INVALID for l in labels): + raise RuntimeError("Invalid labels in GSM8K data.") + + default_stop = ( + ["Question", "Assistant:", "<|separator|>"] if args.prompt_style == "fewshot_qa" else [] + ) + + # Results indexed by (backend, tp, concurrency) for baseline + dflash. + baseline_toks: dict[tuple[str, int, int], Optional[float]] = {} + dflash_toks: dict[tuple[str, int, int], Optional[float]] = {} + dflash_accept_len: dict[tuple[str, int, int], Optional[float]] = {} + baseline_acc: dict[tuple[str, int, int], Optional[float]] = {} + dflash_acc: dict[tuple[str, int, int], Optional[float]] = {} + + for backend in attention_backends: + for tp in tp_sizes: + print(f"\n=== backend={backend} tp={tp} (baseline) ===") + baseline_port = find_available_port(20000) + baseline_url = f"http://127.0.0.1:{baseline_port}" + + common_server_args: list[str] = [ + "--trust-remote-code", + "--attention-backend", + backend, + "--tp-size", + str(tp), + "--dtype", + str(args.dtype), + "--mem-fraction-static", + str(args.mem_fraction_static), + "--max-running-requests", + str(args.max_running_requests), + "--chunked-prefill-size", + str(args.chunked_prefill_size), + "--cuda-graph-bs", + "1", + "2", + "4", + "8", + "16", + "32", + "--cuda-graph-max-bs", + "32", + ] + if args.disable_radix_cache: + common_server_args.append("--disable-radix-cache") + + baseline_proc = popen_launch_server( + args.target_model, + baseline_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=common_server_args, + ) + try: + # Warm up. + _send_generate( + baseline_url, + "Hello", + max_new_tokens=8, + stop=[], + timeout_s=min(int(args.timeout_s), 300), + ) + + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(baseline_url) + metrics = _run_gsm8k_requests( + baseline_url, + prompts=prompts[:n], + labels=labels[:n], + max_new_tokens=int(args.max_new_tokens), + concurrency=int(conc), + stop=default_stop, + timeout_s=int(args.timeout_s), + expect_dflash=False, + ) + baseline_toks[(backend, tp, conc)] = metrics.output_toks_per_s + baseline_acc[(backend, tp, conc)] = metrics.accuracy + print( + f"[baseline] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" + ) + finally: + kill_process_tree(baseline_proc.pid) + try: + baseline_proc.wait(timeout=30) + except Exception: + pass + + print(f"\n=== backend={backend} tp={tp} (DFLASH) ===") + dflash_port = find_available_port(baseline_port + 1) + dflash_url = f"http://127.0.0.1:{dflash_port}" + dflash_proc = popen_launch_server( + args.target_model, + dflash_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + args.draft_model, + ], + ) + try: + _send_generate( + dflash_url, + "Hello", + max_new_tokens=8, + stop=[], + timeout_s=min(int(args.timeout_s), 300), + ) + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(dflash_url) + metrics = _run_gsm8k_requests( + dflash_url, + prompts=prompts[:n], + labels=labels[:n], + max_new_tokens=int(args.max_new_tokens), + concurrency=int(conc), + stop=default_stop, + timeout_s=int(args.timeout_s), + expect_dflash=True, + ) + dflash_toks[(backend, tp, conc)] = metrics.output_toks_per_s + dflash_accept_len[(backend, tp, conc)] = metrics.spec_accept_length + dflash_acc[(backend, tp, conc)] = metrics.accuracy + print( + f"[DFLASH] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f} " + f"accept_len={metrics.spec_accept_length:.3f} " + f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" + ) + finally: + kill_process_tree(dflash_proc.pid) + try: + dflash_proc.wait(timeout=30) + except Exception: + pass + + # Render markdown. + md_lines: list[str] = [] + md_lines.append("# DFLASH GSM8K Sweep") + md_lines.append("") + md_lines.append("## Settings") + md_lines.append(f"- target_model: `{args.target_model}`") + md_lines.append(f"- draft_model: `{args.draft_model}`") + md_lines.append(f"- prompt_style: `{args.prompt_style}`") + if args.prompt_style == "fewshot_qa": + md_lines.append(f"- num_shots: `{args.num_shots}`") + md_lines.append(f"- max_new_tokens: `{args.max_new_tokens}`") + md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`") + md_lines.append(f"- tp_sizes: `{', '.join(str(x) for x in tp_sizes)}`") + md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`") + md_lines.append(f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`") + md_lines.append(f"- device_sm: `{device_sm}`") + md_lines.append(f"- is_blackwell: `{is_blackwell}`") + md_lines.append("") + md_lines.append( + "Note: DFLASH and baseline greedy outputs may diverge on some prompts due to numerical differences " + "(e.g. verify path vs decode path). This sweep focuses on throughput." + ) + md_lines.append("") + + for backend in attention_backends: + md_lines.append(f"## Backend: `{backend}`") + md_lines.append("") + + baseline_values = { + (tp, conc): baseline_toks.get((backend, tp, conc), None) + for tp in tp_sizes + for conc in concurrencies + } + dflash_values = { + (tp, conc): dflash_toks.get((backend, tp, conc), None) + for tp in tp_sizes + for conc in concurrencies + } + speedup_values: dict[tuple[int, int], Optional[float]] = {} + for tp in tp_sizes: + for conc in concurrencies: + b = baseline_values.get((tp, conc), None) + d = dflash_values.get((tp, conc), None) + speedup_values[(tp, conc)] = None if (b is None or d is None or b <= 0) else (d / b) + + md_lines.append("### Baseline output tok/s") + md_lines.append( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values=baseline_values, + float_fmt=",.2f", + ) + ) + md_lines.append("") + md_lines.append("### Baseline accuracy") + md_lines.append( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values={ + (tp, conc): baseline_acc.get((backend, tp, conc), None) + for tp in tp_sizes + for conc in concurrencies + }, + float_fmt=".3f", + ) + ) + md_lines.append("") + md_lines.append("### DFLASH output tok/s") + md_lines.append( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values=dflash_values, + float_fmt=",.2f", + ) + ) + md_lines.append("") + md_lines.append("### DFLASH accuracy") + md_lines.append( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values={ + (tp, conc): dflash_acc.get((backend, tp, conc), None) + for tp in tp_sizes + for conc in concurrencies + }, + float_fmt=".3f", + ) + ) + md_lines.append("") + md_lines.append("### Speedup (DFLASH / baseline)") + md_lines.append( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values=speedup_values, + float_fmt=".3f", + ) + ) + md_lines.append("") + + md_lines.append("### DFLASH acceptance length (completion_tokens / spec_verify_ct)") + md_lines.append( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values={ + (tp, conc): dflash_accept_len.get((backend, tp, conc), None) + for tp in tp_sizes + for conc in concurrencies + }, + float_fmt=".3f", + ) + ) + md_lines.append("") + + with open(args.output_md, "w", encoding="utf-8") as f: + f.write("\n".join(md_lines)) + f.write("\n") + + print(f"\nWrote markdown report to: {args.output_md}") + + +if __name__ == "__main__": + main() From d9685322c8d9857353c50d72c730b6d1fd98169b Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 12 Jan 2026 23:19:58 +0000 Subject: [PATCH 29/73] fix benchmark --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 28 ++++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 1a2af75ffd8c..b98d566ff9d2 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -15,6 +15,7 @@ import os import re import time +import statistics from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import Optional @@ -89,6 +90,8 @@ def _send_generate( ) -> dict: sampling_params: dict = { "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, "max_new_tokens": int(max_new_tokens), } if stop: @@ -133,6 +136,7 @@ def _run_gsm8k_requests( start = time.perf_counter() total_tokens = 0 spec_verify_ct_sum = 0 + spec_accept_lengths: list[float] = [] correct = 0 invalid = 0 @@ -154,6 +158,11 @@ def _run_gsm8k_requests( meta = out.get("meta_info", {}) or {} total_tokens += int(meta.get("completion_tokens", 0)) spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass if labels is not None: pred = _get_answer_value(out.get("text", "")) @@ -172,7 +181,7 @@ def _run_gsm8k_requests( ) spec_accept_length = ( - (total_tokens / spec_verify_ct_sum) if spec_verify_ct_sum > 0 else None + float(statistics.mean(spec_accept_lengths)) if spec_accept_lengths else None ) if labels is None: @@ -233,7 +242,6 @@ def main() -> None: parser.add_argument("--mem-fraction-static", type=float, default=0.75) parser.add_argument("--disable-radix-cache", action="store_true") parser.add_argument("--dtype", type=str, default="bfloat16") - parser.add_argument("--chunked-prefill-size", type=int, default=1024) parser.add_argument("--max-running-requests", type=int, default=64) parser.add_argument( "--tp-sizes", @@ -368,18 +376,10 @@ def main() -> None: str(args.mem_fraction_static), "--max-running-requests", str(args.max_running_requests), - "--chunked-prefill-size", - str(args.chunked_prefill_size), - "--cuda-graph-bs", - "1", - "2", - "4", - "8", - "16", - "32", - "--cuda-graph-max-bs", - "32", ] + common_server_args.extend( + ["--cuda-graph-bs", *[str(i) for i in range(1, 33)], "--cuda-graph-max-bs", "32"] + ) if args.disable_radix_cache: common_server_args.append("--disable-radix-cache") @@ -585,7 +585,7 @@ def main() -> None: ) md_lines.append("") - md_lines.append("### DFLASH acceptance length (completion_tokens / spec_verify_ct)") + md_lines.append("### DFLASH acceptance length (mean per-request spec_accept_length)") md_lines.append( _format_table( tp_sizes=tp_sizes, From 3e4177d97d04f3eb4a3254b91665e9f31483a40f Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 20 Jan 2026 06:11:46 +0000 Subject: [PATCH 30/73] use device tensors for ctx_lens/draft_seq_lens, vectorize kv append and filter_batch, use cpu tensors for seq_lens_sum, only clear padded buffer regions, add --skip-baseline and --batch-requests to benchmark --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 245 ++++++++++++------ python/sglang/srt/managers/schedule_batch.py | 6 +- .../srt/model_executor/cuda_graph_runner.py | 3 +- .../srt/model_executor/input_buffers.py | 12 +- .../sglang/srt/model_executor/model_runner.py | 8 +- python/sglang/srt/speculative/dflash_info.py | 142 +++++++--- .../sglang/srt/speculative/dflash_worker.py | 123 ++++----- 7 files changed, 355 insertions(+), 184 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index b98d566ff9d2..8b772461380d 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -5,7 +5,8 @@ GSM8K workload for each (concurrency, num_questions) setting. Example usage: - ./venv/bin/python benchmark/gsm8k/bench_dflash_gsm8k_sweep.py --output-md dflash_gsm8k_sweep.md + ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py --output-md dflash_gsm8k_sweep.md + ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --concurrencies 32 --tp-sizes 8 """ from __future__ import annotations @@ -108,6 +109,42 @@ def _send_generate( return resp.json() +def _send_generate_batch( + base_url: str, + prompts: list[str], + *, + max_new_tokens: int, + stop: list[str], + timeout_s: int, +) -> list[dict]: + if not prompts: + return [] + sampling_params: dict = { + "temperature": 0.0, + "top_p": 1.0, + "top_k": 1, + "max_new_tokens": int(max_new_tokens), + } + if stop: + sampling_params["stop"] = stop + resp = requests.post( + base_url + "/generate", + json={ + "text": prompts, + "sampling_params": sampling_params, + }, + timeout=int(timeout_s), + ) + resp.raise_for_status() + out = resp.json() + if not isinstance(out, list): + raise RuntimeError( + "Expected a list response for batched /generate, but got " + f"type={type(out).__name__}." + ) + return out + + @dataclass(frozen=True) class BenchMetrics: latency_s: float @@ -126,6 +163,7 @@ def _run_gsm8k_requests( labels: Optional[list[int]], max_new_tokens: int, concurrency: int, + batch_requests: bool, stop: list[str], timeout_s: int, expect_dflash: bool, @@ -140,36 +178,71 @@ def _run_gsm8k_requests( correct = 0 invalid = 0 - with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: - futures = { - pool.submit( - _send_generate, + if batch_requests: + bs = max(int(concurrency), 1) + for start_idx in range(0, len(prompts), bs): + chunk_prompts = prompts[start_idx : start_idx + bs] + chunk_labels = labels[start_idx : start_idx + bs] if labels is not None else None + outs = _send_generate_batch( base_url, - prompt, + chunk_prompts, max_new_tokens=max_new_tokens, stop=stop, timeout_s=timeout_s, - ): i - for i, prompt in enumerate(prompts) - } - for fut in as_completed(futures): - i = futures[fut] - out = fut.result() - meta = out.get("meta_info", {}) or {} - total_tokens += int(meta.get("completion_tokens", 0)) - spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) - if "spec_accept_length" in meta: - try: - spec_accept_lengths.append(float(meta["spec_accept_length"])) - except (TypeError, ValueError): - pass + ) + if len(outs) != len(chunk_prompts): + raise RuntimeError( + "Batched /generate output length mismatch: " + f"got {len(outs)} outputs for {len(chunk_prompts)} prompts." + ) - if labels is not None: - pred = _get_answer_value(out.get("text", "")) - if pred == INVALID: - invalid += 1 - if pred == labels[i]: - correct += 1 + for j, out in enumerate(outs): + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass + + if chunk_labels is not None: + pred = _get_answer_value(out.get("text", "")) + if pred == INVALID: + invalid += 1 + if pred == chunk_labels[j]: + correct += 1 + else: + with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: + futures = { + pool.submit( + _send_generate, + base_url, + prompt, + max_new_tokens=max_new_tokens, + stop=stop, + timeout_s=timeout_s, + ): i + for i, prompt in enumerate(prompts) + } + for fut in as_completed(futures): + i = futures[fut] + out = fut.result() + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass + + if labels is not None: + pred = _get_answer_value(out.get("text", "")) + if pred == INVALID: + invalid += 1 + if pred == labels[i]: + correct += 1 latency = time.perf_counter() - start toks_per_s = total_tokens / max(latency, 1e-6) @@ -225,10 +298,25 @@ def _format_table( def main() -> None: parser = argparse.ArgumentParser() - parser.add_argument("--output-md", type=str, default="dflash_gsm8k_sweep.md") + parser.add_argument( + "--output-md", + type=str, + default=None, + help="Write a markdown report to this file (disabled by default).", + ) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--target-model", type=str, default="Qwen/Qwen3-8B") parser.add_argument("--draft-model", type=str, default="z-lab/Qwen3-8B-DFlash-b16") + parser.add_argument( + "--skip-baseline", + action="store_true", + help="Skip running the baseline (target-only) sweep; only run DFLASH and report N/A for baseline/speedup.", + ) + parser.add_argument( + "--batch-requests", + action="store_true", + help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.", + ) parser.add_argument( "--prompt-style", type=str, @@ -360,9 +448,7 @@ def main() -> None: for backend in attention_backends: for tp in tp_sizes: - print(f"\n=== backend={backend} tp={tp} (baseline) ===") - baseline_port = find_available_port(20000) - baseline_url = f"http://127.0.0.1:{baseline_port}" + port_base = find_available_port(20000) common_server_args: list[str] = [ "--trust-remote-code", @@ -383,52 +469,57 @@ def main() -> None: if args.disable_radix_cache: common_server_args.append("--disable-radix-cache") - baseline_proc = popen_launch_server( - args.target_model, - baseline_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=common_server_args, - ) - try: - # Warm up. - _send_generate( + if not args.skip_baseline: + print(f"\n=== backend={backend} tp={tp} (baseline) ===") + baseline_port = port_base + baseline_url = f"http://127.0.0.1:{baseline_port}" + baseline_proc = popen_launch_server( + args.target_model, baseline_url, - "Hello", - max_new_tokens=8, - stop=[], - timeout_s=min(int(args.timeout_s), 300), + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=common_server_args, ) - - for conc in concurrencies: - n = num_questions_by_conc[conc] - _flush_cache(baseline_url) - metrics = _run_gsm8k_requests( + try: + # Warm up. + _send_generate( baseline_url, - prompts=prompts[:n], - labels=labels[:n], - max_new_tokens=int(args.max_new_tokens), - concurrency=int(conc), - stop=default_stop, - timeout_s=int(args.timeout_s), - expect_dflash=False, + "Hello", + max_new_tokens=8, + stop=[], + timeout_s=min(int(args.timeout_s), 300), ) - baseline_toks[(backend, tp, conc)] = metrics.output_toks_per_s - baseline_acc[(backend, tp, conc)] = metrics.accuracy - print( - f"[baseline] conc={conc:>2} n={n:<4} " - f"toks/s={metrics.output_toks_per_s:,.2f} " - f"latency={metrics.latency_s:.1f}s " - f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" - ) - finally: - kill_process_tree(baseline_proc.pid) - try: - baseline_proc.wait(timeout=30) - except Exception: - pass + + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(baseline_url) + metrics = _run_gsm8k_requests( + baseline_url, + prompts=prompts[:n], + labels=labels[:n], + max_new_tokens=int(args.max_new_tokens), + concurrency=int(conc), + batch_requests=bool(args.batch_requests), + stop=default_stop, + timeout_s=int(args.timeout_s), + expect_dflash=False, + ) + baseline_toks[(backend, tp, conc)] = metrics.output_toks_per_s + baseline_acc[(backend, tp, conc)] = metrics.accuracy + print( + f"[baseline] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" + ) + finally: + kill_process_tree(baseline_proc.pid) + try: + baseline_proc.wait(timeout=30) + except Exception: + pass print(f"\n=== backend={backend} tp={tp} (DFLASH) ===") - dflash_port = find_available_port(baseline_port + 1) + dflash_port = find_available_port(port_base + 1) dflash_url = f"http://127.0.0.1:{dflash_port}" dflash_proc = popen_launch_server( args.target_model, @@ -459,6 +550,7 @@ def main() -> None: labels=labels[:n], max_new_tokens=int(args.max_new_tokens), concurrency=int(conc), + batch_requests=bool(args.batch_requests), stop=default_stop, timeout_s=int(args.timeout_s), expect_dflash=True, @@ -498,6 +590,7 @@ def main() -> None: md_lines.append(f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`") md_lines.append(f"- device_sm: `{device_sm}`") md_lines.append(f"- is_blackwell: `{is_blackwell}`") + md_lines.append(f"- skip_baseline: `{bool(args.skip_baseline)}`") md_lines.append("") md_lines.append( "Note: DFLASH and baseline greedy outputs may diverge on some prompts due to numerical differences " @@ -600,11 +693,13 @@ def main() -> None: ) md_lines.append("") - with open(args.output_md, "w", encoding="utf-8") as f: - f.write("\n".join(md_lines)) - f.write("\n") - - print(f"\nWrote markdown report to: {args.output_md}") + if args.output_md: + with open(args.output_md, "w", encoding="utf-8") as f: + f.write("\n".join(md_lines)) + f.write("\n") + print(f"\nWrote markdown report to: {args.output_md}") + else: + print("\nMarkdown report disabled (pass --output-md to write one).") if __name__ == "__main__": diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5f168595c5e5..21fdb7ce59ec 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2024,7 +2024,11 @@ def filter_batch( self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None - self.seq_lens_sum = self.seq_lens.sum().item() + # Use CPU copy to avoid GPU sync. + if self.seq_lens_cpu is not None: + self.seq_lens_sum = int(self.seq_lens_cpu.sum().item()) + else: + self.seq_lens_sum = int(self.seq_lens.sum().item()) if self.output_ids is not None: self.output_ids = self.output_ids[keep_indices_device] diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 7b0a84afb2dd..c2c436f915cc 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -826,8 +826,7 @@ def replay_prepare( and forward_batch.input_embeds is not None ): buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds) - if bs != raw_bs: - buffers.input_embeds[raw_num_token : bs * self.num_tokens_per_bs].zero_() + # Padded tokens aren't read, so skip zeroing them. if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, diff --git a/python/sglang/srt/model_executor/input_buffers.py b/python/sglang/srt/model_executor/input_buffers.py index f4468a70c634..4d5bd2af56dd 100644 --- a/python/sglang/srt/model_executor/input_buffers.py +++ b/python/sglang/srt/model_executor/input_buffers.py @@ -145,12 +145,14 @@ def populate_from_forward_batch( pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Optional[torch.Tensor]: if bs != raw_bs: - self.seq_lens.fill_(seq_len_fill_value) - self.out_cache_loc.zero_() + # Only clear the padded region, not the entire buffer. + num_tokens = int(bs) * int(num_tokens_per_bs) + self.seq_lens[raw_bs:bs].fill_(seq_len_fill_value) + self.out_cache_loc[raw_num_token:num_tokens].zero_() if self.mamba_track_indices is not None: - self.mamba_track_indices.zero_() + self.mamba_track_indices[raw_bs:bs].zero_() if self.mamba_track_mask is not None: - self.mamba_track_mask.fill_(False) + self.mamba_track_mask[raw_bs:bs].fill_(False) # Common inputs self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) @@ -173,7 +175,7 @@ def populate_from_forward_batch( seq_lens_cpu: Optional[torch.Tensor] = None if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: - self.seq_lens_cpu.fill_(seq_len_fill_value) + self.seq_lens_cpu[raw_bs:bs].fill_(seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) seq_lens_cpu = self.seq_lens_cpu[:bs] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 08acb1b31ccc..4f7a72ba78a2 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1923,6 +1923,12 @@ def get_spec_info(): else: lora_ids = None + # Use CPU copy to avoid GPU sync. + if buffers.seq_lens_cpu is not None: + seq_lens_sum = int(buffers.seq_lens_cpu.sum().item()) + else: + seq_lens_sum = int(buffers.seq_lens.sum().item()) + forward_batch = ForwardBatch( forward_mode=capture_forward_mode, batch_size=batch_size, @@ -1936,7 +1942,7 @@ def get_spec_info(): token_to_kv_pool=self.token_to_kv_pool, attn_backend=self.attn_backend, out_cache_loc=buffers.out_cache_loc, - seq_lens_sum=buffers.seq_lens.sum().item(), + seq_lens_sum=seq_lens_sum, encoder_lens=buffers.encoder_lens, return_logprob=False, positions=buffers.positions, diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 9c09d159a517..5e2ceefd2261 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -41,12 +41,12 @@ class DFlashDraftInput(SpecInput): # or default K == draft_num_layers for existing checkpoints). target_hidden: torch.Tensor - # Context lengths on CPU, one per request. Used to slice `target_hidden`. - ctx_lens_cpu: List[int] + # Context lengths per request, used to slice `target_hidden`. Device tensor (int32). + ctx_lens: torch.Tensor - # Native implementation: how many tokens are already materialized in the draft KV cache. - # The next draft step appends `ctx_lens_cpu[i]` tokens starting at `draft_seq_lens_cpu[i]`. - draft_seq_lens_cpu: List[int] + # How many tokens are already in the draft KV cache per request. + # The next draft step appends ctx_lens[i] tokens starting at draft_seq_lens[i]. + draft_seq_lens: torch.Tensor def __post_init__(self): super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) @@ -56,36 +56,49 @@ def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: return (1, 1) def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): - keep_indices = new_indices.tolist() - - old_ctx_lens_cpu = self.ctx_lens_cpu + old_ctx_lens = self.ctx_lens old_target_hidden = self.target_hidden self.verified_id = self.verified_id[new_indices] - self.ctx_lens_cpu = [old_ctx_lens_cpu[i] for i in keep_indices] - self.draft_seq_lens_cpu = [self.draft_seq_lens_cpu[i] for i in keep_indices] + self.ctx_lens = old_ctx_lens[new_indices] + self.draft_seq_lens = self.draft_seq_lens[new_indices] if old_target_hidden is None or old_target_hidden.numel() == 0: self.target_hidden = old_target_hidden return - old_offsets: List[int] = [0] - for ln in old_ctx_lens_cpu: - old_offsets.append(old_offsets[-1] + int(ln)) + # Rebuild target_hidden for the filtered batch using vectorized indexing. + old_bs = int(old_ctx_lens.shape[0]) + offsets = torch.zeros( + (old_bs + 1,), dtype=torch.int64, device=old_ctx_lens.device + ) + offsets[1:].copy_(old_ctx_lens.to(torch.int64).cumsum(0)) - segments: List[torch.Tensor] = [] - for idx in keep_indices: - ln = int(old_ctx_lens_cpu[idx]) - if ln == 0: - continue - segments.append(old_target_hidden[old_offsets[idx] : old_offsets[idx + 1]]) + start = offsets[:-1] + seg_start = start[new_indices] + seg_lens = old_ctx_lens[new_indices].to(torch.int64) - self.target_hidden = torch.cat(segments, dim=0) if segments else old_target_hidden[:0] + max_len = int(seg_lens.max().item()) if seg_lens.numel() > 0 else 0 + if max_len <= 0: + self.target_hidden = old_target_hidden[:0] + return + + r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[None, :] + pos2d = seg_start[:, None] + r + mask = r < seg_lens[:, None] + flat_pos = pos2d[mask] + self.target_hidden = ( + old_target_hidden.index_select(0, flat_pos) + if flat_pos.numel() > 0 + else old_target_hidden[:0] + ) def merge_batch(self, spec_info: "DFlashDraftInput"): self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) - self.ctx_lens_cpu.extend(spec_info.ctx_lens_cpu) - self.draft_seq_lens_cpu.extend(spec_info.draft_seq_lens_cpu) + self.ctx_lens = torch.cat([self.ctx_lens, spec_info.ctx_lens], dim=0) + self.draft_seq_lens = torch.cat( + [self.draft_seq_lens, spec_info.draft_seq_lens], dim=0 + ) if self.target_hidden is None or self.target_hidden.numel() == 0: self.target_hidden = spec_info.target_hidden elif spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0: @@ -285,35 +298,80 @@ def verify( target_predict=target_predict, ) - packed = torch.empty( - (bs, self.draft_token_num + 2), dtype=torch.int64, device=device - ) - packed[:, : self.draft_token_num].copy_(candidates) - packed[:, self.draft_token_num].copy_(accept_len) - packed[:, self.draft_token_num + 1].copy_(bonus) - packed_cpu = packed.cpu() + # Build output tokens on GPU: accepted drafts + bonus token. + out_lens = accept_len.to(torch.int32) + 1 + accept_len_i64 = accept_len.to(torch.int64) + + out_tokens = torch.empty((bs, self.draft_token_num), dtype=torch.int64, device=device) + if int(self.draft_token_num) > 1: + out_tokens[:, : self.draft_token_num - 1].copy_(candidates[:, 1:]) + out_tokens[:, self.draft_token_num - 1].fill_(0) + out_tokens.scatter_(1, accept_len_i64[:, None], bonus[:, None]) - candidates_cpu = packed_cpu[:, : self.draft_token_num].tolist() - accept_len_cpu = packed_cpu[:, self.draft_token_num].tolist() - bonus_cpu = packed_cpu[:, self.draft_token_num + 1].tolist() + out_tokens_cpu = out_tokens.cpu() + out_lens_cpu = out_lens.cpu() commit_lens_cpu: List[int] = [] new_verified_cpu: List[int] = [] accept_length_per_req_cpu: List[int] = [] for i, req in enumerate(batch.reqs): - # Proposed: accepted draft tokens, then the bonus token. - proposed = candidates_cpu[i][1 : 1 + accept_len_cpu[i]] + [bonus_cpu[i]] + proposed_len = int(out_lens_cpu[i]) + proposed = out_tokens_cpu[i, :proposed_len].tolist() appended = 0 - for tok in proposed: - req.output_ids.append(int(tok)) - appended += 1 - req.check_finished() - if req.finished(): - break - if req.grammar is not None: - req.grammar.accept_token(int(tok)) + if ( + req.grammar is None + and not req.sampling_params.stop_strs + and not req.sampling_params.stop_regex_strs + ): + remaining = int(req.sampling_params.max_new_tokens) - len(req.output_ids) + if remaining > 0: + tokens = proposed[:remaining] + if not req.sampling_params.ignore_eos: + stop_token_ids = req.sampling_params.stop_token_ids + eos_token_ids = req.eos_token_ids + tokenizer = req.tokenizer + tokenizer_eos = tokenizer.eos_token_id if tokenizer is not None else None + additional_stop = ( + tokenizer.additional_stop_token_ids + if tokenizer is not None + else None + ) + vocab_size = getattr(req, "vocab_size", None) + + for j, token_id in enumerate(tokens): + if vocab_size is not None and ( + int(token_id) > int(vocab_size) or int(token_id) < 0 + ): + tokens = tokens[: j + 1] + break + if stop_token_ids and token_id in stop_token_ids: + tokens = tokens[: j + 1] + break + if eos_token_ids and token_id in eos_token_ids: + tokens = tokens[: j + 1] + break + if tokenizer_eos is not None and int(token_id) == int(tokenizer_eos): + tokens = tokens[: j + 1] + break + if additional_stop and token_id in additional_stop: + tokens = tokens[: j + 1] + break + + req.output_ids.extend(int(tok) for tok in tokens) + appended = len(tokens) + if appended > 0: + req.check_finished(new_accepted_len=appended) + else: + for tok in proposed: + req.output_ids.append(int(tok)) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(int(tok)) # DFlash always treats the last appended token as the new "current token" # (uncommitted); therefore we commit exactly `appended` verify-input tokens. diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 75b6e03c224e..8205c38fc970 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -566,31 +566,16 @@ def _append_target_hidden_to_draft_kv( if draft_input.target_hidden is None: raise RuntimeError("DFLASH draft state missing target_hidden context features.") - if len(draft_input.ctx_lens_cpu) != bs: + if draft_input.ctx_lens.numel() != bs: raise RuntimeError( - f"DFLASH ctx_lens_cpu length mismatch: got {len(draft_input.ctx_lens_cpu)} for bs={bs}." + f"DFLASH ctx_lens length mismatch: got {draft_input.ctx_lens.numel()} for bs={bs}." ) - if len(draft_input.draft_seq_lens_cpu) != bs: + if draft_input.draft_seq_lens.numel() != bs: raise RuntimeError( - "DFLASH draft_seq_lens_cpu length mismatch: " - f"got {len(draft_input.draft_seq_lens_cpu)} for bs={bs}." + f"DFLASH draft_seq_lens length mismatch: got {draft_input.draft_seq_lens.numel()} for bs={bs}." ) - # Invariant: draft_seq_len + ctx_len == current target prefix length. - start_pos_cpu = batch.seq_lens_cpu.tolist() - for cache_len, ctx_len, start_pos in zip( - draft_input.draft_seq_lens_cpu, - draft_input.ctx_lens_cpu, - start_pos_cpu, - strict=True, - ): - if int(cache_len) + int(ctx_len) != int(start_pos): - raise RuntimeError( - "DFLASH draft cache length mismatch. " - f"cache_len={int(cache_len)}, ctx_len={int(ctx_len)}, start_pos={int(start_pos)}." - ) - - total_ctx = int(sum(int(x) for x in draft_input.ctx_lens_cpu)) + total_ctx = int(draft_input.target_hidden.shape[0]) if total_ctx <= 0: return @@ -600,45 +585,58 @@ def _append_target_hidden_to_draft_kv( if req_pool_indices.dtype != torch.int64: req_pool_indices = req_pool_indices.to(torch.int64) - ctx_cache_loc_chunks: List[torch.Tensor] = [] - ctx_positions_chunks: List[torch.Tensor] = [] - new_draft_seq_lens_cpu: List[int] = [] - for i, (cache_len, ctx_len) in enumerate( - zip( - draft_input.draft_seq_lens_cpu, - draft_input.ctx_lens_cpu, - strict=True, - ) - ): - cache_len_i = int(cache_len) - ctx_len_i = int(ctx_len) - new_draft_seq_lens_cpu.append(cache_len_i + ctx_len_i) - if ctx_len_i <= 0: - continue - s = cache_len_i - e = cache_len_i + ctx_len_i - req_pool_idx = req_pool_indices[i] - ctx_cache_loc_chunks.append(req_to_token[req_pool_idx, s:e].to(torch.int64)) - ctx_positions_chunks.append( - torch.arange(s, e, device=device, dtype=torch.int64) - ) + ctx_lens = draft_input.ctx_lens + draft_seq_lens = draft_input.draft_seq_lens + if ctx_lens.dtype != torch.int32: + ctx_lens = ctx_lens.to(torch.int32) + if draft_seq_lens.dtype != torch.int32: + draft_seq_lens = draft_seq_lens.to(torch.int32) + if ctx_lens.device != device: + ctx_lens = ctx_lens.to(device, non_blocking=True) + if draft_seq_lens.device != device: + draft_seq_lens = draft_seq_lens.to(device, non_blocking=True) + + if bs == 1: + # Fast path for single request. + max_ctx = int(total_ctx) + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + pos2d = draft_seq_lens.to(torch.int64)[:, None] + r[None, :] # [1, ctx] + cache2d = req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] + ctx_cache_loc = cache2d.reshape(-1).to(torch.int64) # [ctx] + ctx_positions = pos2d.reshape(-1) # [ctx] + else: + # In decode mode, ctx_lens <= block_size so we can skip the .item() sync. + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + max_ctx = int(ctx_lens.max().item()) + else: + max_ctx = int(self.block_size) + if max_ctx <= 0: + raise RuntimeError(f"DFLASH invalid max_ctx={max_ctx} for KV append.") - ctx_cache_loc = ( - torch.cat(ctx_cache_loc_chunks, dim=0) - if ctx_cache_loc_chunks - else torch.empty((0,), dtype=torch.int64, device=device) - ) + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + r = r[None, :] # [1, max_ctx] + pos2d = draft_seq_lens.to(torch.int64)[:, None] + r # [bs, max_ctx] + mask = r < ctx_lens[:, None] - ctx_positions = ( - torch.cat(ctx_positions_chunks, dim=0) - if ctx_positions_chunks - else torch.empty((0,), dtype=torch.int64, device=device) - ) + # Batched gather of cache locations and positions. + cache2d = req_to_token[req_pool_indices[:, None], pos2d] # [bs, max_ctx] + ctx_cache_loc = cache2d[mask].to(torch.int64) # [sum(ctx_lens)] + ctx_positions = pos2d[mask] # [sum(ctx_lens)] with torch.inference_mode(): ctx_hidden = self.draft_model.project_target_hidden( draft_input.target_hidden ) # [sum(ctx), hidden] + if ctx_hidden.shape[0] != ctx_cache_loc.numel(): + raise RuntimeError( + f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." + ) for layer in self.draft_model.layers: attn = layer.self_attn @@ -656,8 +654,8 @@ def _append_target_hidden_to_draft_kv( attn.attn.v_scale, ) - draft_input.draft_seq_lens_cpu = new_draft_seq_lens_cpu - draft_input.ctx_lens_cpu = [0] * bs + draft_input.draft_seq_lens = draft_seq_lens + ctx_lens + draft_input.ctx_lens = torch.zeros_like(ctx_lens) draft_input.target_hidden = draft_input.target_hidden[:0] def forward_batch_generation( @@ -696,15 +694,24 @@ def forward_batch_generation( # Materialize the prompt tokens into the draft KV cache immediately. This is required # for radix cache support, since the scheduler may update radix after prefill returns. + device = next_token_ids.device + + def _to_int32_device_tensor(x, *, device=device): + if isinstance(x, torch.Tensor): + if x.device != device: + x = x.to(device, non_blocking=True) + return x if x.dtype == torch.int32 else x.to(torch.int32) + return torch.tensor(x, dtype=torch.int32, device=device) + draft_input = DFlashDraftInput( verified_id=next_token_ids.to(torch.int64), target_hidden=logits_output.hidden_states, - ctx_lens_cpu=[int(x) for x in model_worker_batch.extend_seq_lens], - draft_seq_lens_cpu=[int(x) for x in model_worker_batch.extend_prefix_lens], + ctx_lens=_to_int32_device_tensor(model_worker_batch.extend_seq_lens), + draft_seq_lens=_to_int32_device_tensor(model_worker_batch.extend_prefix_lens), ) self._append_target_hidden_to_draft_kv(batch, draft_input) batch.spec_info = draft_input - for req, draft_len in zip(batch.reqs, draft_input.draft_seq_lens_cpu, strict=True): + for req, draft_len in zip(batch.reqs, batch.seq_lens_cpu, strict=True): req.dflash_draft_seq_len = int(draft_len) return GenerationBatchResult( @@ -752,7 +759,7 @@ def forward_batch_generation( # into the draft KV cache immediately so radix cache entries are safe to reuse. draft_input.verified_id = new_verified_id draft_input.target_hidden = next_target_hidden - draft_input.ctx_lens_cpu = commit_lens.cpu().tolist() + draft_input.ctx_lens = commit_lens self._append_target_hidden_to_draft_kv(batch, draft_input) batch.spec_info = draft_input batch.forward_mode = ForwardMode.DECODE From b5b4bd60c4c2063f96218b334d2e2099dae88052 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 20 Jan 2026 06:13:30 +0000 Subject: [PATCH 31/73] precommit fixes --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 38 ++++++-- .../layers/attention/flashinfer_backend.py | 3 +- python/sglang/srt/managers/scheduler.py | 1 - .../scheduler_output_processor_mixin.py | 4 +- .../srt/model_executor/cuda_graph_runner.py | 17 +++- .../sglang/srt/model_executor/model_runner.py | 19 +++- python/sglang/srt/models/dflash.py | 37 +++++--- python/sglang/srt/models/qwen3.py | 4 +- python/sglang/srt/server_args.py | 23 ++--- python/sglang/srt/speculative/dflash_info.py | 62 +++++++++---- python/sglang/srt/speculative/dflash_utils.py | 5 +- .../sglang/srt/speculative/dflash_worker.py | 93 ++++++++++++++----- python/sglang/srt/speculative/spec_info.py | 9 +- 13 files changed, 226 insertions(+), 89 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 8b772461380d..5d34ddb06cd3 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -15,8 +15,8 @@ import ast import os import re -import time import statistics +import time from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from typing import Optional @@ -182,7 +182,9 @@ def _run_gsm8k_requests( bs = max(int(concurrency), 1) for start_idx in range(0, len(prompts), bs): chunk_prompts = prompts[start_idx : start_idx + bs] - chunk_labels = labels[start_idx : start_idx + bs] if labels is not None else None + chunk_labels = ( + labels[start_idx : start_idx + bs] if labels is not None else None + ) outs = _send_generate_batch( base_url, chunk_prompts, @@ -381,12 +383,17 @@ def main() -> None: raise RuntimeError("No concurrencies specified.") num_questions_by_conc = { - c: min(int(args.questions_per_concurrency_base) * int(c), int(args.max_questions_per_config)) + c: min( + int(args.questions_per_concurrency_base) * int(c), + int(args.max_questions_per_config), + ) for c in concurrencies } max_questions = max(num_questions_by_conc.values()) - attention_backends = [s.strip() for s in args.attention_backends.split(",") if s.strip()] + attention_backends = [ + s.strip() for s in args.attention_backends.split(",") if s.strip() + ] is_blackwell = _is_blackwell() device_sm = get_device_sm() if is_blackwell: @@ -436,7 +443,9 @@ def main() -> None: raise RuntimeError("Invalid labels in GSM8K data.") default_stop = ( - ["Question", "Assistant:", "<|separator|>"] if args.prompt_style == "fewshot_qa" else [] + ["Question", "Assistant:", "<|separator|>"] + if args.prompt_style == "fewshot_qa" + else [] ) # Results indexed by (backend, tp, concurrency) for baseline + dflash. @@ -464,7 +473,12 @@ def main() -> None: str(args.max_running_requests), ] common_server_args.extend( - ["--cuda-graph-bs", *[str(i) for i in range(1, 33)], "--cuda-graph-max-bs", "32"] + [ + "--cuda-graph-bs", + *[str(i) for i in range(1, 33)], + "--cuda-graph-max-bs", + "32", + ] ) if args.disable_radix_cache: common_server_args.append("--disable-radix-cache") @@ -587,7 +601,9 @@ def main() -> None: md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`") md_lines.append(f"- tp_sizes: `{', '.join(str(x) for x in tp_sizes)}`") md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`") - md_lines.append(f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`") + md_lines.append( + f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`" + ) md_lines.append(f"- device_sm: `{device_sm}`") md_lines.append(f"- is_blackwell: `{is_blackwell}`") md_lines.append(f"- skip_baseline: `{bool(args.skip_baseline)}`") @@ -617,7 +633,9 @@ def main() -> None: for conc in concurrencies: b = baseline_values.get((tp, conc), None) d = dflash_values.get((tp, conc), None) - speedup_values[(tp, conc)] = None if (b is None or d is None or b <= 0) else (d / b) + speedup_values[(tp, conc)] = ( + None if (b is None or d is None or b <= 0) else (d / b) + ) md_lines.append("### Baseline output tok/s") md_lines.append( @@ -678,7 +696,9 @@ def main() -> None: ) md_lines.append("") - md_lines.append("### DFLASH acceptance length (mean per-request spec_accept_length)") + md_lines.append( + "### DFLASH acceptance length (mean per-request spec_accept_length)" + ) md_lines.append( _format_table( tp_sizes=tp_sizes, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index f3e82ced943f..25c690db35ed 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -589,7 +589,8 @@ def init_forward_metadata_capture_cuda_graph( # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise # FlashInfer will treat the (zero) buffer as a real mask and block attention. use_custom_mask = ( - spec_info is not None and getattr(spec_info, "custom_mask", None) is not None + spec_info is not None + and getattr(spec_info, "custom_mask", None) is not None ) prefill_wrappers = [] for i in range(self.num_wrappers): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 47f09ddcdf58..e98a58964034 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -158,7 +158,6 @@ from sglang.srt.managers.session_controller import Session from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length from sglang.srt.mem_cache.cache_init_params import CacheInitParams -from sglang.srt.mem_cache.common import release_kv_cache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.multiplex.multiplexing_mixin import SchedulerMultiplexMixin diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index d6a8d1589e6b..4062f69a499a 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -44,7 +44,9 @@ class SchedulerOutputProcessorMixin: We put them into a separate file to make the `scheduler.py` shorter. """ - def _release_kv_cache_and_draft(self: Scheduler, req: Req, *, is_insert: bool = True): + def _release_kv_cache_and_draft( + self: Scheduler, req: Req, *, is_insert: bool = True + ): release_kv_cache(req, self.tree_cache, is_insert=is_insert) draft_worker = getattr(self, "draft_worker", None) hook = getattr(draft_worker, "on_req_finished", None) if draft_worker else None diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c2c436f915cc..b624dddf309e 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -288,7 +288,9 @@ def __init__(self, model_runner: ModelRunner): if not self.model_runner.spec_algorithm.is_dflash(): raise RuntimeError("This should not happen") self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = self.model_runner.server_args.speculative_num_draft_tokens + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) elif self.is_dllm: self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size @@ -356,7 +358,10 @@ def __init__(self, model_runner: ModelRunner): and model_runner.eagle_use_aux_hidden_state ): self.model_runner.model.set_eagle3_layers_to_capture() - if model_runner.spec_algorithm.is_dflash() and model_runner.dflash_use_aux_hidden_state: + if ( + model_runner.spec_algorithm.is_dflash() + and model_runner.dflash_use_aux_hidden_state + ): if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): raise ValueError( f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " @@ -955,9 +960,11 @@ def get_spec_info(self, num_tokens: int): draft_token=None, positions=None, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, - custom_mask=None - if (self.model_runner.is_draft_worker or skip_custom_mask) - else self.buffers.custom_mask, + custom_mask=( + None + if (self.model_runner.is_draft_worker or skip_custom_mask) + else self.buffers.custom_mask + ), capture_hidden_mode=( CaptureHiddenMode.NULL if self.model_runner.is_draft_worker diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4f7a72ba78a2..67a432f9e36f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -137,8 +137,8 @@ get_global_server_args, set_global_server_args_for_scheduler, ) -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.dflash_utils import resolve_dflash_target_layer_ids +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, cpu_has_amx_support, @@ -352,16 +352,25 @@ def __init__( model_revision=server_args.speculative_draft_model_revision, is_draft_model=True, ) - draft_num_layers = getattr(draft_model_config.hf_config, "num_hidden_layers", None) - target_num_layers = getattr(self.model_config.hf_config, "num_hidden_layers", None) + draft_num_layers = getattr( + draft_model_config.hf_config, "num_hidden_layers", None + ) + target_num_layers = getattr( + self.model_config.hf_config, "num_hidden_layers", None + ) if draft_num_layers is None or target_num_layers is None: raise ValueError( "DFLASH requires both draft and target to expose num_hidden_layers in config. " f"Got draft={draft_num_layers}, target={target_num_layers}." ) - trained_target_layers = getattr(draft_model_config.hf_config, "num_target_layers", None) - if trained_target_layers is not None and trained_target_layers != target_num_layers: + trained_target_layers = getattr( + draft_model_config.hf_config, "num_target_layers", None + ) + if ( + trained_target_layers is not None + and trained_target_layers != target_num_layers + ): logger.warning( "DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; " "selecting capture layers based on the runtime target model.", diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 60c9cb540aab..aab2a2d3f1ff 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -14,12 +14,12 @@ from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -37,7 +37,9 @@ def __init__(self, config, layer_id: int) -> None: hidden_size = int(config.hidden_size) tp_size = int(get_tensor_model_parallel_world_size()) total_num_heads = int(config.num_attention_heads) - total_num_kv_heads = int(getattr(config, "num_key_value_heads", total_num_heads)) + total_num_kv_heads = int( + getattr(config, "num_key_value_heads", total_num_heads) + ) head_dim = int(getattr(config, "head_dim", hidden_size // total_num_heads)) self.hidden_size = hidden_size @@ -129,17 +131,23 @@ def forward( output, _ = self.o_proj(attn_output) return output - def kv_proj_only(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def kv_proj_only( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: """Project hidden_states to K/V only (skip Q). This is used by DFlash to materialize ctx tokens into the draft KV cache: we only need K/V for the cached tokens; Q is never consumed. """ # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. - if isinstance(getattr(self.qkv_proj, "quant_method", None), UnquantizedLinearMethod): + if isinstance( + getattr(self.qkv_proj, "quant_method", None), UnquantizedLinearMethod + ): kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) weight = self.qkv_proj.weight[kv_slice] - bias = self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + bias = ( + self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + ) kv = F.linear(hidden_states, weight, bias) k, v = kv.split([self.kv_size, self.kv_size], dim=-1) return k, v @@ -167,7 +175,9 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: hidden_size = int(config.hidden_size) intermediate_size = int(getattr(config, "intermediate_size", 0)) if intermediate_size <= 0: - raise ValueError(f"Invalid intermediate_size={intermediate_size} for DFlash MLP.") + raise ValueError( + f"Invalid intermediate_size={intermediate_size} for DFlash MLP." + ) self.gate_up_proj = MergedColumnParallelLinear( hidden_size, @@ -295,7 +305,10 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: ) if block_size is None: block_size = 16 - elif getattr(config, "block_size", None) is not None and dflash_block_size is not None: + elif ( + getattr(config, "block_size", None) is not None + and dflash_block_size is not None + ): if int(dflash_block_size) != int(getattr(config, "block_size")): logger.warning( "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; using dflash_config.block_size.", @@ -339,7 +352,9 @@ def forward( residual: Optional[torch.Tensor] = None for layer in self.layers: - hidden_states, residual = layer(positions, hidden_states, forward_batch, residual) + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) if hidden_states.numel() == 0: return hidden_states @@ -391,9 +406,9 @@ def resolve_param_name(name: str) -> Optional[str]: # Ignore unexpected weights (e.g., HF rotary caches). continue param = params_dict[resolved_name] - if resolved_name.endswith("fc.weight") and tuple(loaded_weight.shape) != tuple( - param.shape - ): + if resolved_name.endswith("fc.weight") and tuple( + loaded_weight.shape + ) != tuple(param.shape): raise ValueError( "DFLASH fc.weight shape mismatch. This usually means the draft checkpoint's " "number of context features (K) does not match this config. " diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index 62f5fbc33d76..b7e23fc73df5 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -593,7 +593,9 @@ def set_dflash_layers_to_capture(self, layer_ids: List[int]): return if layer_ids is None: - raise ValueError("DFLASH requires explicit layer_ids for aux hidden capture.") + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) self.capture_aux_hidden_states = True # SGLang captures "before layer i". To capture the hidden state after target diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 68d1c7aead8f..af2401228257 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2116,18 +2116,18 @@ def _handle_speculative_decoding(self): "DFLASH requires --speculative-dflash-block-size to be positive, " f"got {self.speculative_dflash_block_size}." ) - if ( - self.speculative_num_draft_tokens is not None - and int(self.speculative_num_draft_tokens) - != int(self.speculative_dflash_block_size) - ): + if self.speculative_num_draft_tokens is not None and int( + self.speculative_num_draft_tokens + ) != int(self.speculative_dflash_block_size): raise ValueError( "Both --speculative-num-draft-tokens and --speculative-dflash-block-size are set " "but they differ. For DFLASH they must match. " f"speculative_num_draft_tokens={self.speculative_num_draft_tokens}, " f"speculative_dflash_block_size={self.speculative_dflash_block_size}." ) - self.speculative_num_draft_tokens = int(self.speculative_dflash_block_size) + self.speculative_num_draft_tokens = int( + self.speculative_dflash_block_size + ) if self.speculative_num_draft_tokens is None: inferred_block_size = None @@ -2139,7 +2139,9 @@ def _handle_speculative_decoding(self): if os.path.isfile(draft_config_path): with open(draft_config_path, "r") as f: draft_config_json = json.load(f) - top_level_block_size = draft_config_json.get("block_size", None) + top_level_block_size = draft_config_json.get( + "block_size", None + ) dflash_cfg = draft_config_json.get("dflash_config", None) dflash_block_size = ( dflash_cfg.get("block_size", None) @@ -2149,10 +2151,9 @@ def _handle_speculative_decoding(self): if dflash_block_size is not None: inferred_block_size = dflash_block_size - if ( - top_level_block_size is not None - and int(dflash_block_size) != int(top_level_block_size) - ): + if top_level_block_size is not None and int( + dflash_block_size + ) != int(top_level_block_size): logger.warning( "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; " "using dflash_config.block_size for speculative_num_draft_tokens inference.", diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 5e2ceefd2261..4b85468a7474 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -4,6 +4,7 @@ from typing import List, Tuple import torch + from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import ScheduleBatch @@ -13,9 +14,9 @@ get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.dflash_utils import compute_dflash_accept_len_and_bonus from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func -from sglang.srt.speculative.dflash_utils import compute_dflash_accept_len_and_bonus @dataclass @@ -83,7 +84,9 @@ def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True self.target_hidden = old_target_hidden[:0] return - r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[None, :] + r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[ + None, : + ] pos2d = seg_start[:, None] + r mask = r < seg_lens[:, None] flat_pos = pos2d[mask] @@ -101,8 +104,12 @@ def merge_batch(self, spec_info: "DFlashDraftInput"): ) if self.target_hidden is None or self.target_hidden.numel() == 0: self.target_hidden = spec_info.target_hidden - elif spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0: - self.target_hidden = torch.cat([self.target_hidden, spec_info.target_hidden], dim=0) + elif ( + spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0 + ): + self.target_hidden = torch.cat( + [self.target_hidden, spec_info.target_hidden], dim=0 + ) @dataclass @@ -146,7 +153,9 @@ def prepare_for_verify( batch.input_ids = self.draft_token if page_size == 1: - batch.out_cache_loc = alloc_token_slots(batch.tree_cache, len(batch.input_ids)) + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, len(batch.input_ids) + ) end_offset = batch.seq_lens + self.draft_token_num else: prefix_lens = batch.seq_lens @@ -189,9 +198,7 @@ def prepare_for_verify( ) mask_chunks: List[torch.Tensor] = [] q_len = int(self.draft_token_num) - q_idx = torch.arange( - q_len, device=batch.device, dtype=torch.int32 - ).unsqueeze(1) + q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) for prefix_len in batch.seq_lens_cpu.tolist(): prefix_len_i = int(prefix_len) kv_len = prefix_len_i + q_len @@ -302,7 +309,9 @@ def verify( out_lens = accept_len.to(torch.int32) + 1 accept_len_i64 = accept_len.to(torch.int64) - out_tokens = torch.empty((bs, self.draft_token_num), dtype=torch.int64, device=device) + out_tokens = torch.empty( + (bs, self.draft_token_num), dtype=torch.int64, device=device + ) if int(self.draft_token_num) > 1: out_tokens[:, : self.draft_token_num - 1].copy_(candidates[:, 1:]) out_tokens[:, self.draft_token_num - 1].fill_(0) @@ -325,14 +334,18 @@ def verify( and not req.sampling_params.stop_strs and not req.sampling_params.stop_regex_strs ): - remaining = int(req.sampling_params.max_new_tokens) - len(req.output_ids) + remaining = int(req.sampling_params.max_new_tokens) - len( + req.output_ids + ) if remaining > 0: tokens = proposed[:remaining] if not req.sampling_params.ignore_eos: stop_token_ids = req.sampling_params.stop_token_ids eos_token_ids = req.eos_token_ids tokenizer = req.tokenizer - tokenizer_eos = tokenizer.eos_token_id if tokenizer is not None else None + tokenizer_eos = ( + tokenizer.eos_token_id if tokenizer is not None else None + ) additional_stop = ( tokenizer.additional_stop_token_ids if tokenizer is not None @@ -352,7 +365,9 @@ def verify( if eos_token_ids and token_id in eos_token_ids: tokens = tokens[: j + 1] break - if tokenizer_eos is not None and int(token_id) == int(tokenizer_eos): + if tokenizer_eos is not None and int(token_id) == int( + tokenizer_eos + ): tokens = tokens[: j + 1] break if additional_stop and token_id in additional_stop: @@ -397,7 +412,9 @@ def verify( batch.out_cache_loc = out_cache_loc[keep_mask] else: # Page-size > 1 is not supported in the initial DFlash implementation. - raise NotImplementedError("DFLASH verify with page_size > 1 is not supported yet.") + raise NotImplementedError( + "DFLASH verify with page_size > 1 is not supported yet." + ) # Update req-level KV cache accounting. for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): @@ -417,14 +434,18 @@ def verify( # Update batch seq lens. batch.seq_lens.add_(commit_lens.to(batch.seq_lens.dtype)) - batch.seq_lens_cpu.add_(torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype)) + batch.seq_lens_cpu.add_( + torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype) + ) # Keep seq_lens_sum in sync; flashinfer indices updaters rely on this for buffer sizing. batch.seq_lens_sum += sum(commit_lens_cpu) # Build next-step context features from the committed verify-input tokens. hidden = logits_output.hidden_states if hidden is None: - raise RuntimeError("DFLASH verify requires target hidden states, but got None.") + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) hidden = hidden.view(bs, self.draft_token_num, -1) segments: List[torch.Tensor] = [] for i, ln in enumerate(commit_lens_cpu): @@ -435,5 +456,12 @@ def verify( # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). logits_output.hidden_states = None - new_verified_id = torch.tensor(new_verified_cpu, dtype=torch.int64, device=device) - return new_verified_id, commit_lens, next_target_hidden, accept_length_per_req_cpu + new_verified_id = torch.tensor( + new_verified_cpu, dtype=torch.int64, device=device + ) + return ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 8c0aa4facb05..518049984caf 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -4,7 +4,6 @@ import torch - DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" @@ -27,7 +26,9 @@ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> Lis when mapping to capture points. """ if num_target_layers <= 0: - raise ValueError(f"num_target_layers must be positive, got {num_target_layers}.") + raise ValueError( + f"num_target_layers must be positive, got {num_target_layers}." + ) if num_draft_layers <= 0: raise ValueError(f"num_draft_layers must be positive, got {num_draft_layers}.") diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8205c38fc970..acdc8581c5bc 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -1,9 +1,8 @@ import logging from copy import deepcopy -from typing import List, Optional, Union +from typing import Optional, Union import torch -import torch.nn.functional as F from sglang.srt.distributed import get_tp_group from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch @@ -84,7 +83,9 @@ def __init__( draft_server_args.decode_attention_backend = None draft_server_args.attention_backend = draft_backend # Keep draft context length aligned with the target. - draft_server_args.context_length = target_worker.model_runner.model_config.context_len + draft_server_args.context_length = ( + target_worker.model_runner.model_config.context_len + ) self.draft_worker = TpModelWorker( server_args=draft_server_args, gpu_id=gpu_id, @@ -105,7 +106,9 @@ def __init__( else: self.block_size = int(server_args.speculative_num_draft_tokens) model_block_size = getattr(self.draft_model, "block_size", None) - if model_block_size is not None and int(model_block_size) != int(self.block_size): + if model_block_size is not None and int(model_block_size) != int( + self.block_size + ): logger.warning( "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", self.block_size, @@ -133,8 +136,12 @@ def __init__( self.block_size, device=self.device, dtype=torch.int64 ) self._draft_block_ids_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] - self._draft_block_positions_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] - self._draft_block_tokens_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_positions_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_tokens_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU self._draft_block_spec_info = DFlashVerifyInput( @@ -153,7 +160,11 @@ def __init__( self._draft_greedy_index_cap: int = 0 def _ensure_draft_block_buffers(self, bs: int) -> None: - cap = 0 if self._draft_block_ids_buf is None else int(self._draft_block_ids_buf.shape[0]) + cap = ( + 0 + if self._draft_block_ids_buf is None + else int(self._draft_block_ids_buf.shape[0]) + ) if cap >= int(bs): return @@ -172,7 +183,9 @@ def _ensure_draft_block_buffers(self, bs: int) -> None: self._draft_block_end_buf = torch.empty( (new_cap,), dtype=torch.int32, device=device ) - self._draft_seq_lens_cpu_buf = torch.empty((new_cap,), dtype=torch.int32, device="cpu") + self._draft_seq_lens_cpu_buf = torch.empty( + (new_cap,), dtype=torch.int32, device="cpu" + ) def __getattr__(self, name): # Delegate anything not implemented yet to the target worker. @@ -190,11 +203,15 @@ def on_req_finished(self, req): def _resolve_mask_token_id(self, *, mask_token: str) -> int: if not isinstance(mask_token, str) or not mask_token: - raise ValueError(f"DFLASH mask_token must be a non-empty string, got {mask_token!r}.") + raise ValueError( + f"DFLASH mask_token must be a non-empty string, got {mask_token!r}." + ) tokenizer = getattr(self.target_worker, "tokenizer", None) if tokenizer is None: - raise RuntimeError("DFLASH requires tokenizer initialization (skip_tokenizer_init is not supported).") + raise RuntimeError( + "DFLASH requires tokenizer initialization (skip_tokenizer_init is not supported)." + ) vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) mask_token_id = None @@ -239,12 +256,16 @@ def _resolve_mask_token_id(self, *, mask_token: str) -> int: return int(mask_token_id) - def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: DFlashDraftInput): + def _prepare_for_speculative_decoding( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): return if batch.has_grammar: - raise ValueError("DFLASH does not support grammar-constrained decoding yet.") + raise ValueError( + "DFLASH does not support grammar-constrained decoding yet." + ) if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: if not self._warned_forced_greedy and self.tp_rank == 0: logger.warning( @@ -262,7 +283,11 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D target_model = self.target_worker.model_runner.model embed_module = target_model.get_input_embeddings() lm_head = getattr(target_model, "lm_head", None) - if lm_head is None or not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + if ( + lm_head is None + or not hasattr(lm_head, "weight") + or not hasattr(lm_head, "shard_indices") + ): raise RuntimeError( "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " "`shard_indices` attributes." @@ -381,7 +406,11 @@ def _prepare_for_speculative_decoding(self, batch: ScheduleBatch, draft_input: D build_custom_mask=build_custom_mask, ) - batch.forward_mode = ForwardMode.TARGET_VERIFY if not batch.forward_mode.is_idle() else ForwardMode.IDLE + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) batch.spec_info = verify_input batch.return_hidden_states = False @@ -440,7 +469,8 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: if num_org > 0: base_logits = torch.matmul(hs, weight[:num_org].T) out_token_ids[start:end] = ( - torch.argmax(base_logits, dim=-1).to(torch.long) + org_vocab_start + torch.argmax(base_logits, dim=-1).to(torch.long) + + org_vocab_start ) else: out_token_ids[start:end] = 0 @@ -470,23 +500,31 @@ def _cast_hs(x: torch.Tensor) -> torch.Tensor: if num_added > 0: added_slice_start = num_org_padded added_slice_end = num_org_padded + num_added - added_logits = torch.matmul(hs, weight[added_slice_start:added_slice_end].T) + added_logits = torch.matmul( + hs, weight[added_slice_start:added_slice_end].T + ) added_max, added_arg = torch.max(added_logits, dim=-1) use_added = added_max > local_max local_max = torch.where(use_added, added_max, local_max) # For base/added conversion below, keep local_arg expressed in the full local # weight index space (base + padding + added), matching `lm_head.weight`. - local_arg = torch.where(use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg) + local_arg = torch.where( + use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg + ) # Convert local argmax indices to global token ids. if num_added == 0: local_arg.add_(org_vocab_start) global_ids = local_arg else: - global_ids = torch.empty((chunk_len,), dtype=torch.int64, device=hs.device) + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hs.device + ) is_base = local_arg < num_org global_ids[is_base] = org_vocab_start + local_arg[is_base] - global_ids[~is_base] = added_vocab_start + (local_arg[~is_base] - num_org_padded) + global_ids[~is_base] = added_vocab_start + ( + local_arg[~is_base] - num_org_padded + ) if tp_size == 1: out_token_ids[start:end] = global_ids.to(torch.long) @@ -565,7 +603,9 @@ def _append_target_hidden_to_draft_kv( device = self.model_runner.device if draft_input.target_hidden is None: - raise RuntimeError("DFLASH draft state missing target_hidden context features.") + raise RuntimeError( + "DFLASH draft state missing target_hidden context features." + ) if draft_input.ctx_lens.numel() != bs: raise RuntimeError( f"DFLASH ctx_lens length mismatch: got {draft_input.ctx_lens.numel()} for bs={bs}." @@ -664,7 +704,9 @@ def forward_batch_generation( **kwargs, ) -> GenerationBatchResult: if getattr(batch, "return_logprob", False): - raise ValueError("DFLASH speculative decoding does not support return_logprob yet.") + raise ValueError( + "DFLASH speculative decoding does not support return_logprob yet." + ) if isinstance(batch, ModelWorkerBatch): # Should not happen for spec-v1 (non-overlap) scheduling, but keep a sane fallback. @@ -687,7 +729,10 @@ def forward_batch_generation( "Make sure the target model has DFlash layers-to-capture configured." ) - if model_worker_batch.extend_seq_lens is None or model_worker_batch.extend_prefix_lens is None: + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): raise RuntimeError( "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, but got None." ) @@ -707,7 +752,9 @@ def _to_int32_device_tensor(x, *, device=device): verified_id=next_token_ids.to(torch.int64), target_hidden=logits_output.hidden_states, ctx_lens=_to_int32_device_tensor(model_worker_batch.extend_seq_lens), - draft_seq_lens=_to_int32_device_tensor(model_worker_batch.extend_prefix_lens), + draft_seq_lens=_to_int32_device_tensor( + model_worker_batch.extend_prefix_lens + ), ) self._append_target_hidden_to_draft_kv(batch, draft_input) batch.spec_info = draft_input diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 6735c59f286d..5c4daeb702d4 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -64,7 +64,9 @@ def create_worker( if self.is_dflash(): if enable_overlap: - raise ValueError("DFLASH does not support overlap scheduling (spec v2).") + raise ValueError( + "DFLASH does not support overlap scheduling (spec v2)." + ) from sglang.srt.speculative.dflash_worker import DFlashWorker return DFlashWorker @@ -134,7 +136,10 @@ def __init__(self, spec_input_type: SpecInputType): def is_draft_input(self) -> bool: # FIXME: remove this function which is only used for assertion # or use another variable name like `draft_input` to substitute `spec_info` - return self.spec_input_type in {SpecInputType.EAGLE_DRAFT, SpecInputType.DFLASH_DRAFT} + return self.spec_input_type in { + SpecInputType.EAGLE_DRAFT, + SpecInputType.DFLASH_DRAFT, + } def is_verify_input(self) -> bool: return self.spec_input_type in { From ed8b16de88c1f7d17e196eac5f52e62bb9b6037e Mon Sep 17 00:00:00 2001 From: xiaomin-D Date: Tue, 20 Jan 2026 07:39:36 +0000 Subject: [PATCH 32/73] feat(dflash): add fused KV materialization kernel and optimize D2H - Add Triton-based fused KV materialization kernel (batched proj + norm + RoPE + cache write) - Enable fused kernel by default on CUDA devices - Optimize verify with single D2H transfer (merge candidates/accept_len/bonus) - Simplify dflash.py code formatting Co-authored-by: yilian49 --- python/sglang/srt/models/dflash.py | 10 +- python/sglang/srt/speculative/dflash_info.py | 42 ++- .../sglang/srt/speculative/dflash_worker.py | 124 ++++++++- .../srt/speculative/triton_ops/__init__.py | 20 ++ .../triton_ops/fused_kv_materialize.py | 246 ++++++++++++++++++ 5 files changed, 392 insertions(+), 50 deletions(-) create mode 100644 python/sglang/srt/speculative/triton_ops/__init__.py create mode 100644 python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index aab2a2d3f1ff..1f2824d0692d 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -117,16 +117,8 @@ def forward( ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q, k = apply_qk_norm( - q=q, - k=k, - q_norm=self.q_norm, - k_norm=self.k_norm, - head_dim=self.head_dim, - ) + q, k = apply_qk_norm(q, k, self.q_norm, self.k_norm, self.head_dim) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output) return output diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 4b85468a7474..e5be27117292 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -305,28 +305,21 @@ def verify( target_predict=target_predict, ) - # Build output tokens on GPU: accepted drafts + bonus token. - out_lens = accept_len.to(torch.int32) + 1 - accept_len_i64 = accept_len.to(torch.int64) + # Single D2H transfer: candidates[1:] + accept_len + bonus + packed = torch.cat( + [candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 + ).cpu() - out_tokens = torch.empty( - (bs, self.draft_token_num), dtype=torch.int64, device=device - ) - if int(self.draft_token_num) > 1: - out_tokens[:, : self.draft_token_num - 1].copy_(candidates[:, 1:]) - out_tokens[:, self.draft_token_num - 1].fill_(0) - out_tokens.scatter_(1, accept_len_i64[:, None], bonus[:, None]) - - out_tokens_cpu = out_tokens.cpu() - out_lens_cpu = out_lens.cpu() - - commit_lens_cpu: List[int] = [] - new_verified_cpu: List[int] = [] + max_acc = self.draft_token_num - 1 accept_length_per_req_cpu: List[int] = [] + commit_lens_cpu: List[int] = [] + new_verified_list: List[int] = [] for i, req in enumerate(batch.reqs): - proposed_len = int(out_lens_cpu[i]) - proposed = out_tokens_cpu[i, :proposed_len].tolist() + acc_len = int(packed[i, max_acc].item()) + proposed = packed[i, :acc_len].tolist() + [ + int(packed[i, max_acc + 1].item()) + ] appended = 0 if ( @@ -388,18 +381,16 @@ def verify( if req.grammar is not None: req.grammar.accept_token(int(tok)) - # DFlash always treats the last appended token as the new "current token" - # (uncommitted); therefore we commit exactly `appended` verify-input tokens. - if appended <= 0: - raise RuntimeError("DFLASH verify unexpectedly appended 0 tokens.") commit_lens_cpu.append(appended) - new_verified_cpu.append(req.output_ids[-1]) + new_verified_list.append(req.output_ids[-1]) accept_length_per_req_cpu.append(max(0, appended - 1)) - req.spec_verify_ct += 1 req.spec_accepted_tokens += accept_length_per_req_cpu[-1] commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) + new_verified_id = torch.tensor( + new_verified_list, dtype=torch.int64, device=device + ) # Free uncommitted KV cache slots and compact out_cache_loc. if page_size == 1: @@ -456,9 +447,6 @@ def verify( # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). logits_output.hidden_states = None - new_verified_id = torch.tensor( - new_verified_cpu, dtype=torch.int64, device=device - ) return ( new_verified_id, commit_lens, diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index acdc8581c5bc..9c234b5bd6ff 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -18,9 +18,23 @@ from sglang.srt.speculative.dflash_utils import resolve_dflash_mask_token from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.utils import is_cuda logger = logging.getLogger(__name__) +_FusedKVMaterializeHelper = None + + +def _get_fused_kv_materialize_helper(): + global _FusedKVMaterializeHelper + if _FusedKVMaterializeHelper is None: + from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, + ) + + _FusedKVMaterializeHelper = FusedKVMaterializeHelper + return _FusedKVMaterializeHelper + class DFlashWorker: """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" @@ -159,6 +173,49 @@ def __init__( self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None self._draft_greedy_index_cap: int = 0 + self._use_fused_kv_materialize = is_cuda() + self._fused_kv_helper: Optional[object] = None + if self._use_fused_kv_materialize: + self._init_fused_kv_helper() + + def _init_fused_kv_helper(self) -> None: + """Initialize the fused KV materialization helper with pre-stacked weights.""" + try: + FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() + layers = self.draft_model.layers + if len(layers) == 0: + logger.warning( + "DFLASH fused KV: no layers found, disabling fused path." + ) + self._use_fused_kv_materialize = False + return + + first_attn = layers[0].self_attn + rotary_emb = first_attn.rotary_emb + + self._fused_kv_helper = FusedKVMaterializeHelper( + layers=layers, + rotary_emb=rotary_emb, + num_kv_heads=first_attn.num_kv_heads, + head_dim=first_attn.head_dim, + device=self.device, + ) + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization enabled. " + "n_layers=%d, num_kv_heads=%d, head_dim=%d", + len(layers), + first_attn.num_kv_heads, + first_attn.head_dim, + ) + except Exception as e: + logger.warning( + "DFLASH fused KV initialization failed, falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + def _ensure_draft_block_buffers(self, bs: int) -> None: cap = ( 0 @@ -678,26 +735,65 @@ def _append_target_hidden_to_draft_kv( f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." ) - for layer in self.draft_model.layers: - attn = layer.self_attn - k, v = attn.kv_proj_only(ctx_hidden) - k = attn.apply_k_norm(k) - k = attn.apply_k_rope(ctx_positions, k) - k = k.view(-1, attn.num_kv_heads, attn.head_dim) - v = v.view(-1, attn.num_kv_heads, attn.head_dim) - self.draft_model_runner.token_to_kv_pool.set_kv_buffer( - attn.attn, - ctx_cache_loc, - k, - v, - attn.attn.k_scale, - attn.attn.v_scale, + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + self._append_target_hidden_fused( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + else: + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc ) draft_input.draft_seq_lens = draft_seq_lens + ctx_lens draft_input.ctx_lens = torch.zeros_like(ctx_lens) draft_input.target_hidden = draft_input.target_hidden[:0] + def _append_target_hidden_sequential( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(ctx_positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + def _append_target_hidden_fused( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + """Fused KV materialization using batched projection + Triton kernel.""" + token_to_kv_pool = self.draft_model_runner.token_to_kv_pool + k_cache_buffers = [] + v_cache_buffers = [] + for layer in self.draft_model.layers: + layer_id = layer.self_attn.attn.layer_id + k_buf, v_buf = token_to_kv_pool.get_kv_buffer(layer_id) + k_cache_buffers.append(k_buf) + v_cache_buffers.append(v_buf) + + self._fused_kv_helper.materialize( + ctx_hidden=ctx_hidden, + positions=ctx_positions, + cache_locs=ctx_cache_loc, + k_cache_buffers=k_cache_buffers, + v_cache_buffers=v_cache_buffers, + ) + def forward_batch_generation( self, batch: Union[ScheduleBatch, ModelWorkerBatch], diff --git a/python/sglang/srt/speculative/triton_ops/__init__.py b/python/sglang/srt/speculative/triton_ops/__init__.py new file mode 100644 index 000000000000..a8ea8f4c704b --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Triton kernels for speculative decoding.""" + +from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, +) + +__all__ = ["FusedKVMaterializeHelper"] diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py new file mode 100644 index 000000000000..97cddd454efb --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -0,0 +1,246 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fused Triton kernel for DFlash KV materialization. + +Combines: KV projection (cuBLAS) + RMSNorm + RoPE + KV cache write (Triton). +""" + +from typing import List + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_norm_rope_write_kernel( + kv_ptr, # [total_ctx, kv_size * 2] + k_norm_weight_ptr, # [head_dim] + cos_sin_cache_ptr, # [max_pos, rotary_dim] + positions_ptr, # [total_ctx] + cache_loc_ptr, # [total_ctx] + k_cache_ptr, # [total_slots, num_kv_heads, head_dim] + v_cache_ptr, # [total_slots, num_kv_heads, head_dim] + kv_stride_ctx, + cos_sin_stride_pos, + k_cache_stride_slot, + k_cache_stride_head, + v_cache_stride_slot, + v_cache_stride_head, + total_ctx, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_size: tl.constexpr, + rotary_dim: tl.constexpr, + half_rotary_dim: tl.constexpr, + eps: tl.constexpr, + BLOCK_HD: tl.constexpr, +): + """Fused RMSNorm(K) + RoPE(K) + cache write. Grid: (total_ctx, num_kv_heads).""" + ctx_id = tl.program_id(0) + head_id = tl.program_id(1) + if ctx_id >= total_ctx: + return + + # Load metadata + position = tl.load(positions_ptr + ctx_id) + cache_loc = tl.load(cache_loc_ptr + ctx_id) + + # Compute base pointers + kv_base = kv_ptr + ctx_id * kv_stride_ctx + k_base = kv_base + head_id * head_dim + v_base = kv_base + kv_size + head_id * head_dim + k_write = ( + k_cache_ptr + cache_loc * k_cache_stride_slot + head_id * k_cache_stride_head + ) + v_write = ( + v_cache_ptr + cache_loc * v_cache_stride_slot + head_id * v_cache_stride_head + ) + + # Load K and V + offs = tl.arange(0, BLOCK_HD) + mask_hd = offs < head_dim + mask_half = offs < half_rotary_dim + + k_raw = tl.load(k_base + offs, mask=mask_hd, other=0.0).to(tl.float32) + v_raw = tl.load(v_base + offs, mask=mask_hd, other=0.0) + + # RMSNorm on K + inv_rms = tl.rsqrt(tl.sum(k_raw * k_raw) / head_dim + eps) + norm_w = tl.load(k_norm_weight_ptr + offs, mask=mask_hd, other=1.0).to(tl.float32) + k_normed = k_raw * inv_rms * norm_w + + # RoPE (neox style): k_first, k_second -> rotated + cos_sin_base = cos_sin_cache_ptr + position * cos_sin_stride_pos + cos_v = tl.load(cos_sin_base + offs, mask=mask_half, other=1.0).to(tl.float32) + sin_v = tl.load( + cos_sin_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + + # Extract first/second halves of K for rotation + k_first = tl.where(mask_half, k_normed, 0.0) + k_second_raw = tl.load( + k_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + norm_w_second = tl.load( + k_norm_weight_ptr + half_rotary_dim + offs, mask=mask_half, other=1.0 + ).to(tl.float32) + k_second = k_second_raw * inv_rms * norm_w_second + + # Apply rotation + k_rot_first = k_first * cos_v - k_second * sin_v + k_rot_second = k_second * cos_v + k_first * sin_v + + # Store V (no transform) + tl.store(v_write + offs, v_raw, mask=mask_hd) + + # Store K: rotated halves + pass-through + tl.store(k_write + offs, k_rot_first.to(v_raw.dtype), mask=mask_half) + tl.store( + k_write + half_rotary_dim + offs, k_rot_second.to(v_raw.dtype), mask=mask_half + ) + mask_pass = (offs >= rotary_dim) & (offs < head_dim) + tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) + + +def _fused_norm_rope_write( + kv: torch.Tensor, # [total_ctx, kv_size*2] + k_norm_weight: torch.Tensor, # [head_dim] + cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] + positions: torch.Tensor, # [total_ctx] + cache_locs: torch.Tensor, # [total_ctx] + k_cache: torch.Tensor, # [total_slots, num_kv_heads, head_dim] + v_cache: torch.Tensor, # [total_slots, num_kv_heads, head_dim] + num_kv_heads: int, + head_dim: int, + rotary_dim: int, + eps: float = 1e-6, +) -> None: + """Fused RMSNorm + RoPE + cache write for a single layer.""" + total_ctx = kv.shape[0] + if total_ctx == 0: + return + + kv_size = num_kv_heads * head_dim + half_rotary_dim = rotary_dim // 2 + BLOCK_HD = triton.next_power_of_2(head_dim) + + # Ensure int64 for indexing + if positions.dtype != torch.int64: + positions = positions.to(torch.int64) + if cache_locs.dtype != torch.int64: + cache_locs = cache_locs.to(torch.int64) + + _fused_norm_rope_write_kernel[(total_ctx, num_kv_heads)]( + kv, + k_norm_weight, + cos_sin_cache, + positions, + cache_locs, + k_cache, + v_cache, + kv.stride(0), + cos_sin_cache.stride(0), + k_cache.stride(0), + k_cache.stride(1), + v_cache.stride(0), + v_cache.stride(1), + total_ctx, + num_kv_heads, + head_dim, + kv_size, + rotary_dim, + half_rotary_dim, + eps, + BLOCK_HD, + ) + + +class FusedKVMaterializeHelper: + """Fused KV materialization helper using batched projection. + + Uses torch.einsum for batched KV projection across all layers, + then a Triton kernel for fused RMSNorm + RoPE + cache write per layer. + """ + + def __init__( + self, + layers: List, + rotary_emb, + num_kv_heads: int, + head_dim: int, + device: torch.device, + ): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.n_layers = len(layers) + self.device = device + + # Pre-extract and stack weights for batched projection + kv_weights = [] + self.k_norm_weights = [] + self.eps_values = [] + + for layer in layers: + attn = layer.self_attn + # Extract KV portion of QKV weight + qkv_w = attn.qkv_proj.weight + kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] + kv_weights.append(kv_weight) + self.k_norm_weights.append(attn.k_norm.weight) + self.eps_values.append(attn.k_norm.variance_epsilon) + + # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] + self.batched_kv_weight = torch.stack(kv_weights) + + self.rotary_dim = getattr(rotary_emb, "rotary_dim", head_dim) + self.is_neox_style = getattr(rotary_emb, "is_neox_style", True) + + if not self.is_neox_style: + raise NotImplementedError("Only neox-style RoPE is supported.") + + def materialize( + self, + ctx_hidden: torch.Tensor, + positions: torch.Tensor, + cache_locs: torch.Tensor, + k_cache_buffers: List[torch.Tensor], + v_cache_buffers: List[torch.Tensor], + ) -> None: + """Materialize KV cache for all layers using batched projection.""" + total_ctx = ctx_hidden.shape[0] + if total_ctx == 0: + return + + cos_sin_cache = self.rotary_emb.cos_sin_cache + + # Batched KV projection: [n_layers, total_ctx, kv_size*2] + kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) + + # Per-layer fused norm/RoPE/write + for layer_id in range(self.n_layers): + _fused_norm_rope_write( + kv_all[layer_id], + self.k_norm_weights[layer_id], + cos_sin_cache, + positions, + cache_locs, + k_cache_buffers[layer_id], + v_cache_buffers[layer_id], + self.num_kv_heads, + self.head_dim, + self.rotary_dim, + self.eps_values[layer_id], + ) From f2a6dbc27b11f413d6af7d64826b1ef3752cc705 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 20 Jan 2026 21:40:45 +0000 Subject: [PATCH 33/73] add support for qwen3_moe --- python/sglang/srt/models/dflash.py | 26 ++++++++----------- python/sglang/srt/models/qwen2_moe.py | 5 ++++ python/sglang/srt/models/qwen3_moe.py | 12 +++++++++ python/sglang/srt/speculative/dflash_utils.py | 5 +++- 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index aab2a2d3f1ff..344ccc640ac4 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -26,7 +26,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.utils import apply_qk_norm -from sglang.srt.speculative.dflash_utils import get_dflash_config +from sglang.srt.speculative.dflash_utils import ( + get_dflash_config, + resolve_dflash_target_layer_ids, +) logger = logging.getLogger(__name__) @@ -275,20 +278,13 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: # Project per-token target context features: # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer # feature tensors concatenated per token (not necessarily equal to num_layers). - target_layer_ids = dflash_cfg_dict.get("target_layer_ids", None) - if target_layer_ids is None: - num_context_features = num_layers - else: - if not isinstance(target_layer_ids, (list, tuple)): - raise ValueError( - "DFLASH dflash_config.target_layer_ids must be a list of ints, " - f"got type={type(target_layer_ids).__name__}." - ) - if len(target_layer_ids) <= 0: - raise ValueError( - "DFLASH dflash_config.target_layer_ids must be non-empty, got []." - ) - num_context_features = len(target_layer_ids) + target_num_layers = int(getattr(config, "num_target_layers", num_layers)) + target_layer_ids = resolve_dflash_target_layer_ids( + draft_hf_config=config, + target_num_layers=target_num_layers, + draft_num_layers=num_layers, + ) + num_context_features = len(target_layer_ids) self.num_context_features = int(num_context_features) self.fc = nn.Linear( diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 1e4b53a7d1fa..465ac03ffb43 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -604,6 +604,11 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]): for layer_id in self.layers_to_capture: setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def set_dflash_layers_to_capture(self, layers_to_capture: List[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index d0469a62e146..def766cc98a2 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -1014,6 +1014,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 518049984caf..27e4a3154438 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -73,7 +73,8 @@ def resolve_dflash_target_layer_ids( Precedence: 1) `draft_hf_config.dflash_config.target_layer_ids` - 2) default `build_target_layer_ids(target_num_layers, draft_num_layers)` + 2) `draft_hf_config.target_layer_ids` (fallback to base config) + 3) default `build_target_layer_ids(target_num_layers, draft_num_layers)` Notes: The number of draft transformer layers is *not* fundamentally tied to the number @@ -83,6 +84,8 @@ def resolve_dflash_target_layer_ids( """ cfg = get_dflash_config(draft_hf_config) layer_ids = cfg.get("target_layer_ids", None) + if layer_ids is None: + layer_ids = getattr(draft_hf_config, "target_layer_ids", None) if layer_ids is None: return build_target_layer_ids(target_num_layers, draft_num_layers) From 117352df4cd4975f7f5d6326a5852ed54fec1ed0 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 5 Feb 2026 22:38:27 +0000 Subject: [PATCH 34/73] support dflash_config.mask_token_id --- python/sglang/srt/speculative/dflash_utils.py | 22 +++++- .../sglang/srt/speculative/dflash_worker.py | 75 ++++++++++++++----- 2 files changed, 77 insertions(+), 20 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 27e4a3154438..c5b37bc5f75f 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, List, Tuple +from numbers import Integral +from typing import Any, List, Optional, Tuple import torch @@ -124,6 +125,25 @@ def resolve_dflash_mask_token(*, draft_hf_config: Any) -> str: return mask_token +def resolve_dflash_mask_token_id(*, draft_hf_config: Any) -> Optional[int]: + cfg = get_dflash_config(draft_hf_config) + mask_token_id = cfg.get("mask_token_id", None) + if mask_token_id is None: + return None + if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): + raise ValueError( + "DFLASH dflash_config.mask_token_id must be an integer, " + f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." + ) + mask_token_id = int(mask_token_id) + if mask_token_id < 0: + raise ValueError( + "DFLASH dflash_config.mask_token_id must be non-negative, " + f"got {mask_token_id}." + ) + return mask_token_id + + def compute_dflash_accept_len_and_bonus( *, candidates: torch.Tensor, diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index acdc8581c5bc..9c413bd0ad74 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -15,7 +15,10 @@ ) from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput -from sglang.srt.speculative.dflash_utils import resolve_dflash_mask_token +from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_mask_token, + resolve_dflash_mask_token_id, +) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func @@ -118,7 +121,13 @@ def __init__( self._mask_token = resolve_dflash_mask_token( draft_hf_config=self.draft_model_runner.model_config.hf_config ) - self._mask_token_id = self._resolve_mask_token_id(mask_token=self._mask_token) + self._mask_token_id_override = resolve_dflash_mask_token_id( + draft_hf_config=self.draft_model_runner.model_config.hf_config + ) + self._mask_token_id = self._resolve_mask_token_id( + mask_token=self._mask_token, + mask_token_id=self._mask_token_id_override, + ) if self.tp_rank == 0: logger.info( "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", @@ -127,9 +136,10 @@ def __init__( self.block_size, ) logger.info( - "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s", + "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s, mask_token_id_override=%s", self._mask_token, self._mask_token_id, + self._mask_token_id_override, ) self._block_pos_offsets = torch.arange( @@ -201,60 +211,87 @@ def on_req_finished(self, req): if hasattr(req, "dflash_draft_seq_len"): req.dflash_draft_seq_len = 0 - def _resolve_mask_token_id(self, *, mask_token: str) -> int: + def _resolve_mask_token_id( + self, *, mask_token: str, mask_token_id: Optional[int] = None + ) -> int: if not isinstance(mask_token, str) or not mask_token: raise ValueError( f"DFLASH mask_token must be a non-empty string, got {mask_token!r}." ) + vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) + if mask_token_id is not None: + resolved_id = int(mask_token_id) + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is not None: + token_id_from_vocab = tokenizer.get_vocab().get(mask_token, None) + if ( + token_id_from_vocab is not None + and int(token_id_from_vocab) != resolved_id + ): + raise ValueError( + "DFLASH config mismatch: dflash_config.mask_token_id conflicts with tokenizer vocab id " + f"for dflash_config.mask_token. mask_token={mask_token!r}, " + f"mask_token_id={resolved_id}, tokenizer_vocab_id={int(token_id_from_vocab)}." + ) + return resolved_id + tokenizer = getattr(self.target_worker, "tokenizer", None) if tokenizer is None: raise RuntimeError( - "DFLASH requires tokenizer initialization (skip_tokenizer_init is not supported)." + "DFLASH requires tokenizer initialization when dflash_config.mask_token_id is not set " + "(skip_tokenizer_init is not supported in this mode)." ) - vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) - mask_token_id = None + resolved_id = None if getattr(tokenizer, "mask_token", None) == mask_token: - mask_token_id = getattr(tokenizer, "mask_token_id", None) + resolved_id = getattr(tokenizer, "mask_token_id", None) - if mask_token_id is None: + if resolved_id is None: # Prefer checking the explicit vocab mapping first. vocab = tokenizer.get_vocab() - mask_token_id = vocab.get(mask_token, None) + resolved_id = vocab.get(mask_token, None) - if mask_token_id is None: + if resolved_id is None: # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. # This is safe only when the resulting id stays within the target model vocab size. added = tokenizer.add_special_tokens({"mask_token": mask_token}) - mask_token_id = getattr(tokenizer, "mask_token_id", None) - if mask_token_id is None: - mask_token_id = tokenizer.convert_tokens_to_ids(mask_token) + resolved_id = getattr(tokenizer, "mask_token_id", None) + if resolved_id is None: + resolved_id = tokenizer.convert_tokens_to_ids(mask_token) if added and self.tp_rank == 0: logger.info( "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", mask_token, - mask_token_id, + resolved_id, len(tokenizer), vocab_size, ) - if mask_token_id is None or int(mask_token_id) < 0: + if resolved_id is None or int(resolved_id) < 0: raise ValueError( "DFLASH requires resolving a mask token id, but it could not be resolved. " f"mask_token={mask_token!r}." ) - if mask_token_id >= vocab_size: + if resolved_id >= vocab_size: raise ValueError( "DFLASH mask_token_id is outside the target vocab size. " - f"mask_token_id={mask_token_id}, vocab_size={vocab_size}. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " "SGLang does not support resizing target embeddings for DFLASH yet." ) - return int(mask_token_id) + return int(resolved_id) def _prepare_for_speculative_decoding( self, batch: ScheduleBatch, draft_input: DFlashDraftInput From 5ba316c688753b4ecae6219a3e20bc3b69495aef Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 5 Feb 2026 23:39:45 +0000 Subject: [PATCH 35/73] add llama3.1 support and fix config block_size logic --- python/sglang/srt/models/llama.py | 12 +++++++ python/sglang/srt/server_args.py | 57 +++++++++++++++++++++---------- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 53761dae5b85..3b77d3922898 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -785,6 +785,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index af2401228257..ee4335acc693 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2132,6 +2132,8 @@ def _handle_speculative_decoding(self): if self.speculative_num_draft_tokens is None: inferred_block_size = None try: + top_level_block_size = None + dflash_block_size = None if os.path.isdir(self.speculative_draft_model_path): draft_config_path = os.path.join( self.speculative_draft_model_path, "config.json" @@ -2139,32 +2141,51 @@ def _handle_speculative_decoding(self): if os.path.isfile(draft_config_path): with open(draft_config_path, "r") as f: draft_config_json = json.load(f) - top_level_block_size = draft_config_json.get( - "block_size", None - ) + top_level_block_size = draft_config_json.get("block_size") dflash_cfg = draft_config_json.get("dflash_config", None) dflash_block_size = ( - dflash_cfg.get("block_size", None) + dflash_cfg.get("block_size") if isinstance(dflash_cfg, dict) else None ) - if dflash_block_size is not None: - inferred_block_size = dflash_block_size - if top_level_block_size is not None and int( - dflash_block_size - ) != int(top_level_block_size): - logger.warning( - "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; " - "using dflash_config.block_size for speculative_num_draft_tokens inference.", - top_level_block_size, - dflash_block_size, - ) - else: - inferred_block_size = top_level_block_size + if top_level_block_size is None and dflash_block_size is None: + from sglang.srt.utils.hf_transformers_utils import get_config + + draft_hf_config = get_config( + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + revision=self.speculative_draft_model_revision, + model_override_args=json.loads( + self.json_model_override_args + ), + ) + top_level_block_size = getattr( + draft_hf_config, "block_size", None + ) + dflash_cfg = getattr(draft_hf_config, "dflash_config", None) + dflash_block_size = ( + dflash_cfg.get("block_size") + if isinstance(dflash_cfg, dict) + else None + ) + + if dflash_block_size is not None: + inferred_block_size = dflash_block_size + if top_level_block_size is not None and int( + dflash_block_size + ) != int(top_level_block_size): + logger.warning( + "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; " + "using dflash_config.block_size for speculative_num_draft_tokens inference.", + top_level_block_size, + dflash_block_size, + ) + else: + inferred_block_size = top_level_block_size except Exception as e: logger.warning( - "Failed to infer DFlash block_size from draft config.json; " + "Failed to infer DFLASH block_size from draft config; " "defaulting speculative_num_draft_tokens to 16. Error: %s", e, ) From 189f17757415421430c7e2290bc5c078655b9e70 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 12 Feb 2026 19:03:19 +0000 Subject: [PATCH 36/73] guards for fused path --- python/sglang/srt/models/dflash.py | 7 +- python/sglang/srt/speculative/dflash_info.py | 12 +- python/sglang/srt/speculative/dflash_utils.py | 26 +++ .../sglang/srt/speculative/dflash_worker.py | 83 +++++++-- .../triton_ops/fused_kv_materialize.py | 163 ++++++++++++------ 5 files changed, 216 insertions(+), 75 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index cfaef193cc10..d72d13661524 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -20,13 +20,13 @@ QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.utils import apply_qk_norm from sglang.srt.speculative.dflash_utils import ( + can_dflash_slice_qkv_weight, get_dflash_config, resolve_dflash_target_layer_ids, ) @@ -135,9 +135,8 @@ def kv_proj_only( we only need K/V for the cached tokens; Q is never consumed. """ # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. - if isinstance( - getattr(self.qkv_proj, "quant_method", None), UnquantizedLinearMethod - ): + can_slice_qkv_weight, _ = can_dflash_slice_qkv_weight(self.qkv_proj) + if can_slice_qkv_weight: kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) weight = self.qkv_proj.weight[kv_slice] bias = ( diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index e5be27117292..828c791f47f2 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -381,8 +381,18 @@ def verify( if req.grammar is not None: req.grammar.accept_token(int(tok)) + if req.output_ids: + new_verified_token = int(req.output_ids[-1]) + elif req.origin_input_ids: + # If no token was appended in this verify step, keep the current token unchanged. + new_verified_token = int(req.origin_input_ids[-1]) + else: + raise RuntimeError( + "DFLASH verify cannot determine current token: both output_ids and origin_input_ids are empty." + ) + commit_lens_cpu.append(appended) - new_verified_list.append(req.output_ids[-1]) + new_verified_list.append(new_verified_token) accept_length_per_req_cpu.append(max(0, appended - 1)) req.spec_verify_ct += 1 req.spec_accepted_tokens += accept_length_per_req_cpu[-1] diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index c5b37bc5f75f..d8d7a14c2460 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -5,6 +5,8 @@ import torch +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod + DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" @@ -144,6 +146,30 @@ def resolve_dflash_mask_token_id(*, draft_hf_config: Any) -> Optional[int]: return mask_token_id +def can_dflash_slice_qkv_weight(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether DFlash can slice KV weights from a fused QKV linear layer.""" + quant_method = getattr(qkv_proj, "quant_method", None) + if not isinstance(quant_method, UnquantizedLinearMethod): + return ( + False, + "quantized qkv_proj is not supported for this path " + f"(quant_method={type(quant_method).__name__})", + ) + if not hasattr(qkv_proj, "weight"): + return False, "qkv weight tensor is missing" + return True, "" + + +def can_dflash_use_fused_qkv_proj(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether a QKV layer is eligible for DFlash fused KV materialization.""" + eligible, reason = can_dflash_slice_qkv_weight(qkv_proj) + if not eligible: + return False, reason + if getattr(qkv_proj, "bias", None) is not None: + return False, "qkv bias is not supported for fused KV path" + return True, "" + + def compute_dflash_accept_len_and_bonus( *, candidates: torch.Tensor, diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index b365d41b6f2d..b5225b368c46 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -1,4 +1,5 @@ import logging +import math from copy import deepcopy from typing import Optional, Union @@ -16,6 +17,7 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput from sglang.srt.speculative.dflash_utils import ( + can_dflash_use_fused_qkv_proj, resolve_dflash_mask_token, resolve_dflash_mask_token_id, ) @@ -191,15 +193,46 @@ def __init__( def _init_fused_kv_helper(self) -> None: """Initialize the fused KV materialization helper with pre-stacked weights.""" try: - FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() layers = self.draft_model.layers + fused_disable_reason: Optional[str] = None + if len(layers) == 0: - logger.warning( - "DFLASH fused KV: no layers found, disabling fused path." - ) + fused_disable_reason = "no layers found" + + for layer_idx, layer in enumerate(layers): + attn = layer.self_attn + eligible, reason = can_dflash_use_fused_qkv_proj(attn.qkv_proj) + if not eligible: + fused_disable_reason = f"{reason}: layer={layer_idx}" + break + + # Keep semantics aligned with set_kv_buffer scaling behavior. + k_scale = getattr(attn.attn, "k_scale", None) + v_scale = getattr(attn.attn, "v_scale", None) + if k_scale is not None and not math.isclose(float(k_scale), 1.0): + fused_disable_reason = ( + "non-unit k_scale is not supported for fused KV path: " + f"layer={layer_idx}, k_scale={k_scale}" + ) + break + if v_scale is not None and not math.isclose(float(v_scale), 1.0): + fused_disable_reason = ( + "non-unit v_scale is not supported for fused KV path: " + f"layer={layer_idx}, v_scale={v_scale}" + ) + break + + if fused_disable_reason is not None: + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization disabled: %s", + fused_disable_reason, + ) self._use_fused_kv_materialize = False + self._fused_kv_helper = None return + FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() first_attn = layers[0].self_attn rotary_emb = first_attn.rotary_emb @@ -773,9 +806,20 @@ def _append_target_hidden_to_draft_kv( ) if self._use_fused_kv_materialize and self._fused_kv_helper is not None: - self._append_target_hidden_fused( - ctx_hidden, ctx_positions, ctx_cache_loc - ) + try: + self._append_target_hidden_fused( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + except Exception as e: + logger.warning( + "DFLASH fused KV append failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) else: self._append_target_hidden_sequential( ctx_hidden, ctx_positions, ctx_cache_loc @@ -815,20 +859,25 @@ def _append_target_hidden_fused( ) -> None: """Fused KV materialization using batched projection + Triton kernel.""" token_to_kv_pool = self.draft_model_runner.token_to_kv_pool - k_cache_buffers = [] - v_cache_buffers = [] - for layer in self.draft_model.layers: - layer_id = layer.self_attn.attn.layer_id - k_buf, v_buf = token_to_kv_pool.get_kv_buffer(layer_id) - k_cache_buffers.append(k_buf) - v_cache_buffers.append(v_buf) + layers = self.draft_model.layers + + def _write_layer_kv( + layer_idx: int, cache_k: torch.Tensor, cache_v: torch.Tensor + ) -> None: + attn = layers[layer_idx].self_attn.attn + token_to_kv_pool.set_kv_buffer( + attn, + ctx_cache_loc, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) self._fused_kv_helper.materialize( ctx_hidden=ctx_hidden, positions=ctx_positions, - cache_locs=ctx_cache_loc, - k_cache_buffers=k_cache_buffers, - v_cache_buffers=v_cache_buffers, + write_layer_kv=_write_layer_kv, ) def forward_batch_generation( diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py index 97cddd454efb..e7dc4c05ddfc 100644 --- a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -13,10 +13,10 @@ # ============================================================================== """Fused Triton kernel for DFlash KV materialization. -Combines: KV projection (cuBLAS) + RMSNorm + RoPE + KV cache write (Triton). +Combines: KV projection (cuBLAS) + RMSNorm + RoPE (Triton), then pool-managed KV writes. """ -from typing import List +from typing import Callable, List import torch import triton @@ -24,20 +24,19 @@ @triton.jit -def _fused_norm_rope_write_kernel( +def _fused_norm_rope_kernel( kv_ptr, # [total_ctx, kv_size * 2] k_norm_weight_ptr, # [head_dim] cos_sin_cache_ptr, # [max_pos, rotary_dim] positions_ptr, # [total_ctx] - cache_loc_ptr, # [total_ctx] - k_cache_ptr, # [total_slots, num_kv_heads, head_dim] - v_cache_ptr, # [total_slots, num_kv_heads, head_dim] + k_out_ptr, # [total_ctx, num_kv_heads, head_dim] + v_out_ptr, # [total_ctx, num_kv_heads, head_dim] kv_stride_ctx, cos_sin_stride_pos, - k_cache_stride_slot, - k_cache_stride_head, - v_cache_stride_slot, - v_cache_stride_head, + k_out_stride_ctx, + k_out_stride_head, + v_out_stride_ctx, + v_out_stride_head, total_ctx, num_kv_heads: tl.constexpr, head_dim: tl.constexpr, @@ -47,7 +46,7 @@ def _fused_norm_rope_write_kernel( eps: tl.constexpr, BLOCK_HD: tl.constexpr, ): - """Fused RMSNorm(K) + RoPE(K) + cache write. Grid: (total_ctx, num_kv_heads).""" + """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads).""" ctx_id = tl.program_id(0) head_id = tl.program_id(1) if ctx_id >= total_ctx: @@ -55,18 +54,13 @@ def _fused_norm_rope_write_kernel( # Load metadata position = tl.load(positions_ptr + ctx_id) - cache_loc = tl.load(cache_loc_ptr + ctx_id) # Compute base pointers kv_base = kv_ptr + ctx_id * kv_stride_ctx k_base = kv_base + head_id * head_dim v_base = kv_base + kv_size + head_id * head_dim - k_write = ( - k_cache_ptr + cache_loc * k_cache_stride_slot + head_id * k_cache_stride_head - ) - v_write = ( - v_cache_ptr + cache_loc * v_cache_stride_slot + head_id * v_cache_stride_head - ) + k_write = k_out_ptr + ctx_id * k_out_stride_ctx + head_id * k_out_stride_head + v_write = v_out_ptr + ctx_id * v_out_stride_ctx + head_id * v_out_stride_head # Load K and V offs = tl.arange(0, BLOCK_HD) @@ -114,48 +108,63 @@ def _fused_norm_rope_write_kernel( tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) -def _fused_norm_rope_write( +def _fused_norm_rope( kv: torch.Tensor, # [total_ctx, kv_size*2] k_norm_weight: torch.Tensor, # [head_dim] cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] positions: torch.Tensor, # [total_ctx] - cache_locs: torch.Tensor, # [total_ctx] - k_cache: torch.Tensor, # [total_slots, num_kv_heads, head_dim] - v_cache: torch.Tensor, # [total_slots, num_kv_heads, head_dim] num_kv_heads: int, head_dim: int, rotary_dim: int, eps: float = 1e-6, -) -> None: - """Fused RMSNorm + RoPE + cache write for a single layer.""" +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused RMSNorm + RoPE materialization for a single layer.""" total_ctx = kv.shape[0] if total_ctx == 0: - return + empty = torch.empty( + (0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + return empty, empty kv_size = num_kv_heads * head_dim + if kv.shape[1] != kv_size * 2: + raise ValueError( + "Invalid fused KV projection shape: " + f"got {tuple(kv.shape)}, expected second dim {kv_size * 2}." + ) + if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={rotary_dim}, head_dim={head_dim}." + ) + half_rotary_dim = rotary_dim // 2 BLOCK_HD = triton.next_power_of_2(head_dim) # Ensure int64 for indexing - if positions.dtype != torch.int64: + if positions.device != kv.device: + positions = positions.to(device=kv.device, dtype=torch.int64) + elif positions.dtype != torch.int64: positions = positions.to(torch.int64) - if cache_locs.dtype != torch.int64: - cache_locs = cache_locs.to(torch.int64) - _fused_norm_rope_write_kernel[(total_ctx, num_kv_heads)]( + k_out = torch.empty( + (total_ctx, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + v_out = torch.empty_like(k_out) + + _fused_norm_rope_kernel[(total_ctx, num_kv_heads)]( kv, k_norm_weight, cos_sin_cache, positions, - cache_locs, - k_cache, - v_cache, + k_out, + v_out, kv.stride(0), cos_sin_cache.stride(0), - k_cache.stride(0), - k_cache.stride(1), - v_cache.stride(0), - v_cache.stride(1), + k_out.stride(0), + k_out.stride(1), + v_out.stride(0), + v_out.stride(1), total_ctx, num_kv_heads, head_dim, @@ -165,13 +174,14 @@ def _fused_norm_rope_write( eps, BLOCK_HD, ) + return k_out, v_out class FusedKVMaterializeHelper: """Fused KV materialization helper using batched projection. Uses torch.einsum for batched KV projection across all layers, - then a Triton kernel for fused RMSNorm + RoPE + cache write per layer. + then a Triton kernel for fused RMSNorm + RoPE materialization per layer. """ def __init__( @@ -188,13 +198,48 @@ def __init__( self.n_layers = len(layers) self.device = device - # Pre-extract and stack weights for batched projection + self.rotary_dim = int(getattr(rotary_emb, "rotary_dim", head_dim)) + self.is_neox_style = bool(getattr(rotary_emb, "is_neox_style", True)) + + if not self.is_neox_style: + raise NotImplementedError("Only neox-style RoPE is supported.") + if self.rotary_dim <= 0 or self.rotary_dim > self.head_dim: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={self.rotary_dim}, head_dim={self.head_dim}." + ) + + # Pre-extract and stack weights for batched projection. kv_weights = [] self.k_norm_weights = [] self.eps_values = [] - for layer in layers: + for layer_id, layer in enumerate(layers): attn = layer.self_attn + if int(attn.num_kv_heads) != self.num_kv_heads: + raise ValueError( + "num_kv_heads mismatch across layers for fused KV path: " + f"expected {self.num_kv_heads}, got {int(attn.num_kv_heads)} at layer {layer_id}." + ) + if int(attn.head_dim) != self.head_dim: + raise ValueError( + "head_dim mismatch across layers for fused KV path: " + f"expected {self.head_dim}, got {int(attn.head_dim)} at layer {layer_id}." + ) + layer_rotary_dim = int( + getattr(attn.rotary_emb, "rotary_dim", self.head_dim) + ) + layer_is_neox = bool(getattr(attn.rotary_emb, "is_neox_style", True)) + if ( + layer_rotary_dim != self.rotary_dim + or layer_is_neox != self.is_neox_style + ): + raise ValueError( + "RoPE config mismatch across layers for fused KV path: " + f"expected (rotary_dim={self.rotary_dim}, neox={self.is_neox_style}), " + f"got (rotary_dim={layer_rotary_dim}, neox={layer_is_neox}) at layer {layer_id}." + ) + # Extract KV portion of QKV weight qkv_w = attn.qkv_proj.weight kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] @@ -205,42 +250,54 @@ def __init__( # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] self.batched_kv_weight = torch.stack(kv_weights) - self.rotary_dim = getattr(rotary_emb, "rotary_dim", head_dim) - self.is_neox_style = getattr(rotary_emb, "is_neox_style", True) - - if not self.is_neox_style: - raise NotImplementedError("Only neox-style RoPE is supported.") - def materialize( self, ctx_hidden: torch.Tensor, positions: torch.Tensor, - cache_locs: torch.Tensor, - k_cache_buffers: List[torch.Tensor], - v_cache_buffers: List[torch.Tensor], + write_layer_kv: Callable[[int, torch.Tensor, torch.Tensor], None], ) -> None: """Materialize KV cache for all layers using batched projection.""" total_ctx = ctx_hidden.shape[0] if total_ctx == 0: return + if positions.ndim != 1: + positions = positions.reshape(-1) + if positions.numel() != total_ctx: + raise ValueError( + "positions must match ctx_hidden token count for fused KV materialization: " + f"positions={positions.numel()}, total_ctx={total_ctx}." + ) + + max_position = int(positions.max().item()) + ensure_cos_sin_cache_length = getattr( + self.rotary_emb, "_ensure_cos_sin_cache_length", None + ) + if callable(ensure_cos_sin_cache_length): + ensure_cos_sin_cache_length(max_position) + cos_sin_cache = self.rotary_emb.cos_sin_cache + if max_position >= int(cos_sin_cache.shape[0]): + raise RuntimeError( + "RoPE cos/sin cache is too short for fused KV materialization: " + f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." + ) + if cos_sin_cache.device != ctx_hidden.device: + cos_sin_cache = cos_sin_cache.to(ctx_hidden.device) # Batched KV projection: [n_layers, total_ctx, kv_size*2] kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) - # Per-layer fused norm/RoPE/write + # Per-layer fused norm/RoPE/materialize, then delegate writes to the KV pool. for layer_id in range(self.n_layers): - _fused_norm_rope_write( + cache_k, cache_v = _fused_norm_rope( kv_all[layer_id], self.k_norm_weights[layer_id], cos_sin_cache, positions, - cache_locs, - k_cache_buffers[layer_id], - v_cache_buffers[layer_id], self.num_kv_heads, self.head_dim, self.rotary_dim, self.eps_values[layer_id], ) + write_layer_kv(layer_id, cache_k, cache_v) From 0841db68bb38e78b8e75cab0665bf97af633b4c9 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 12 Feb 2026 19:08:21 +0000 Subject: [PATCH 37/73] add support for gpt oss --- python/sglang/srt/models/gpt_oss.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 608b23cd16ee..c1e1ae433cb2 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -1087,6 +1087,9 @@ def _load_normal_weights( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1109,6 +1112,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + @classmethod def get_model_config_for_expert_location(cls, config): return ModelConfigForExpertLocation( From 56477b97de3bfb6be4b5035a225db218cc2f11e2 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 12 Feb 2026 21:40:28 +0000 Subject: [PATCH 38/73] clean up --- python/sglang/srt/managers/scheduler.py | 20 ++++ python/sglang/srt/models/dflash.py | 28 +----- python/sglang/srt/server_args.py | 96 ++++++++++--------- python/sglang/srt/speculative/dflash_utils.py | 68 ++++++++++++- .../sglang/srt/utils/hf_transformers_utils.py | 5 +- 5 files changed, 142 insertions(+), 75 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e98a58964034..d850b84a14d8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1479,6 +1479,26 @@ def handle_generate_request( self._add_request_to_queue(req) return + if self.spec_algorithm.is_dflash() and req.return_logprob: + req.set_finish_with_abort( + "DFLASH speculative decoding does not support return_logprob yet." + ) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return + if self.spec_algorithm.is_dflash() and ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + req.set_finish_with_abort( + "DFLASH speculative decoding does not support grammar-constrained decoding yet." + ) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return + # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index d72d13661524..593bd311c0ca 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -27,7 +27,7 @@ from sglang.srt.models.utils import apply_qk_norm from sglang.srt.speculative.dflash_utils import ( can_dflash_slice_qkv_weight, - get_dflash_config, + resolve_dflash_block_size, resolve_dflash_target_layer_ids, ) @@ -259,8 +259,6 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: num_layers = int(config.num_hidden_layers) rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) - dflash_cfg_dict = get_dflash_config(config) - self.layers = nn.ModuleList( [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] ) @@ -283,29 +281,7 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: ) self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) - dflash_block_size = dflash_cfg_dict.get("block_size", None) - - block_size = ( - dflash_block_size - if dflash_block_size is not None - else getattr(config, "block_size", None) - ) - if block_size is None: - block_size = 16 - elif ( - getattr(config, "block_size", None) is not None - and dflash_block_size is not None - ): - if int(dflash_block_size) != int(getattr(config, "block_size")): - logger.warning( - "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; using dflash_config.block_size.", - getattr(config, "block_size"), - dflash_block_size, - ) - try: - self.block_size = int(block_size) - except Exception as e: - raise ValueError(f"Invalid DFLASH block_size={block_size!r}.") from e + self.block_size = resolve_dflash_block_size(draft_hf_config=config, default=16) def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: """Project concatenated target-layer hidden states into draft hidden_size.""" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ee4335acc693..92b859338033 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2082,6 +2082,12 @@ def _handle_speculative_decoding(self): "Currently DFLASH speculative decoding only supports pp_size == 1." ) + if self.page_size != 1: + raise ValueError( + "Currently DFLASH speculative decoding requires page_size == 1. " + f"Got page_size={self.page_size}." + ) + if self.speculative_draft_model_path is None: raise ValueError( "DFLASH speculative decoding requires setting --speculative-draft-model-path." @@ -2130,73 +2136,69 @@ def _handle_speculative_decoding(self): ) if self.speculative_num_draft_tokens is None: + from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_block_size, + ) + + model_override_args = json.loads(self.json_model_override_args) inferred_block_size = None try: - top_level_block_size = None - dflash_block_size = None - if os.path.isdir(self.speculative_draft_model_path): - draft_config_path = os.path.join( - self.speculative_draft_model_path, "config.json" + from sglang.srt.utils.hf_transformers_utils import download_from_hf + + config_root = ( + self.speculative_draft_model_path + if os.path.isdir(self.speculative_draft_model_path) + else download_from_hf( + self.speculative_draft_model_path, + allow_patterns=["config.json"], + revision=self.speculative_draft_model_revision, ) - if os.path.isfile(draft_config_path): - with open(draft_config_path, "r") as f: - draft_config_json = json.load(f) - top_level_block_size = draft_config_json.get("block_size") - dflash_cfg = draft_config_json.get("dflash_config", None) - dflash_block_size = ( - dflash_cfg.get("block_size") - if isinstance(dflash_cfg, dict) - else None - ) + ) + draft_config_path = os.path.join(config_root, "config.json") + if os.path.isfile(draft_config_path): + with open(draft_config_path, "r") as f: + draft_config_json = json.load(f) + if model_override_args: + draft_config_json.update(model_override_args) + inferred_block_size = resolve_dflash_block_size( + draft_hf_config=draft_config_json, + default=None, + ) + except Exception as e: + logger.warning( + "Failed to infer DFLASH block_size from draft config.json; " + "falling back to transformers config loader. Error: %s", + e, + ) - if top_level_block_size is None and dflash_block_size is None: + if inferred_block_size is None: + try: from sglang.srt.utils.hf_transformers_utils import get_config draft_hf_config = get_config( self.speculative_draft_model_path, trust_remote_code=self.trust_remote_code, revision=self.speculative_draft_model_revision, - model_override_args=json.loads( - self.json_model_override_args - ), + model_override_args=model_override_args, ) - top_level_block_size = getattr( - draft_hf_config, "block_size", None + inferred_block_size = resolve_dflash_block_size( + draft_hf_config=draft_hf_config, + default=None, ) - dflash_cfg = getattr(draft_hf_config, "dflash_config", None) - dflash_block_size = ( - dflash_cfg.get("block_size") - if isinstance(dflash_cfg, dict) - else None + except Exception as e: + logger.warning( + "Failed to infer DFLASH block_size from transformers config loader; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", + e, ) - if dflash_block_size is not None: - inferred_block_size = dflash_block_size - if top_level_block_size is not None and int( - dflash_block_size - ) != int(top_level_block_size): - logger.warning( - "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; " - "using dflash_config.block_size for speculative_num_draft_tokens inference.", - top_level_block_size, - dflash_block_size, - ) - else: - inferred_block_size = top_level_block_size - except Exception as e: - logger.warning( - "Failed to infer DFLASH block_size from draft config; " - "defaulting speculative_num_draft_tokens to 16. Error: %s", - e, - ) - if inferred_block_size is None: inferred_block_size = 16 logger.warning( "speculative_num_draft_tokens is not set; defaulting to %d for DFLASH.", inferred_block_size, ) - self.speculative_num_draft_tokens = int(inferred_block_size) + self.speculative_num_draft_tokens = inferred_block_size if self.max_running_requests is None: self.max_running_requests = 48 diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index d8d7a14c2460..766d14aa2175 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from numbers import Integral from typing import Any, List, Optional, Tuple @@ -8,6 +9,7 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" +logger = logging.getLogger(__name__) def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: @@ -54,7 +56,10 @@ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> Lis def get_dflash_config(config: Any) -> dict: - cfg = getattr(config, "dflash_config", None) + if isinstance(config, dict): + cfg = config.get("dflash_config", None) + else: + cfg = getattr(config, "dflash_config", None) if cfg is None: return {} if isinstance(cfg, dict): @@ -66,6 +71,67 @@ def get_dflash_config(config: Any) -> dict: return {} +def resolve_dflash_block_size( + *, + draft_hf_config: Any, + default: Optional[int] = None, +) -> Optional[int]: + """Resolve DFLASH block size from draft config. + + Precedence: + 1) `dflash_config.block_size` + 2) top-level `block_size` + 3) `default` + """ + dflash_cfg = get_dflash_config(draft_hf_config) + dflash_block_size = dflash_cfg.get("block_size", None) + if isinstance(draft_hf_config, dict): + top_level_block_size = draft_hf_config.get("block_size", None) + else: + top_level_block_size = getattr(draft_hf_config, "block_size", None) + + parsed_dflash_block_size = None + if dflash_block_size is not None: + try: + parsed_dflash_block_size = int(dflash_block_size) + except Exception as e: + raise ValueError( + f"Invalid DFLASH dflash_config.block_size={dflash_block_size!r}." + ) from e + + parsed_top_level_block_size = None + if top_level_block_size is not None: + try: + parsed_top_level_block_size = int(top_level_block_size) + except Exception as e: + raise ValueError( + f"Invalid DFLASH block_size={top_level_block_size!r}." + ) from e + + if ( + parsed_dflash_block_size is not None + and parsed_top_level_block_size is not None + and parsed_dflash_block_size != parsed_top_level_block_size + ): + logger.warning( + "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; using dflash_config.block_size.", + top_level_block_size, + dflash_block_size, + ) + + block_size = ( + parsed_dflash_block_size + if parsed_dflash_block_size is not None + else parsed_top_level_block_size + ) + if block_size is None: + return default + + if block_size <= 0: + raise ValueError(f"DFLASH block_size must be positive, got {block_size}.") + return block_size + + def resolve_dflash_target_layer_ids( *, draft_hf_config: Any, diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index f88c4889d431..976fa7181234 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -105,6 +105,7 @@ def download_from_hf( model_path: str, allow_patterns: Optional[Union[str, list]] = None, + revision: Optional[str] = None, ): if os.path.exists(model_path): return model_path @@ -112,7 +113,9 @@ def download_from_hf( if not allow_patterns: allow_patterns = ["*.json", "*.bin", "*.model"] - return snapshot_download(model_path, allow_patterns=allow_patterns) + return snapshot_download( + model_path, allow_patterns=allow_patterns, revision=revision + ) def get_hf_text_config(config: PretrainedConfig): From d9c68a1f0ee7d96f820632db02aef403572c4aae Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Feb 2026 01:10:10 +0000 Subject: [PATCH 39/73] add qwen3-coder-next support (mamba) --- python/sglang/srt/models/qwen3_next.py | 20 ++++ python/sglang/srt/speculative/dflash_info.py | 3 + .../sglang/srt/speculative/dflash_worker.py | 93 +++++++++++++++++-- 3 files changed, 107 insertions(+), 9 deletions(-) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index a013dca531e4..6aabcd9e55a0 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -868,6 +868,11 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: list[int]): for layer_id in self.layers_to_capture: setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def forward( self, input_ids: torch.Tensor, @@ -1002,6 +1007,9 @@ def forward( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1170,6 +1178,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) + def set_dflash_layers_to_capture(self, layer_ids: list[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + EntryClass = Qwen3NextForCausalLM diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 828c791f47f2..abe133a8fc61 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -124,6 +124,9 @@ class DFlashVerifyInput(SpecInput): draft_token: torch.Tensor positions: torch.Tensor draft_token_num: int + # Kept for compatibility with attention backends that gate tree metadata by `topk > 1`. + # DFLASH verify is linear (non-tree), so this is always 1. + topk: int = 1 # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. custom_mask: torch.Tensor | None = None diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index b5225b368c46..48c81e84fe14 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -44,6 +44,14 @@ def _get_fused_kv_materialize_helper(): class DFlashWorker: """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" + _VERIFY_SKIP_CUSTOM_MASK_BACKENDS = { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } + def __init__( self, server_args: ServerArgs, @@ -518,15 +526,7 @@ def _prepare_for_speculative_decoding( positions=positions, draft_token_num=self.block_size, ) - backend_name = type(self.model_runner.attn_backend).__name__ - skip_custom_mask = backend_name in { - "FlashInferAttnBackend", - "FlashInferMLAAttnBackend", - "FlashAttentionBackend", - "TRTLLMHAAttnBackend", - "TRTLLMMLABackend", - } - build_custom_mask = not skip_custom_mask + _, build_custom_mask = self._resolve_verify_mask_policy() verify_input.prepare_for_verify( batch, self.page_size, @@ -880,6 +880,67 @@ def _write_layer_kv( write_layer_kv=_write_layer_kv, ) + def _resolve_verify_mask_backend_name(self) -> str: + backend = self.model_runner.attn_backend + for _ in range(4): + full_backend = getattr(backend, "full_attn_backend", None) + if full_backend is None: + break + backend = full_backend + return type(backend).__name__ + + def _resolve_verify_mask_policy(self) -> tuple[str, bool]: + backend_name = self._resolve_verify_mask_backend_name() + return backend_name, ( + backend_name not in self._VERIFY_SKIP_CUSTOM_MASK_BACKENDS + ) + + def _update_target_mamba_state_after_verify( + self, + *, + batch: ScheduleBatch, + seq_lens_pre_verify: torch.Tensor, + commit_lens: torch.Tensor, + ) -> None: + """Commit Mamba intermediate states for accepted verify steps. + + During TARGET_VERIFY, Mamba kernels run with `disable_state_update=True` and + cache per-step intermediate states. After acceptance, we need to commit the + state corresponding to each request's last accepted step. + """ + attn_backend = self.target_worker.model_runner.attn_backend + if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): + return + + accepted_steps = commit_lens.to(torch.int64) - 1 + mamba_steps_to_track = None + + if batch.mamba_track_indices is not None: + mamba_track_interval = self.server_args.mamba_track_interval + to_track_mask = ( + seq_lens_pre_verify // mamba_track_interval + != batch.seq_lens // mamba_track_interval + ) + tracking_point = ( + batch.seq_lens // mamba_track_interval * mamba_track_interval + ) + to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0) + can_track_mask = to_track_mask & ( + to_track_ith < commit_lens.to(to_track_ith.dtype) + ) + mamba_steps_to_track = torch.where( + can_track_mask, + to_track_ith.to(torch.int64), + torch.full_like(to_track_ith, -1, dtype=torch.int64), + ) + + attn_backend.update_mamba_state_after_mtp_verify( + accepted_steps=accepted_steps, + mamba_track_indices=batch.mamba_track_indices, + mamba_steps_to_track=mamba_steps_to_track, + model=self.target_worker.model_runner.model, + ) + def forward_batch_generation( self, batch: Union[ScheduleBatch, ModelWorkerBatch], @@ -964,6 +1025,13 @@ def _to_int32_device_tensor(x, *, device=device): assert model_worker_batch.forward_mode.is_target_verify() verify_input = model_worker_batch.spec_info assert isinstance(verify_input, DFlashVerifyInput) + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + batch.seq_lens.clone() if need_mamba_verify_commit else None + ) batch_result = self.target_worker.forward_batch_generation( model_worker_batch, is_verify=True, **kwargs @@ -983,6 +1051,13 @@ def _to_int32_device_tensor(x, *, device=device): logits_output=logits_output, page_size=self.page_size, ) + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) # Update draft state for the next iteration. Also materialize the committed verify tokens # into the draft KV cache immediately so radix cache entries are safe to reuse. From 7e189bdc7937e10f5ce79651c91b1acf5c0b1c07 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Feb 2026 22:11:00 +0000 Subject: [PATCH 40/73] add page size > 1 support --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 24 ++++++----- .../sglang/srt/model_executor/model_runner.py | 15 +++++++ python/sglang/srt/server_args.py | 6 --- python/sglang/srt/speculative/dflash_info.py | 41 +++++++++++++++++-- .../sglang/srt/speculative/dflash_worker.py | 19 ++++++++- 5 files changed, 85 insertions(+), 20 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 5d34ddb06cd3..ca2677ee78ab 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -332,6 +332,12 @@ def main() -> None: parser.add_argument("--mem-fraction-static", type=float, default=0.75) parser.add_argument("--disable-radix-cache", action="store_true") parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument( + "--page-size", + type=int, + default=None, + help="Optional server --page-size override for both baseline and DFLASH runs.", + ) parser.add_argument("--max-running-requests", type=int, default=64) parser.add_argument( "--tp-sizes", @@ -361,7 +367,7 @@ def main() -> None: "--attention-backends", type=str, default="flashinfer,fa3", - help="Comma-separated list. Will auto-skip fa3 on Blackwell/SM<90.", + help="Comma-separated list. Auto-skips unsupported backends for the current GPU.", ) args = parser.parse_args() @@ -397,9 +403,11 @@ def main() -> None: is_blackwell = _is_blackwell() device_sm = get_device_sm() if is_blackwell: - attention_backends = [b for b in attention_backends if b == "flashinfer"] + attention_backends = [b for b in attention_backends if b != "fa3"] if device_sm < 90: attention_backends = [b for b in attention_backends if b != "fa3"] + if device_sm < 100: + attention_backends = [b for b in attention_backends if b != "trtllm_mha"] attention_backends = attention_backends or ["flashinfer"] data_path = _maybe_download_gsm8k(args.data_path) @@ -471,17 +479,13 @@ def main() -> None: str(args.mem_fraction_static), "--max-running-requests", str(args.max_running_requests), + "--cuda-graph-max-bs", + "32", ] - common_server_args.extend( - [ - "--cuda-graph-bs", - *[str(i) for i in range(1, 33)], - "--cuda-graph-max-bs", - "32", - ] - ) if args.disable_radix_cache: common_server_args.append("--disable-radix-cache") + if args.page_size is not None: + common_server_args.extend(["--page-size", str(int(args.page_size))]) if not args.skip_baseline: print(f"\n=== backend={backend} tp={tp} (baseline) ===") diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4ed5b1d106ce..2ef72f1ab119 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2019,6 +2019,21 @@ def get_spec_info(): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + + # Dummy warmup only needs shape metadata; avoid forcing custom-mask mode. + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.server_args.speculative_num_draft_tokens, + custom_mask=None, + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8062669df6e2..55202bba65d9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -2303,12 +2303,6 @@ def _handle_speculative_decoding(self): "Currently DFLASH speculative decoding only supports pp_size == 1." ) - if self.page_size != 1: - raise ValueError( - "Currently DFLASH speculative decoding requires page_size == 1. " - f"Got page_size={self.page_size}." - ) - if self.speculative_draft_model_path is None: raise ValueError( "DFLASH speculative decoding requires setting --speculative-draft-model-path." diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index abe133a8fc61..0d0c5f12dff5 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -19,6 +19,32 @@ from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +def _compute_paged_keep_slots( + *, + prefix_lens: torch.Tensor, + commit_lens: torch.Tensor, + draft_token_num: int, + page_size: int, +) -> torch.Tensor: + """Compute how many draft slots per request must remain allocated. + + The allocator frees at page granularity for paged mode, so we can only release + full pages from the tail after verify. + """ + + if page_size <= 1: + raise ValueError(f"Expected page_size > 1, got {page_size}.") + + seq_dtype = prefix_lens.dtype + extended_lens = prefix_lens + int(draft_token_num) + new_lens = prefix_lens + commit_lens.to(seq_dtype) + aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size + keep_lens = torch.minimum(aligned_new_lens, extended_lens) + keep_slots = (keep_lens - prefix_lens).to(torch.int64) + keep_slots.clamp_(min=0, max=int(draft_token_num)) + return keep_slots + + @dataclass class DFlashDraftInput(SpecInput): """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. @@ -415,10 +441,19 @@ def verify( batch.token_to_kv_pool_allocator.free(out_cache_loc[~keep_mask]) batch.out_cache_loc = out_cache_loc[keep_mask] else: - # Page-size > 1 is not supported in the initial DFlash implementation. - raise NotImplementedError( - "DFLASH verify with page_size > 1 is not supported yet." + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + row_offsets = torch.arange(self.draft_token_num, device=device)[None, :] + keep_slots = _compute_paged_keep_slots( + prefix_lens=batch.seq_lens, + commit_lens=commit_lens, + draft_token_num=self.draft_token_num, + page_size=page_size, ) + free_mask = row_offsets >= keep_slots[:, None] + batch.token_to_kv_pool_allocator.free(out_cache_loc[free_mask]) + + keep_mask = row_offsets < commit_lens[:, None] + batch.out_cache_loc = out_cache_loc[keep_mask] # Update req-level KV cache accounting. for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 48c81e84fe14..90b59033fda2 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -9,6 +9,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.common import get_last_loc from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, @@ -463,7 +464,23 @@ def _prepare_for_speculative_decoding( allocator = self.draft_model_runner.token_to_kv_pool_allocator token_to_kv_pool_state_backup = allocator.backup_state() try: - block_cache_loc = allocator.alloc(bs * self.block_size) + if self.page_size == 1: + block_cache_loc = allocator.alloc(bs * self.block_size) + else: + block_end_cpu = seq_lens_cpu + int(self.block_size) + last_loc = get_last_loc( + self.draft_model_runner.req_to_token_pool.req_to_token, + batch.req_pool_indices, + block_start, + ) + block_cache_loc = allocator.alloc_extend( + block_start, + seq_lens_cpu, + block_end, + block_end_cpu, + last_loc, + bs * self.block_size, + ) if block_cache_loc is None: raise RuntimeError( f"DFLASH draft OOM when allocating {bs * self.block_size} block tokens." From 7a739f8839a6d0ddc09eecdf7ab75e5ae9fa3c38 Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 25 Feb 2026 01:00:06 +0000 Subject: [PATCH 41/73] non greedy --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 66 ++++- python/sglang/srt/speculative/dflash_info.py | 64 ++++- python/sglang/srt/speculative/dflash_utils.py | 262 ++++++++++++++++++ .../sglang/srt/speculative/dflash_worker.py | 15 +- 4 files changed, 387 insertions(+), 20 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index ca2677ee78ab..ba20229fd43a 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -86,13 +86,16 @@ def _send_generate( prompt: str, *, max_new_tokens: int, + temperature: float, + top_p: float, + top_k: int, stop: list[str], timeout_s: int, ) -> dict: sampling_params: dict = { - "temperature": 0.0, - "top_p": 1.0, - "top_k": 1, + "temperature": float(temperature), + "top_p": float(top_p), + "top_k": int(top_k), "max_new_tokens": int(max_new_tokens), } if stop: @@ -114,15 +117,18 @@ def _send_generate_batch( prompts: list[str], *, max_new_tokens: int, + temperature: float, + top_p: float, + top_k: int, stop: list[str], timeout_s: int, ) -> list[dict]: if not prompts: return [] sampling_params: dict = { - "temperature": 0.0, - "top_p": 1.0, - "top_k": 1, + "temperature": float(temperature), + "top_p": float(top_p), + "top_k": int(top_k), "max_new_tokens": int(max_new_tokens), } if stop: @@ -162,6 +168,9 @@ def _run_gsm8k_requests( prompts: list[str], labels: Optional[list[int]], max_new_tokens: int, + temperature: float, + top_p: float, + top_k: int, concurrency: int, batch_requests: bool, stop: list[str], @@ -189,6 +198,9 @@ def _run_gsm8k_requests( base_url, chunk_prompts, max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, stop=stop, timeout_s=timeout_s, ) @@ -222,6 +234,9 @@ def _run_gsm8k_requests( base_url, prompt, max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, stop=stop, timeout_s=timeout_s, ): i @@ -328,6 +343,24 @@ def main() -> None: ) parser.add_argument("--num-shots", type=int, default=0) parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument( + "--temperature", + type=float, + default=0.0, + help="Sampling temperature for /generate requests. Default 0.0 (greedy).", + ) + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="Sampling top-p for /generate requests. Default 1.0.", + ) + parser.add_argument( + "--top-k", + type=int, + default=1, + help="Sampling top-k for /generate requests. Default 1 (greedy).", + ) parser.add_argument("--timeout-s", type=int, default=3600) parser.add_argument("--mem-fraction-static", type=float, default=0.75) parser.add_argument("--disable-radix-cache", action="store_true") @@ -373,6 +406,12 @@ def main() -> None: if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this sweep.") + if args.temperature < 0.0: + raise RuntimeError(f"--temperature must be >= 0, got {args.temperature}.") + if not (0.0 < args.top_p <= 1.0): + raise RuntimeError(f"--top-p must be in (0, 1], got {args.top_p}.") + if args.top_k == 0 or args.top_k < -1: + raise RuntimeError(f"--top-k must be -1 (all vocab) or >= 1, got {args.top_k}.") visible_gpus = int(torch.cuda.device_count()) tp_sizes = [int(x) for x in args.tp_sizes.split(",") if x.strip()] @@ -503,6 +542,9 @@ def main() -> None: baseline_url, "Hello", max_new_tokens=8, + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), stop=[], timeout_s=min(int(args.timeout_s), 300), ) @@ -515,6 +557,9 @@ def main() -> None: prompts=prompts[:n], labels=labels[:n], max_new_tokens=int(args.max_new_tokens), + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), concurrency=int(conc), batch_requests=bool(args.batch_requests), stop=default_stop, @@ -556,6 +601,9 @@ def main() -> None: dflash_url, "Hello", max_new_tokens=8, + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), stop=[], timeout_s=min(int(args.timeout_s), 300), ) @@ -567,6 +615,9 @@ def main() -> None: prompts=prompts[:n], labels=labels[:n], max_new_tokens=int(args.max_new_tokens), + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), concurrency=int(conc), batch_requests=bool(args.batch_requests), stop=default_stop, @@ -602,6 +653,9 @@ def main() -> None: if args.prompt_style == "fewshot_qa": md_lines.append(f"- num_shots: `{args.num_shots}`") md_lines.append(f"- max_new_tokens: `{args.max_new_tokens}`") + md_lines.append( + f"- sampling: `temperature={args.temperature}, top_p={args.top_p}, top_k={args.top_k}`" + ) md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`") md_lines.append(f"- tp_sizes: `{', '.join(str(x) for x in tp_sizes)}`") md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`") diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 0d0c5f12dff5..111ddadd57cd 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -7,6 +7,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, @@ -14,7 +15,11 @@ get_last_loc, ) from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode -from sglang.srt.speculative.dflash_utils import compute_dflash_accept_len_and_bonus +from sglang.srt.speculative.dflash_utils import ( + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func @@ -310,7 +315,7 @@ def verify( logits_output: LogitsProcessorOutput, page_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: - """Greedy DFlash verification. + """DFlash verification for greedy and non-greedy sampling. Returns: new_verified_id: int64 tensor [bs] (the new current token per request) @@ -325,14 +330,55 @@ def verify( bs = batch.batch_size() device = logits_output.next_token_logits.device + sampling_info = batch.sampling_info + if sampling_info is not None: + if len(sampling_info) != bs: + raise RuntimeError( + "DFLASH verify sampling_info size mismatch: " + f"len(sampling_info)={len(sampling_info)}, bs={bs}." + ) + + # Keep speculative verify semantics consistent with normal sampling path. + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + logits_output.next_token_logits, + sampling_info, + num_tokens_in_batch=self.draft_token_num, + ) + + if ( + sampling_info.penalizer_orchestrator.is_required + or sampling_info.logit_bias is not None + ): + linear_penalty = torch.zeros( + (bs, logits_output.next_token_logits.shape[1]), + dtype=torch.float32, + device=device, + ) + sampling_info.apply_logits_bias(linear_penalty) + logits_output.next_token_logits.add_( + torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + ) + candidates = self.draft_token.view(bs, self.draft_token_num) - target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( - bs, self.draft_token_num - ) - accept_len, bonus = compute_dflash_accept_len_and_bonus( - candidates=candidates, - target_predict=target_predict, - ) + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + ) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, self.draft_token_num + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) # Single D2H transfer: candidates[1:] + accept_len + bonus packed = torch.cat( diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 766d14aa2175..92b5dc895298 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -5,12 +5,100 @@ from typing import Any, List, Optional, Tuple import torch +import torch.nn.functional as F from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.server_args import get_global_server_args +from sglang.srt.utils import is_cuda DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" logger = logging.getLogger(__name__) +_DFLASH_SAMPLING_VERIFY_AVAILABLE = False +_DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} + + +if is_cuda(): + try: + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + ) + + _DFLASH_SAMPLING_VERIFY_AVAILABLE = True + except Exception: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None +else: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None + + +def is_dflash_sampling_verify_available() -> bool: + return _DFLASH_SAMPLING_VERIFY_AVAILABLE + + +def _get_or_create_chain_verify_buffers( + *, + bs: int, + draft_token_num: int, + device: torch.device, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + key = (device.index, int(draft_token_num)) + cached = _DFLASH_CHAIN_VERIFY_BUFFERS.get(key) + cap_bs = 0 if cached is None else int(cached["cap_bs"]) + if cap_bs < bs: + new_cap = max(int(bs), cap_bs * 2 if cap_bs > 0 else int(bs)) + retrieve_index = torch.arange( + new_cap * draft_token_num, dtype=torch.int64, device=device + ).view(new_cap, draft_token_num) + row_next = torch.arange( + 1, draft_token_num + 1, dtype=torch.int64, device=device + ) + row_next[-1] = -1 + retrieve_next_token = row_next.unsqueeze(0).expand(new_cap, -1).clone() + retrieve_next_sibling = torch.full( + (new_cap, draft_token_num), -1, dtype=torch.int64, device=device + ) + predicts = torch.empty( + (new_cap * draft_token_num,), dtype=torch.int32, device=device + ) + accept_index = torch.empty( + (new_cap, draft_token_num), dtype=torch.int32, device=device + ) + accept_token_num = torch.empty((new_cap,), dtype=torch.int32, device=device) + cached = { + "cap_bs": int(new_cap), + "retrieve_index": retrieve_index, + "retrieve_next_token": retrieve_next_token, + "retrieve_next_sibling": retrieve_next_sibling, + "predicts": predicts, + "accept_index": accept_index, + "accept_token_num": accept_token_num, + } + _DFLASH_CHAIN_VERIFY_BUFFERS[key] = cached + + assert cached is not None + retrieve_index = cached["retrieve_index"][:bs] + retrieve_next_token = cached["retrieve_next_token"][:bs] + retrieve_next_sibling = cached["retrieve_next_sibling"][:bs] + predicts = cached["predicts"][: bs * draft_token_num] + accept_index = cached["accept_index"][:bs] + accept_token_num = cached["accept_token_num"][:bs] + return ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) + def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: """Select target layer indices used to build DFlash context features. @@ -275,3 +363,177 @@ def compute_dflash_accept_len_and_bonus( accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] return accept_len, bonus.to(torch.int64) + + +def compute_dflash_sampling_accept_len_and_bonus( + *, + candidates: torch.Tensor, + next_token_logits: torch.Tensor, + sampling_info: Any, + threshold_single: Optional[float] = None, + threshold_acc: Optional[float] = None, + uniform_samples: Optional[torch.Tensor] = None, + uniform_samples_for_final_sampling: Optional[torch.Tensor] = None, + use_sparse_topk: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens for non-greedy sampling. + + This is a chain-specialized variant of speculative target-only verification: + - DFlash proposals are linear (topk == 1), so each verify level has at most one candidate. + - When a candidate is rejected at a level, the final token is sampled from + `relu(q - p)` where `p` has only the rejected candidate mass. + """ + if not _DFLASH_SAMPLING_VERIFY_AVAILABLE: + raise RuntimeError( + "DFLASH non-greedy verification is unavailable on this build/device." + ) + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + + bs, draft_token_num = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + if candidates.device != next_token_logits.device: + raise ValueError( + "candidates and next_token_logits must be on the same device, " + f"got {candidates.device} and {next_token_logits.device}." + ) + + if threshold_single is None: + threshold_single = get_global_server_args().speculative_accept_threshold_single + if threshold_acc is None: + threshold_acc = get_global_server_args().speculative_accept_threshold_acc + threshold_single = float(threshold_single) + threshold_acc = max(float(threshold_acc), 1e-9) + + device = next_token_logits.device + + if uniform_samples is None: + uniform_samples = torch.rand( + (bs, draft_token_num), dtype=torch.float32, device=device + ) + else: + if uniform_samples.shape != (bs, draft_token_num): + raise ValueError( + "uniform_samples shape mismatch. " + f"Expected {(bs, draft_token_num)}, got {tuple(uniform_samples.shape)}." + ) + uniform_samples = uniform_samples.to(device=device, dtype=torch.float32) + + if uniform_samples_for_final_sampling is None: + uniform_samples_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=device + ) + else: + if uniform_samples_for_final_sampling.shape != (bs,): + raise ValueError( + "uniform_samples_for_final_sampling shape mismatch. " + f"Expected {(bs,)}, got {tuple(uniform_samples_for_final_sampling.shape)}." + ) + uniform_samples_for_final_sampling = uniform_samples_for_final_sampling.to( + device=device, + dtype=torch.float32, + ) + + need_top_k = bool(getattr(sampling_info, "need_top_k_sampling", True)) + need_top_p = bool(getattr(sampling_info, "need_top_p_sampling", False)) + # Build target distribution once over all verify rows. + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, draft_token_num, dim=0 + ) + scaled_logits = next_token_logits / expanded_temperature + sparse_topk_applied = False + + if use_sparse_topk and need_top_k: + repeated_top_ks = torch.repeat_interleave( + sampling_info.top_ks, draft_token_num, dim=0 + ).to(dtype=torch.int64) + vocab_size = int(scaled_logits.shape[-1]) + repeated_top_ks.clamp_(min=1, max=vocab_size) + max_top_k = int(repeated_top_ks.max().item()) + + # Sparse exact path for top-k/top-p (top-k-first semantics), then scatter to dense. + if 0 < max_top_k < vocab_size: + topk_logits, topk_indices = torch.topk(scaled_logits, k=max_top_k, dim=-1) + if not torch.all(repeated_top_ks == max_top_k): + ranks = torch.arange(max_top_k, device=device, dtype=torch.int64)[ + None, : + ] + valid = ranks < repeated_top_ks.unsqueeze(1) + topk_logits = topk_logits.masked_fill(~valid, float("-inf")) + + topk_probs = F.softmax(topk_logits, dim=-1) + if need_top_p: + repeated_top_ps = torch.repeat_interleave( + sampling_info.top_ps, draft_token_num, dim=0 + ) + topk_probs = top_p_renorm_prob(topk_probs, repeated_top_ps) + + target_probs = torch.zeros_like(scaled_logits, dtype=topk_probs.dtype) + target_probs.scatter_(1, topk_indices, topk_probs) + sparse_topk_applied = True + + if not sparse_topk_applied: + target_probs = F.softmax(scaled_logits, dim=-1) + if need_top_k: + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ks, draft_token_num, dim=0), + ) + if need_top_p: + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ps, draft_token_num, dim=0), + ) + target_probs = target_probs.view(bs, draft_token_num, -1).contiguous() + draft_probs = torch.zeros_like(target_probs) + + ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) = _get_or_create_chain_verify_buffers( + bs=bs, + draft_token_num=draft_token_num, + device=device, + ) + candidates_i64 = ( + candidates if candidates.dtype == torch.int64 else candidates.to(torch.int64) + ) + tree_speculative_sampling_target_only( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates_i64, + retrive_index=retrieve_index, + retrive_next_token=retrieve_next_token, + retrive_next_sibling=retrieve_next_sibling, + uniform_samples=uniform_samples, + uniform_samples_for_final_sampling=uniform_samples_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=threshold_single, + threshold_acc=threshold_acc, + deterministic=True, + ) + + accept_len = accept_token_num + row_ids = torch.arange(bs, dtype=torch.long, device=device) + accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long) + bonus = predicts[accept_pos].to(torch.int64) + return accept_len, bonus diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 90b59033fda2..6ec82275cff0 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -19,6 +19,7 @@ from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput from sglang.srt.speculative.dflash_utils import ( can_dflash_use_fused_qkv_proj, + is_dflash_sampling_verify_available, resolve_dflash_mask_token, resolve_dflash_mask_token_id, ) @@ -75,7 +76,7 @@ def __init__( self.page_size = server_args.page_size self.device = target_worker.device - self._warned_forced_greedy = False + self._warned_sampling_fallback = False self._logged_first_verify = False # Draft runner (separate KV cache + attention backend). @@ -403,12 +404,16 @@ def _prepare_for_speculative_decoding( "DFLASH does not support grammar-constrained decoding yet." ) if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: - if not self._warned_forced_greedy and self.tp_rank == 0: + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): logger.warning( - "DFLASH currently supports greedy verification only; " - "ignoring non-greedy sampling params (e.g. temperature/top_p/top_k) and using argmax." + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." ) - self._warned_forced_greedy = True + self._warned_sampling_fallback = True bs = batch.batch_size() device = self.model_runner.device From 0fe389ddb80fa3be3429fca70c4aa7bde9ab7d76 Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 26 Feb 2026 20:59:43 +0000 Subject: [PATCH 42/73] rope rotation support --- python/sglang/srt/models/dflash.py | 6 ++++++ python/sglang/srt/speculative/dflash_utils.py | 3 --- python/sglang/srt/speculative/dflash_worker.py | 10 ++++++++++ 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 593bd311c0ca..09441fb6ebac 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -92,6 +92,11 @@ def __init__(self, config, layer_id: int) -> None: rope_theta = float(getattr(config, "rope_theta", 1000000)) rope_scaling = getattr(config, "rope_scaling", None) + rope_is_neox_style = bool( + getattr( + config, "rope_is_neox_style", getattr(config, "is_neox_style", True) + ) + ) max_position_embeddings = int(getattr(config, "max_position_embeddings", 32768)) self.rotary_emb = get_rope( head_dim, @@ -99,6 +104,7 @@ def __init__(self, config, layer_id: int) -> None: max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, ) self.scaling = head_dim**-0.5 diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 92b5dc895298..9bb982ed7a56 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -103,9 +103,6 @@ def _get_or_create_chain_verify_buffers( def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: """Select target layer indices used to build DFlash context features. - Mirrors the upstream DFlash helper in `docs/dflash/model/utils.py`, but keeps the - logic local to SGLang. - Args: num_target_layers: Number of transformer layers in the runtime target model. num_draft_layers: Number of layers in the DFlash draft model. diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 4c22ab476826..899ffccf8f8c 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -238,6 +238,16 @@ def _init_fused_kv_helper(self) -> None: ) break + rope_is_neox_style = bool( + getattr(attn.rotary_emb, "is_neox_style", True) + ) + if not rope_is_neox_style: + fused_disable_reason = ( + "non-neox RoPE is not supported for fused KV path: " + f"layer={layer_idx}, rope_is_neox_style={rope_is_neox_style}" + ) + break + if fused_disable_reason is not None: if self.tp_rank == 0: logger.info( From a134f0a0e898e730b5ec2e85014fcd7f76fe0c1a Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 26 Feb 2026 22:14:54 +0000 Subject: [PATCH 43/73] clean up schedule_batch.py --- python/sglang/srt/managers/schedule_batch.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cd7cf599848a..5576cd577f95 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2055,11 +2055,7 @@ def filter_batch( self.seq_lens_cpu = self.seq_lens_cpu[keep_indices] self.orig_seq_lens = self.orig_seq_lens[keep_indices_device] self.out_cache_loc = None - # Use CPU copy to avoid GPU sync. - if self.seq_lens_cpu is not None: - self.seq_lens_sum = int(self.seq_lens_cpu.sum().item()) - else: - self.seq_lens_sum = int(self.seq_lens.sum().item()) + self.seq_lens_sum = self.seq_lens.sum().item() if self.output_ids is not None: self.output_ids = self.output_ids[keep_indices_device] From f62e5de1df6d3e932e2f676f5eefd938d73af69e Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 28 Feb 2026 21:18:49 +0000 Subject: [PATCH 44/73] fix auto memory oom, cleanup --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 15 +- .../srt/model_executor/cuda_graph_runner.py | 16 +- .../sglang/srt/model_executor/model_runner.py | 26 +- .../model_runner_kv_cache_mixin.py | 13 + python/sglang/srt/models/dflash.py | 18 +- python/sglang/srt/server_args.py | 60 +--- python/sglang/srt/speculative/dflash_utils.py | 325 +++++++++++------- .../sglang/srt/speculative/dflash_worker.py | 48 +-- .../sglang/srt/utils/hf_transformers_utils.py | 5 +- 9 files changed, 288 insertions(+), 238 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index ba20229fd43a..8765557796ff 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -362,7 +362,12 @@ def main() -> None: help="Sampling top-k for /generate requests. Default 1 (greedy).", ) parser.add_argument("--timeout-s", type=int, default=3600) - parser.add_argument("--mem-fraction-static", type=float, default=0.75) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=None, + help="Optional server --mem-fraction-static override. If unset, use the server auto heuristic.", + ) parser.add_argument("--disable-radix-cache", action="store_true") parser.add_argument("--dtype", type=str, default="bfloat16") parser.add_argument( @@ -371,7 +376,7 @@ def main() -> None: default=None, help="Optional server --page-size override for both baseline and DFLASH runs.", ) - parser.add_argument("--max-running-requests", type=int, default=64) + parser.add_argument("--max-running-requests", type=int, default=32) parser.add_argument( "--tp-sizes", type=str, @@ -514,13 +519,15 @@ def main() -> None: str(tp), "--dtype", str(args.dtype), - "--mem-fraction-static", - str(args.mem_fraction_static), "--max-running-requests", str(args.max_running_requests), "--cuda-graph-max-bs", "32", ] + if args.mem_fraction_static is not None: + common_server_args.extend( + ["--mem-fraction-static", str(args.mem_fraction_static)] + ) if args.disable_radix_cache: common_server_args.append("--disable-radix-cache") if args.page_size is not None: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 23dd5463823a..b2c759961218 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -1152,24 +1152,22 @@ def get_spec_info(self, num_tokens: int): ) elif self.model_runner.spec_algorithm.is_dflash(): from sglang.srt.speculative.dflash_info import DFlashVerifyInput + from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_verify_mask_policy, + ) - backend_name = type(self.model_runner.attn_backend).__name__ # Avoid enabling custom-mask modes during graph capture for backends that # can express DFLASH verify via their built-in causal path. - skip_custom_mask = backend_name in { - "FlashInferAttnBackend", - "FlashInferMLAAttnBackend", - "FlashAttentionBackend", - "TRTLLMHAAttnBackend", - "TRTLLMMLABackend", - } + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) spec_info = DFlashVerifyInput( draft_token=None, positions=None, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, custom_mask=( None - if (self.model_runner.is_draft_worker or skip_custom_mask) + if (self.model_runner.is_draft_worker or not build_custom_mask) else self.buffers.custom_mask ), capture_hidden_mode=( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 345f8fae2008..7a4764b9260c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -149,7 +149,9 @@ get_global_server_args, set_global_server_args_for_scheduler, ) -from sglang.srt.speculative.dflash_utils import resolve_dflash_target_layer_ids +from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, +) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, @@ -347,6 +349,7 @@ def __init__( self.eagle_use_aux_hidden_state = False self.dflash_use_aux_hidden_state = False self.dflash_target_layer_ids = None + self.dflash_draft_num_layers = None if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -380,21 +383,22 @@ def __init__( model_revision=server_args.speculative_draft_model_revision, is_draft_model=True, ) - draft_num_layers = getattr( - draft_model_config.hf_config, "num_hidden_layers", None + dflash_draft_config = parse_dflash_draft_config( + draft_hf_config=draft_model_config.hf_config ) + draft_num_layers = dflash_draft_config.require_num_layers() + trained_target_layers = dflash_draft_config.num_target_layers + target_num_layers = getattr( self.model_config.hf_config, "num_hidden_layers", None ) - if draft_num_layers is None or target_num_layers is None: + if target_num_layers is None: raise ValueError( - "DFLASH requires both draft and target to expose num_hidden_layers in config. " - f"Got draft={draft_num_layers}, target={target_num_layers}." + "DFLASH requires target num_hidden_layers in config. " + f"Got target={target_num_layers}." ) + target_num_layers = int(target_num_layers) - trained_target_layers = getattr( - draft_model_config.hf_config, "num_target_layers", None - ) if ( trained_target_layers is not None and trained_target_layers != target_num_layers @@ -407,8 +411,8 @@ def __init__( ) self.dflash_use_aux_hidden_state = True - self.dflash_target_layer_ids = resolve_dflash_target_layer_ids( - draft_hf_config=draft_model_config.hf_config, + self.dflash_draft_num_layers = int(draft_num_layers) + self.dflash_target_layer_ids = dflash_draft_config.resolve_target_layer_ids( target_num_layers=int(target_num_layers), draft_num_layers=int(draft_num_layers), ) diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 0fccfd027f7e..a12baa5cc64f 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -24,6 +24,7 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator +from sglang.srt.speculative.dflash_utils import scale_kv_cell_size_per_token_for_dflash from sglang.srt.utils.common import ( get_available_gpu_memory, is_float4_e2m1fn_x2, @@ -139,6 +140,18 @@ def profile_max_num_token(self: ModelRunner, total_gpu_memory: int): num_layers = self.num_effective_layers cell_size = self.get_cell_size_per_token(num_layers) + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + draft_num_layers = getattr(self, "dflash_draft_num_layers", None) + if ( + draft_num_layers is not None + and int(draft_num_layers) > 0 + and int(num_layers) > 0 + ): + cell_size = scale_kv_cell_size_per_token_for_dflash( + target_cell_size_per_token=cell_size, + target_num_layers=int(num_layers), + draft_num_layers=int(draft_num_layers), + ) rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index 09441fb6ebac..c76617e4b40b 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -27,8 +27,7 @@ from sglang.srt.models.utils import apply_qk_norm from sglang.srt.speculative.dflash_utils import ( can_dflash_slice_qkv_weight, - resolve_dflash_block_size, - resolve_dflash_target_layer_ids, + parse_dflash_draft_config, ) logger = logging.getLogger(__name__) @@ -273,11 +272,14 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: # Project per-token target context features: # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer # feature tensors concatenated per token (not necessarily equal to num_layers). - target_num_layers = int(getattr(config, "num_target_layers", num_layers)) - target_layer_ids = resolve_dflash_target_layer_ids( - draft_hf_config=config, - target_num_layers=target_num_layers, - draft_num_layers=num_layers, + draft_config = parse_dflash_draft_config(draft_hf_config=config) + target_num_layers = ( + int(draft_config.num_target_layers) + if draft_config.num_target_layers is not None + else num_layers + ) + target_layer_ids = draft_config.resolve_target_layer_ids( + target_num_layers=target_num_layers, draft_num_layers=num_layers ) num_context_features = len(target_layer_ids) @@ -287,7 +289,7 @@ def __init__(self, config, quant_config=None, prefix: str = "") -> None: ) self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) - self.block_size = resolve_dflash_block_size(draft_hf_config=config, default=16) + self.block_size = draft_config.resolve_block_size(default=16) def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: """Project concatenated target-layer hidden states into draft hidden_size.""" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0c4283aa817b..6f3eb09a4949 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1047,9 +1047,6 @@ def _handle_gpu_memory_settings(self, gpu_mem): if self.speculative_algorithm == "STANDALONE": # standalonedraft model and cuda graphs reserved_mem += 6 * 1024 - elif self.speculative_algorithm == "DFLASH": - # dflash draft model and cuda graphs - reserved_mem += 6 * 1024 elif self.speculative_algorithm != "NGRAM": # eagle draft models and cuda graphs reserved_mem += 2 * 1024 @@ -2449,61 +2446,30 @@ def _handle_speculative_decoding(self): if self.speculative_num_draft_tokens is None: from sglang.srt.speculative.dflash_utils import ( - resolve_dflash_block_size, + parse_dflash_draft_config, ) model_override_args = json.loads(self.json_model_override_args) inferred_block_size = None try: - from sglang.srt.utils.hf_transformers_utils import download_from_hf - - config_root = ( - self.speculative_draft_model_path - if os.path.isdir(self.speculative_draft_model_path) - else download_from_hf( - self.speculative_draft_model_path, - allow_patterns=["config.json"], - revision=self.speculative_draft_model_revision, - ) + from sglang.srt.utils.hf_transformers_utils import get_config + + draft_hf_config = get_config( + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + revision=self.speculative_draft_model_revision, + model_override_args=model_override_args, ) - draft_config_path = os.path.join(config_root, "config.json") - if os.path.isfile(draft_config_path): - with open(draft_config_path, "r") as f: - draft_config_json = json.load(f) - if model_override_args: - draft_config_json.update(model_override_args) - inferred_block_size = resolve_dflash_block_size( - draft_hf_config=draft_config_json, - default=None, - ) + inferred_block_size = parse_dflash_draft_config( + draft_hf_config=draft_hf_config + ).resolve_block_size(default=None) except Exception as e: logger.warning( - "Failed to infer DFLASH block_size from draft config.json; " - "falling back to transformers config loader. Error: %s", + "Failed to infer DFLASH block_size from draft model config; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", e, ) - if inferred_block_size is None: - try: - from sglang.srt.utils.hf_transformers_utils import get_config - - draft_hf_config = get_config( - self.speculative_draft_model_path, - trust_remote_code=self.trust_remote_code, - revision=self.speculative_draft_model_revision, - model_override_args=model_override_args, - ) - inferred_block_size = resolve_dflash_block_size( - draft_hf_config=draft_hf_config, - default=None, - ) - except Exception as e: - logger.warning( - "Failed to infer DFLASH block_size from transformers config loader; " - "defaulting speculative_num_draft_tokens to 16. Error: %s", - e, - ) - if inferred_block_size is None: inferred_block_size = 16 logger.warning( diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 9bb982ed7a56..36ad26c05a96 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -1,6 +1,6 @@ from __future__ import annotations -import logging +from dataclasses import dataclass from numbers import Integral from typing import Any, List, Optional, Tuple @@ -8,14 +8,21 @@ import torch.nn.functional as F from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod -from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import is_cuda DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" -logger = logging.getLogger(__name__) _DFLASH_SAMPLING_VERIFY_AVAILABLE = False _DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} +_DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS = frozenset( + { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } +) if is_cuda(): @@ -41,6 +48,60 @@ def is_dflash_sampling_verify_available() -> bool: return _DFLASH_SAMPLING_VERIFY_AVAILABLE +def scale_kv_cell_size_per_token_for_dflash( + *, + target_cell_size_per_token: int, + target_num_layers: int, + draft_num_layers: int, + draft_cell_size_per_token: Optional[int] = None, +) -> int: + """Compute bytes/token budget for combined target+draft KV pools (DFLASH). + + DFLASH runs a separate draft runner with its own KV pool. The target runner's + token capacity must fit both pools in aggregate. For DFLASH checkpoints, the + draft KV geometry typically matches the target KV geometry and differs mainly + by layer count, so draft KV bytes/token scales linearly with `draft_num_layers`. + + Returns: + Approximate per-token bytes for (target KV + draft KV), expressed as a + scaled version of `target_cell_size_per_token`, unless an explicit + `draft_cell_size_per_token` is provided (in which case we sum them). + """ + if target_cell_size_per_token <= 0: + raise ValueError( + "target_cell_size_per_token must be positive, " + f"got {target_cell_size_per_token}." + ) + + if draft_cell_size_per_token is not None: + draft_cell_size_per_token = int(draft_cell_size_per_token) + if draft_cell_size_per_token <= 0: + raise ValueError( + "draft_cell_size_per_token must be positive when provided, " + f"got {draft_cell_size_per_token}." + ) + return int(target_cell_size_per_token) + int(draft_cell_size_per_token) + + if target_num_layers <= 0 or draft_num_layers <= 0: + return int(target_cell_size_per_token) + + total_layers = int(target_num_layers) + int(draft_num_layers) + return ( + int(target_cell_size_per_token) * int(total_layers) + int(target_num_layers) - 1 + ) // int(target_num_layers) + + +def resolve_dflash_verify_mask_policy(attn_backend: Any) -> tuple[str, bool]: + backend = attn_backend + for _ in range(4): + full_backend = getattr(backend, "full_attn_backend", None) + if full_backend is None: + break + backend = full_backend + backend_name = type(backend).__name__ + return backend_name, (backend_name not in _DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS) + + def _get_or_create_chain_verify_buffers( *, bs: int, @@ -140,7 +201,13 @@ def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> Lis ] -def get_dflash_config(config: Any) -> dict: +def _cfg_get(config: Any, key: str, default: Any = None) -> Any: + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +def _get_dflash_config(config: Any) -> dict: if isinstance(config, dict): cfg = config.get("dflash_config", None) else: @@ -156,145 +223,157 @@ def get_dflash_config(config: Any) -> dict: return {} -def resolve_dflash_block_size( +def _parse_optional_int( + value: Any, *, - draft_hf_config: Any, - default: Optional[int] = None, + field_name: str, + min_value: Optional[int] = None, ) -> Optional[int]: - """Resolve DFLASH block size from draft config. - - Precedence: - 1) `dflash_config.block_size` - 2) top-level `block_size` - 3) `default` - """ - dflash_cfg = get_dflash_config(draft_hf_config) - dflash_block_size = dflash_cfg.get("block_size", None) - if isinstance(draft_hf_config, dict): - top_level_block_size = draft_hf_config.get("block_size", None) - else: - top_level_block_size = getattr(draft_hf_config, "block_size", None) - - parsed_dflash_block_size = None - if dflash_block_size is not None: - try: - parsed_dflash_block_size = int(dflash_block_size) - except Exception as e: + if value is None: + return None + try: + parsed = int(value) + except Exception as e: + raise ValueError(f"Invalid {field_name}={value!r}.") from e + if min_value is not None and parsed < int(min_value): + comparator = "positive" if int(min_value) == 1 else f">= {int(min_value)}" + raise ValueError(f"{field_name} must be {comparator}, got {parsed}.") + return parsed + + +@dataclass(frozen=True) +class DFlashDraftConfig: + num_hidden_layers: Optional[int] + num_target_layers: Optional[int] + block_size: Optional[int] + target_layer_ids: Optional[List[int]] + mask_token: str + mask_token_id: Optional[int] + + def require_num_layers(self) -> int: + if self.num_hidden_layers is None: raise ValueError( - f"Invalid DFLASH dflash_config.block_size={dflash_block_size!r}." - ) from e - - parsed_top_level_block_size = None - if top_level_block_size is not None: - try: - parsed_top_level_block_size = int(top_level_block_size) - except Exception as e: + "DFLASH requires draft num_hidden_layers in config. " + "Got config without num_hidden_layers." + ) + return int(self.num_hidden_layers) + + def resolve_block_size(self, *, default: Optional[int] = None) -> Optional[int]: + return self.block_size if self.block_size is not None else default + + def resolve_target_layer_ids( + self, + *, + target_num_layers: int, + draft_num_layers: Optional[int] = None, + ) -> List[int]: + target_num_layers = int(target_num_layers) + if target_num_layers <= 0: raise ValueError( - f"Invalid DFLASH block_size={top_level_block_size!r}." - ) from e - - if ( - parsed_dflash_block_size is not None - and parsed_top_level_block_size is not None - and parsed_dflash_block_size != parsed_top_level_block_size - ): - logger.warning( - "DFLASH draft config has both block_size=%s and dflash_config.block_size=%s; using dflash_config.block_size.", - top_level_block_size, - dflash_block_size, - ) - - block_size = ( - parsed_dflash_block_size - if parsed_dflash_block_size is not None - else parsed_top_level_block_size - ) - if block_size is None: - return default - - if block_size <= 0: - raise ValueError(f"DFLASH block_size must be positive, got {block_size}.") - return block_size + f"target_num_layers must be positive, got {target_num_layers}." + ) + if self.target_layer_ids is None: + if draft_num_layers is None: + draft_num_layers = self.require_num_layers() + return build_target_layer_ids(target_num_layers, int(draft_num_layers)) -def resolve_dflash_target_layer_ids( - *, - draft_hf_config: Any, - target_num_layers: int, - draft_num_layers: int, -) -> List[int]: - """Resolve target layer ids used to build DFlash context features. + resolved = list(self.target_layer_ids) + if len(resolved) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(resolved)}." + ) + for idx, val in enumerate(resolved): + if val < 0 or val >= target_num_layers: + raise ValueError( + "DFLASH target_layer_ids contains an out-of-range layer id. " + f"target_layer_ids[{idx}]={val}, target_num_layers={target_num_layers}." + ) + return resolved - Precedence: - 1) `draft_hf_config.dflash_config.target_layer_ids` - 2) `draft_hf_config.target_layer_ids` (fallback to base config) - 3) default `build_target_layer_ids(target_num_layers, draft_num_layers)` - Notes: - The number of draft transformer layers is *not* fundamentally tied to the number - of target-layer features (K) used as DFlash context. We treat - `len(target_layer_ids)` as K when explicitly provided. For backward compatibility - (and for current released checkpoints), the default still uses K == draft_num_layers. - """ - cfg = get_dflash_config(draft_hf_config) - layer_ids = cfg.get("target_layer_ids", None) - if layer_ids is None: - layer_ids = getattr(draft_hf_config, "target_layer_ids", None) - if layer_ids is None: - return build_target_layer_ids(target_num_layers, draft_num_layers) +def parse_dflash_draft_config(*, draft_hf_config: Any) -> DFlashDraftConfig: + """Parse and validate DFLASH draft config fields from HF config/dict.""" + dflash_cfg = _get_dflash_config(draft_hf_config) - if not isinstance(layer_ids, (list, tuple)): - raise ValueError( - "DFLASH dflash_config.target_layer_ids must be a list of ints, " - f"got type={type(layer_ids).__name__}." - ) + num_hidden_layers = _parse_optional_int( + _cfg_get(draft_hf_config, "num_hidden_layers", None), + field_name="DFLASH draft num_hidden_layers", + min_value=1, + ) + raw_num_target_layers = dflash_cfg.get( + "num_target_layers", + _cfg_get(draft_hf_config, "num_target_layers", None), + ) + num_target_layers = _parse_optional_int( + raw_num_target_layers, + field_name="DFLASH draft num_target_layers", + min_value=1, + ) - resolved: List[int] = [int(x) for x in layer_ids] - if len(resolved) <= 0: - raise ValueError( - "DFLASH dflash_config.target_layer_ids must be non-empty. " - f"Got len(target_layer_ids)={len(resolved)}." - ) + # Keep support for current checkpoints where block_size is top-level. + raw_block_size = dflash_cfg.get( + "block_size", + _cfg_get(draft_hf_config, "block_size", None), + ) + block_size = _parse_optional_int( + raw_block_size, + field_name="DFLASH block_size", + min_value=1, + ) - for idx, val in enumerate(resolved): - if val < 0 or val >= int(target_num_layers): + layer_ids = dflash_cfg.get( + "target_layer_ids", + _cfg_get(draft_hf_config, "target_layer_ids", None), + ) + parsed_target_layer_ids: Optional[List[int]] + if layer_ids is None: + parsed_target_layer_ids = None + else: + if not isinstance(layer_ids, (list, tuple)): raise ValueError( - "DFLASH target_layer_ids contains an out-of-range layer id. " - f"target_layer_ids[{idx}]={val}, target_num_layers={int(target_num_layers)}." + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(layer_ids).__name__}." + ) + parsed_target_layer_ids = [int(x) for x in layer_ids] + if len(parsed_target_layer_ids) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(parsed_target_layer_ids)}." ) - return resolved - -def resolve_dflash_mask_token(*, draft_hf_config: Any) -> str: - cfg = get_dflash_config(draft_hf_config) - mask_token = cfg.get("mask_token", None) + mask_token = dflash_cfg.get("mask_token", None) if mask_token is None: - return DEFAULT_DFLASH_MASK_TOKEN + mask_token = DEFAULT_DFLASH_MASK_TOKEN if not isinstance(mask_token, str) or not mask_token: raise ValueError( "DFLASH dflash_config.mask_token must be a non-empty string, " f"got {mask_token!r}." ) - return mask_token + mask_token_id = dflash_cfg.get("mask_token_id", None) + if mask_token_id is not None: + if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): + raise ValueError( + "DFLASH dflash_config.mask_token_id must be an integer, " + f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." + ) + mask_token_id = int(mask_token_id) + if mask_token_id < 0: + raise ValueError( + "DFLASH dflash_config.mask_token_id must be non-negative, " + f"got {mask_token_id}." + ) -def resolve_dflash_mask_token_id(*, draft_hf_config: Any) -> Optional[int]: - cfg = get_dflash_config(draft_hf_config) - mask_token_id = cfg.get("mask_token_id", None) - if mask_token_id is None: - return None - if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): - raise ValueError( - "DFLASH dflash_config.mask_token_id must be an integer, " - f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." - ) - mask_token_id = int(mask_token_id) - if mask_token_id < 0: - raise ValueError( - "DFLASH dflash_config.mask_token_id must be non-negative, " - f"got {mask_token_id}." - ) - return mask_token_id + return DFlashDraftConfig( + num_hidden_layers=num_hidden_layers, + num_target_layers=num_target_layers, + block_size=block_size, + target_layer_ids=parsed_target_layer_ids, + mask_token=mask_token, + mask_token_id=mask_token_id, + ) def can_dflash_slice_qkv_weight(qkv_proj: Any) -> Tuple[bool, str]: @@ -409,8 +488,12 @@ def compute_dflash_sampling_accept_len_and_bonus( ) if threshold_single is None: + from sglang.srt.server_args import get_global_server_args + threshold_single = get_global_server_args().speculative_accept_threshold_single if threshold_acc is None: + from sglang.srt.server_args import get_global_server_args + threshold_acc = get_global_server_args().speculative_accept_threshold_acc threshold_single = float(threshold_single) threshold_acc = max(float(threshold_acc), 1e-9) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 899ffccf8f8c..aa2da57686e5 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -20,8 +20,8 @@ from sglang.srt.speculative.dflash_utils import ( can_dflash_use_fused_qkv_proj, is_dflash_sampling_verify_available, - resolve_dflash_mask_token, - resolve_dflash_mask_token_id, + parse_dflash_draft_config, + resolve_dflash_verify_mask_policy, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func @@ -46,14 +46,6 @@ def _get_fused_kv_materialize_helper(): class DFlashWorker: """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" - _VERIFY_SKIP_CUSTOM_MASK_BACKENDS = { - "FlashInferAttnBackend", - "FlashInferMLAAttnBackend", - "FlashAttentionBackend", - "TRTLLMHAAttnBackend", - "TRTLLMMLABackend", - } - def __init__( self, server_args: ServerArgs, @@ -135,12 +127,17 @@ def __init__( ) self.draft_model_runner = self.draft_worker.model_runner self.draft_model = self.draft_model_runner.model + draft_config = parse_dflash_draft_config( + draft_hf_config=self.draft_model_runner.model_config.hf_config + ) if server_args.speculative_num_draft_tokens is None: # Should not happen (ServerArgs should have inferred it), but keep a fallback. - self.block_size = int(getattr(self.draft_model, "block_size", 16)) + self.block_size = int(draft_config.resolve_block_size(default=16)) else: self.block_size = int(server_args.speculative_num_draft_tokens) - model_block_size = getattr(self.draft_model, "block_size", None) + model_block_size = draft_config.block_size + if model_block_size is None: + model_block_size = getattr(self.draft_model, "block_size", None) if model_block_size is not None and int(model_block_size) != int( self.block_size ): @@ -150,12 +147,8 @@ def __init__( model_block_size, ) - self._mask_token = resolve_dflash_mask_token( - draft_hf_config=self.draft_model_runner.model_config.hf_config - ) - self._mask_token_id_override = resolve_dflash_mask_token_id( - draft_hf_config=self.draft_model_runner.model_config.hf_config - ) + self._mask_token = draft_config.mask_token + self._mask_token_id_override = draft_config.mask_token_id self._mask_token_id = self._resolve_mask_token_id( mask_token=self._mask_token, mask_token_id=self._mask_token_id_override, @@ -564,7 +557,9 @@ def _prepare_for_speculative_decoding( positions=positions, draft_token_num=self.block_size, ) - _, build_custom_mask = self._resolve_verify_mask_policy() + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) verify_input.prepare_for_verify( batch, self.page_size, @@ -918,21 +913,6 @@ def _write_layer_kv( write_layer_kv=_write_layer_kv, ) - def _resolve_verify_mask_backend_name(self) -> str: - backend = self.model_runner.attn_backend - for _ in range(4): - full_backend = getattr(backend, "full_attn_backend", None) - if full_backend is None: - break - backend = full_backend - return type(backend).__name__ - - def _resolve_verify_mask_policy(self) -> tuple[str, bool]: - backend_name = self._resolve_verify_mask_backend_name() - return backend_name, ( - backend_name not in self._VERIFY_SKIP_CUSTOM_MASK_BACKENDS - ) - def _update_target_mamba_state_after_verify( self, *, diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 717610f7cab3..9c6886fe54ca 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -119,7 +119,6 @@ def download_from_hf( model_path: str, allow_patterns: Optional[Union[str, list]] = None, - revision: Optional[str] = None, ): if os.path.exists(model_path): return model_path @@ -127,9 +126,7 @@ def download_from_hf( if not allow_patterns: allow_patterns = ["*.json", "*.bin", "*.model"] - return snapshot_download( - model_path, allow_patterns=allow_patterns, revision=revision - ) + return snapshot_download(model_path, allow_patterns=allow_patterns) def get_hf_text_config(config: PretrainedConfig): From 2cc5f070484f9bbac82ba00299b0aad7cefad9a3 Mon Sep 17 00:00:00 2001 From: David Wang Date: Sat, 28 Feb 2026 21:47:06 +0000 Subject: [PATCH 45/73] clean up --- python/sglang/srt/model_executor/model_runner.py | 8 +------- python/sglang/srt/speculative/dflash_utils.py | 4 +--- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7a4764b9260c..3b641ea653f4 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2109,12 +2109,6 @@ def get_spec_info(): else: lora_ids = None - # Use CPU copy to avoid GPU sync. - if buffers.seq_lens_cpu is not None: - seq_lens_sum = int(buffers.seq_lens_cpu.sum().item()) - else: - seq_lens_sum = int(buffers.seq_lens.sum().item()) - forward_batch = ForwardBatch( forward_mode=capture_forward_mode, batch_size=batch_size, @@ -2128,7 +2122,7 @@ def get_spec_info(): token_to_kv_pool=self.token_to_kv_pool, attn_backend=self.attn_backend, out_cache_loc=buffers.out_cache_loc, - seq_lens_sum=seq_lens_sum, + seq_lens_sum=buffers.seq_lens.sum().item(), encoder_lens=buffers.encoder_lens, return_logprob=False, positions=buffers.positions, diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 36ad26c05a96..5a88022a817e 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -58,9 +58,7 @@ def scale_kv_cell_size_per_token_for_dflash( """Compute bytes/token budget for combined target+draft KV pools (DFLASH). DFLASH runs a separate draft runner with its own KV pool. The target runner's - token capacity must fit both pools in aggregate. For DFLASH checkpoints, the - draft KV geometry typically matches the target KV geometry and differs mainly - by layer count, so draft KV bytes/token scales linearly with `draft_num_layers`. + token capacity must fit both pools in aggregate. Returns: Approximate per-token bytes for (target KV + draft KV), expressed as a From 26441b848d136a36aaaaee4df42091a32d1a353c Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 2 Mar 2026 04:19:37 +0000 Subject: [PATCH 46/73] initial fa4 support to dflash, clean up benchmarking script --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 878 ++++++++---------- .../sglang/srt/speculative/dflash_worker.py | 6 +- 2 files changed, 401 insertions(+), 483 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 8765557796ff..14e634d66108 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -5,7 +5,7 @@ GSM8K workload for each (concurrency, num_questions) setting. Example usage: - ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py --output-md dflash_gsm8k_sweep.md + ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --concurrencies 32 --tp-sizes 8 """ @@ -25,7 +25,6 @@ import torch from transformers import AutoTokenizer -from sglang.srt.environ import envs from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -37,25 +36,16 @@ INVALID = -9999999 -def _is_blackwell() -> bool: - # Prefer explicit env var, but also infer from compute capability (SM100+). - if envs.IS_BLACKWELL.get(): - return True - return get_device_sm() >= 100 +def _parse_int_csv(value: str) -> list[int]: + return [int(x) for x in value.split(",") if x.strip()] -def _get_one_example(lines, i: int, include_answer: bool) -> str: - ret = "Question: " + lines[i]["question"] + "\nAnswer:" - if include_answer: - ret += " " + lines[i]["answer"] - return ret - - -def _get_few_shot_examples(lines, k: int) -> str: - ret = "" - for i in range(k): - ret += _get_one_example(lines, i, True) + "\n\n" - return ret +def _filter_attention_backends(backends: list[str], *, device_sm: int) -> list[str]: + if not (80 <= device_sm <= 90): + backends = [b for b in backends if b != "fa3"] + if device_sm < 100: + backends = [b for b in backends if b not in ("fa4", "trtllm_mha")] + return backends or ["flashinfer"] def _get_answer_value(answer_str: str) -> int: @@ -83,47 +73,15 @@ def _flush_cache(base_url: str) -> None: def _send_generate( base_url: str, - prompt: str, - *, - max_new_tokens: int, - temperature: float, - top_p: float, - top_k: int, - stop: list[str], - timeout_s: int, -) -> dict: - sampling_params: dict = { - "temperature": float(temperature), - "top_p": float(top_p), - "top_k": int(top_k), - "max_new_tokens": int(max_new_tokens), - } - if stop: - sampling_params["stop"] = stop - resp = requests.post( - base_url + "/generate", - json={ - "text": prompt, - "sampling_params": sampling_params, - }, - timeout=int(timeout_s), - ) - resp.raise_for_status() - return resp.json() - - -def _send_generate_batch( - base_url: str, - prompts: list[str], + text: str | list[str], *, max_new_tokens: int, temperature: float, top_p: float, top_k: int, - stop: list[str], timeout_s: int, ) -> list[dict]: - if not prompts: + if isinstance(text, list) and not text: return [] sampling_params: dict = { "temperature": float(temperature), @@ -131,24 +89,35 @@ def _send_generate_batch( "top_k": int(top_k), "max_new_tokens": int(max_new_tokens), } - if stop: - sampling_params["stop"] = stop resp = requests.post( base_url + "/generate", json={ - "text": prompts, + "text": text, "sampling_params": sampling_params, }, timeout=int(timeout_s), ) resp.raise_for_status() out = resp.json() - if not isinstance(out, list): + if isinstance(text, list): + if not isinstance(out, list): + raise RuntimeError( + "Expected a list response for batched /generate, but got " + f"type={type(out).__name__}." + ) + if len(out) != len(text): + raise RuntimeError( + "Batched /generate output length mismatch: " + f"got {len(out)} outputs for {len(text)} prompts." + ) + return out + + if isinstance(out, list): raise RuntimeError( - "Expected a list response for batched /generate, but got " + "Expected an object response for single /generate, but got " f"type={type(out).__name__}." ) - return out + return [out] @dataclass(frozen=True) @@ -173,7 +142,6 @@ def _run_gsm8k_requests( top_k: int, concurrency: int, batch_requests: bool, - stop: list[str], timeout_s: int, expect_dflash: bool, ) -> BenchMetrics: @@ -187,6 +155,24 @@ def _run_gsm8k_requests( correct = 0 invalid = 0 + def _handle_output(out: dict, label: Optional[int]) -> None: + nonlocal total_tokens, spec_verify_ct_sum, correct, invalid + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass + + if label is not None: + pred = _get_answer_value(out.get("text", "")) + if pred == INVALID: + invalid += 1 + if pred == label: + correct += 1 + if batch_requests: bs = max(int(concurrency), 1) for start_idx in range(0, len(prompts), bs): @@ -194,72 +180,45 @@ def _run_gsm8k_requests( chunk_labels = ( labels[start_idx : start_idx + bs] if labels is not None else None ) - outs = _send_generate_batch( + outs = _send_generate( base_url, chunk_prompts, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, - stop=stop, timeout_s=timeout_s, ) - if len(outs) != len(chunk_prompts): - raise RuntimeError( - "Batched /generate output length mismatch: " - f"got {len(outs)} outputs for {len(chunk_prompts)} prompts." - ) - - for j, out in enumerate(outs): - meta = out.get("meta_info", {}) or {} - total_tokens += int(meta.get("completion_tokens", 0)) - spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) - if "spec_accept_length" in meta: - try: - spec_accept_lengths.append(float(meta["spec_accept_length"])) - except (TypeError, ValueError): - pass - - if chunk_labels is not None: - pred = _get_answer_value(out.get("text", "")) - if pred == INVALID: - invalid += 1 - if pred == chunk_labels[j]: - correct += 1 + if chunk_labels is None: + for out in outs: + _handle_output(out, None) + else: + for out, label in zip(outs, chunk_labels): + _handle_output(out, label) else: with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: futures = { pool.submit( _send_generate, - base_url, - prompt, + base_url=base_url, + text=prompt, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, - stop=stop, timeout_s=timeout_s, ): i for i, prompt in enumerate(prompts) } for fut in as_completed(futures): i = futures[fut] - out = fut.result() - meta = out.get("meta_info", {}) or {} - total_tokens += int(meta.get("completion_tokens", 0)) - spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) - if "spec_accept_length" in meta: - try: - spec_accept_lengths.append(float(meta["spec_accept_length"])) - except (TypeError, ValueError): - pass - - if labels is not None: - pred = _get_answer_value(out.get("text", "")) - if pred == INVALID: - invalid += 1 - if pred == labels[i]: - correct += 1 + outs = fut.result() + if len(outs) != 1: + raise RuntimeError( + "Expected exactly one output for single /generate request." + ) + label = None if labels is None else labels[i] + _handle_output(outs[0], label) latency = time.perf_counter() - start toks_per_s = total_tokens / max(latency, 1e-6) @@ -300,30 +259,286 @@ def _format_table( float_fmt: str, ) -> str: header = ["tp\\conc"] + [str(c) for c in concurrencies] - lines = [ - "| " + " | ".join(header) + " |", - "| " + " | ".join(["---"] * len(header)) + " |", - ] + rows: list[list[str]] = [header] for tp in tp_sizes: row = [str(tp)] for c in concurrencies: v = values.get((tp, c), None) row.append("N/A" if v is None else format(v, float_fmt)) - lines.append("| " + " | ".join(row) + " |") + rows.append(row) + + col_widths = [ + max(len(row[col_idx]) for row in rows) for col_idx in range(len(rows[0])) + ] + + lines: list[str] = [] + lines.append(" ".join(cell.rjust(col_widths[i]) for i, cell in enumerate(rows[0]))) + lines.append(" ".join("-" * w for w in col_widths)) + for row in rows[1:]: + lines.append(" ".join(cell.rjust(col_widths[i]) for i, cell in enumerate(row))) return "\n".join(lines) -def main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--output-md", - type=str, - default=None, - help="Write a markdown report to this file (disabled by default).", +def _build_common_server_args( + args: argparse.Namespace, *, backend: str, tp: int +) -> list[str]: + common_server_args: list[str] = [ + "--trust-remote-code", + "--attention-backend", + backend, + "--tp-size", + str(tp), + "--dtype", + str(args.dtype), + "--max-running-requests", + str(args.max_running_requests), + "--cuda-graph-max-bs", + "32", + ] + if args.mem_fraction_static is not None: + common_server_args.extend( + ["--mem-fraction-static", str(args.mem_fraction_static)] + ) + if args.disable_radix_cache: + common_server_args.append("--disable-radix-cache") + if args.page_size is not None: + common_server_args.extend(["--page-size", str(int(args.page_size))]) + return common_server_args + + +def _build_mode_runs( + args: argparse.Namespace, common_server_args: list[str] +) -> list[tuple[str, str, list[str], bool]]: + mode_runs: list[tuple[str, str, list[str], bool]] = [] + if not args.skip_baseline: + mode_runs.append(("baseline", "baseline", common_server_args, False)) + mode_runs.append( + ( + "dflash", + "DFLASH", + [ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + args.draft_model, + *( + [ + "--speculative-draft-attention-backend", + args.speculative_draft_attention_backend, + ] + if args.speculative_draft_attention_backend + else [] + ), + ], + True, + ) ) - parser.add_argument("--data-path", type=str, default="test.jsonl") - parser.add_argument("--target-model", type=str, default="Qwen/Qwen3-8B") - parser.add_argument("--draft-model", type=str, default="z-lab/Qwen3-8B-DFlash-b16") + return mode_runs + + +def _collect_metric( + *, + results: dict[tuple[str, int, int, str], BenchMetrics], + backend: str, + tp_sizes: list[int], + concurrencies: list[int], + mode: str, + field: str, +) -> dict[tuple[int, int], Optional[float]]: + out: dict[tuple[int, int], Optional[float]] = {} + for tp in tp_sizes: + for conc in concurrencies: + metrics = results.get((backend, tp, conc, mode), None) + out[(tp, conc)] = None if metrics is None else getattr(metrics, field) + return out + + +def _compute_speedup( + baseline: dict[tuple[int, int], Optional[float]], + dflash: dict[tuple[int, int], Optional[float]], +) -> dict[tuple[int, int], Optional[float]]: + return { + key: None if (b is None or d is None or b <= 0) else (d / b) + for key, b in baseline.items() + for d in [dflash.get(key, None)] + } + + +def _print_kv_lines(items: list[tuple[str, object]]) -> None: + for key, value in items: + print(f"{key}={value}") + + +def _run_mode_for_backend_tp( + *, + mode_label: str, + model_path: str, + base_url: str, + server_args: list[str], + expect_dflash: bool, + prompts: list[str], + labels: list[int], + concurrencies: list[int], + num_questions_by_conc: dict[int, int], + args: argparse.Namespace, +) -> dict[int, BenchMetrics]: + print(f"\n=== {mode_label} ===") + proc = popen_launch_server( + model_path, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=server_args, + ) + try: + _send_generate( + base_url, + "Hello", + max_new_tokens=8, + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), + timeout_s=min(int(args.timeout_s), 300), + ) + + metrics_by_conc: dict[int, BenchMetrics] = {} + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(base_url) + metrics = _run_gsm8k_requests( + base_url, + prompts=prompts[:n], + labels=labels[:n], + max_new_tokens=int(args.max_new_tokens), + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), + concurrency=int(conc), + batch_requests=bool(args.batch_requests), + timeout_s=int(args.timeout_s), + expect_dflash=expect_dflash, + ) + metrics_by_conc[conc] = metrics + line = ( + f"[{mode_label}] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" + ) + if expect_dflash: + accept_len = ( + "N/A" + if metrics.spec_accept_length is None + else f"{metrics.spec_accept_length:.3f}" + ) + line += ( + f" accept_len={accept_len} " + f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" + ) + print(line) + return metrics_by_conc + finally: + kill_process_tree(proc.pid) + try: + proc.wait(timeout=30) + except Exception: + pass + + +def _print_summary( + *, + args: argparse.Namespace, + attention_backends: list[str], + tp_sizes: list[int], + concurrencies: list[int], + device_sm: int, + results: dict[tuple[str, int, int, str], BenchMetrics], +) -> None: + print("\n=== DFLASH GSM8K Sweep Summary ===") + _print_kv_lines( + [ + ("target_model", args.target_model), + ("draft_model", args.draft_model), + ("max_new_tokens", args.max_new_tokens), + ( + "sampling", + f"temperature:{args.temperature}, top_p:{args.top_p}, top_k:{args.top_k}", + ), + ("attention_backends", ",".join(attention_backends)), + ( + "speculative_draft_attention_backend", + args.speculative_draft_attention_backend, + ), + ("tp_sizes", ",".join(str(x) for x in tp_sizes)), + ("concurrencies", ",".join(str(x) for x in concurrencies)), + ( + "questions_per_concurrency_base", + args.questions_per_concurrency_base, + ), + ("device_sm", device_sm), + ("skip_baseline", bool(args.skip_baseline)), + ] + ) + + section_fields = [ + ("Baseline output tok/s", "baseline", "output_toks_per_s", ",.2f"), + ("Baseline accuracy", "baseline", "accuracy", ".3f"), + ("DFLASH output tok/s", "dflash", "output_toks_per_s", ",.2f"), + ("DFLASH accuracy", "dflash", "accuracy", ".3f"), + ( + "DFLASH acceptance length (mean spec_accept_length)", + "dflash", + "spec_accept_length", + ".3f", + ), + ] + + for backend in attention_backends: + print(f"\n=== Backend: {backend} ===") + metrics_map = { + (mode, field): _collect_metric( + results=results, + backend=backend, + tp_sizes=tp_sizes, + concurrencies=concurrencies, + mode=mode, + field=field, + ) + for _, mode, field, _ in section_fields + } + sections: list[tuple[str, dict[tuple[int, int], Optional[float]], str]] = [ + (title, metrics_map[(mode, field)], fmt) + for title, mode, field, fmt in section_fields + ] + sections.insert( + 4, + ( + "Speedup (DFLASH / baseline)", + _compute_speedup( + metrics_map[("baseline", "output_toks_per_s")], + metrics_map[("dflash", "output_toks_per_s")], + ), + ".3f", + ), + ) + + for title, values, fmt in sections: + print(f"\n{title}") + print( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values=values, + float_fmt=fmt, + ) + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="test.jsonl") + parser.add_argument("--target-model", default="Qwen/Qwen3-8B") + parser.add_argument("--draft-model", default="z-lab/Qwen3-8B-DFlash-b16") parser.add_argument( "--skip-baseline", action="store_true", @@ -334,33 +549,10 @@ def main() -> None: action="store_true", help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.", ) - parser.add_argument( - "--prompt-style", - type=str, - choices=["fewshot_qa", "chat"], - default="chat", - help="Prompting style: 'chat' matches the DFlash HF demo prompt.", - ) - parser.add_argument("--num-shots", type=int, default=0) parser.add_argument("--max-new-tokens", type=int, default=2048) - parser.add_argument( - "--temperature", - type=float, - default=0.0, - help="Sampling temperature for /generate requests. Default 0.0 (greedy).", - ) - parser.add_argument( - "--top-p", - type=float, - default=1.0, - help="Sampling top-p for /generate requests. Default 1.0.", - ) - parser.add_argument( - "--top-k", - type=int, - default=1, - help="Sampling top-k for /generate requests. Default 1 (greedy).", - ) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=1) parser.add_argument("--timeout-s", type=int, default=3600) parser.add_argument( "--mem-fraction-static", @@ -369,7 +561,7 @@ def main() -> None: help="Optional server --mem-fraction-static override. If unset, use the server auto heuristic.", ) parser.add_argument("--disable-radix-cache", action="store_true") - parser.add_argument("--dtype", type=str, default="bfloat16") + parser.add_argument("--dtype", default="bfloat16") parser.add_argument( "--page-size", type=int, @@ -377,18 +569,8 @@ def main() -> None: help="Optional server --page-size override for both baseline and DFLASH runs.", ) parser.add_argument("--max-running-requests", type=int, default=32) - parser.add_argument( - "--tp-sizes", - type=str, - default="1,2,4,8", - help="Comma-separated list, filtered by visible CUDA devices.", - ) - parser.add_argument( - "--concurrencies", - type=str, - default="1,2,4,8,16,32", - help="Comma-separated list of client concurrency levels.", - ) + parser.add_argument("--tp-sizes", default="1,2,4,8") + parser.add_argument("--concurrencies", default="1,2,4,8,16,32") parser.add_argument( "--questions-per-concurrency-base", type=int, @@ -401,13 +583,17 @@ def main() -> None: default=1024, help="Cap num_questions per (tp, concurrency) run (default: 1024).", ) + parser.add_argument("--attention-backends", default="flashinfer,fa3,trtllm_mha,fa4") parser.add_argument( - "--attention-backends", - type=str, - default="flashinfer,fa3", - help="Comma-separated list. Auto-skips unsupported backends for the current GPU.", + "--speculative-draft-attention-backend", + default=None, + help="Optional server --speculative-draft-attention-backend override for DFLASH runs.", ) - args = parser.parse_args() + return parser.parse_args() + + +def main() -> None: + args = parse_args() if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for this sweep.") @@ -419,7 +605,7 @@ def main() -> None: raise RuntimeError(f"--top-k must be -1 (all vocab) or >= 1, got {args.top_k}.") visible_gpus = int(torch.cuda.device_count()) - tp_sizes = [int(x) for x in args.tp_sizes.split(",") if x.strip()] + tp_sizes = _parse_int_csv(args.tp_sizes) tp_sizes = [tp for tp in tp_sizes if tp >= 1 and tp <= visible_gpus] if not tp_sizes: raise RuntimeError( @@ -427,7 +613,7 @@ def main() -> None: "Set CUDA_VISIBLE_DEVICES accordingly." ) - concurrencies = [int(x) for x in args.concurrencies.split(",") if x.strip()] + concurrencies = _parse_int_csv(args.concurrencies) concurrencies = [c for c in concurrencies if c >= 1] if not concurrencies: raise RuntimeError("No concurrencies specified.") @@ -444,15 +630,10 @@ def main() -> None: attention_backends = [ s.strip() for s in args.attention_backends.split(",") if s.strip() ] - is_blackwell = _is_blackwell() device_sm = get_device_sm() - if is_blackwell: - attention_backends = [b for b in attention_backends if b != "fa3"] - if device_sm < 90: - attention_backends = [b for b in attention_backends if b != "fa3"] - if device_sm < 100: - attention_backends = [b for b in attention_backends if b != "trtllm_mha"] - attention_backends = attention_backends or ["flashinfer"] + attention_backends = _filter_attention_backends( + attention_backends, device_sm=device_sm + ) data_path = _maybe_download_gsm8k(args.data_path) lines = list(read_jsonl(data_path)) @@ -461,330 +642,65 @@ def main() -> None: f"GSM8K file only has {len(lines)} lines, but need {max_questions}." ) - tokenizer = None - if args.prompt_style == "chat": - tokenizer = AutoTokenizer.from_pretrained(args.target_model) - - few_shot = ( - _get_few_shot_examples(lines, int(args.num_shots)) - if args.prompt_style == "fewshot_qa" - else "" - ) + tokenizer = AutoTokenizer.from_pretrained(args.target_model) prompts: list[str] = [] labels: list[int] = [] for i in range(max_questions): - if args.prompt_style == "fewshot_qa": - prompts.append(few_shot + _get_one_example(lines, i, False)) - else: - assert tokenizer is not None - user_content = ( - lines[i]["question"] - + "\nPlease reason step by step, and put your final answer within \\boxed{}." - ) - prompts.append( - tokenizer.apply_chat_template( - [{"role": "user", "content": user_content}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) + user_content = ( + lines[i]["question"] + + "\nPlease reason step by step, and put your final answer within \\boxed{}." + ) + prompts.append( + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, ) + ) labels.append(_get_answer_value(lines[i]["answer"])) if not all(l != INVALID for l in labels): raise RuntimeError("Invalid labels in GSM8K data.") - default_stop = ( - ["Question", "Assistant:", "<|separator|>"] - if args.prompt_style == "fewshot_qa" - else [] - ) - - # Results indexed by (backend, tp, concurrency) for baseline + dflash. - baseline_toks: dict[tuple[str, int, int], Optional[float]] = {} - dflash_toks: dict[tuple[str, int, int], Optional[float]] = {} - dflash_accept_len: dict[tuple[str, int, int], Optional[float]] = {} - baseline_acc: dict[tuple[str, int, int], Optional[float]] = {} - dflash_acc: dict[tuple[str, int, int], Optional[float]] = {} + # Results indexed by (backend, tp, concurrency, mode). + results: dict[tuple[str, int, int, str], BenchMetrics] = {} for backend in attention_backends: for tp in tp_sizes: port_base = find_available_port(20000) - - common_server_args: list[str] = [ - "--trust-remote-code", - "--attention-backend", - backend, - "--tp-size", - str(tp), - "--dtype", - str(args.dtype), - "--max-running-requests", - str(args.max_running_requests), - "--cuda-graph-max-bs", - "32", - ] - if args.mem_fraction_static is not None: - common_server_args.extend( - ["--mem-fraction-static", str(args.mem_fraction_static)] - ) - if args.disable_radix_cache: - common_server_args.append("--disable-radix-cache") - if args.page_size is not None: - common_server_args.extend(["--page-size", str(int(args.page_size))]) - - if not args.skip_baseline: - print(f"\n=== backend={backend} tp={tp} (baseline) ===") - baseline_port = port_base - baseline_url = f"http://127.0.0.1:{baseline_port}" - baseline_proc = popen_launch_server( - args.target_model, - baseline_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=common_server_args, - ) - try: - # Warm up. - _send_generate( - baseline_url, - "Hello", - max_new_tokens=8, - temperature=float(args.temperature), - top_p=float(args.top_p), - top_k=int(args.top_k), - stop=[], - timeout_s=min(int(args.timeout_s), 300), - ) - - for conc in concurrencies: - n = num_questions_by_conc[conc] - _flush_cache(baseline_url) - metrics = _run_gsm8k_requests( - baseline_url, - prompts=prompts[:n], - labels=labels[:n], - max_new_tokens=int(args.max_new_tokens), - temperature=float(args.temperature), - top_p=float(args.top_p), - top_k=int(args.top_k), - concurrency=int(conc), - batch_requests=bool(args.batch_requests), - stop=default_stop, - timeout_s=int(args.timeout_s), - expect_dflash=False, - ) - baseline_toks[(backend, tp, conc)] = metrics.output_toks_per_s - baseline_acc[(backend, tp, conc)] = metrics.accuracy - print( - f"[baseline] conc={conc:>2} n={n:<4} " - f"toks/s={metrics.output_toks_per_s:,.2f} " - f"latency={metrics.latency_s:.1f}s " - f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" - ) - finally: - kill_process_tree(baseline_proc.pid) - try: - baseline_proc.wait(timeout=30) - except Exception: - pass - - print(f"\n=== backend={backend} tp={tp} (DFLASH) ===") - dflash_port = find_available_port(port_base + 1) - dflash_url = f"http://127.0.0.1:{dflash_port}" - dflash_proc = popen_launch_server( - args.target_model, - dflash_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - *common_server_args, - "--speculative-algorithm", - "DFLASH", - "--speculative-draft-model-path", - args.draft_model, - ], - ) - try: - _send_generate( - dflash_url, - "Hello", - max_new_tokens=8, - temperature=float(args.temperature), - top_p=float(args.top_p), - top_k=int(args.top_k), - stop=[], - timeout_s=min(int(args.timeout_s), 300), + common_server_args = _build_common_server_args(args, backend=backend, tp=tp) + mode_runs = _build_mode_runs(args, common_server_args) + + for idx, ( + mode_key, + mode_name, + mode_server_args, + expect_dflash, + ) in enumerate(mode_runs): + mode_metrics = _run_mode_for_backend_tp( + mode_label=f"backend={backend} tp={tp} ({mode_name})", + model_path=args.target_model, + base_url=f"http://127.0.0.1:{find_available_port(port_base + idx)}", + server_args=mode_server_args, + expect_dflash=expect_dflash, + prompts=prompts, + labels=labels, + concurrencies=concurrencies, + num_questions_by_conc=num_questions_by_conc, + args=args, ) - for conc in concurrencies: - n = num_questions_by_conc[conc] - _flush_cache(dflash_url) - metrics = _run_gsm8k_requests( - dflash_url, - prompts=prompts[:n], - labels=labels[:n], - max_new_tokens=int(args.max_new_tokens), - temperature=float(args.temperature), - top_p=float(args.top_p), - top_k=int(args.top_k), - concurrency=int(conc), - batch_requests=bool(args.batch_requests), - stop=default_stop, - timeout_s=int(args.timeout_s), - expect_dflash=True, - ) - dflash_toks[(backend, tp, conc)] = metrics.output_toks_per_s - dflash_accept_len[(backend, tp, conc)] = metrics.spec_accept_length - dflash_acc[(backend, tp, conc)] = metrics.accuracy - print( - f"[DFLASH] conc={conc:>2} n={n:<4} " - f"toks/s={metrics.output_toks_per_s:,.2f} " - f"latency={metrics.latency_s:.1f}s " - f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f} " - f"accept_len={metrics.spec_accept_length:.3f} " - f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" - ) - finally: - kill_process_tree(dflash_proc.pid) - try: - dflash_proc.wait(timeout=30) - except Exception: - pass - - # Render markdown. - md_lines: list[str] = [] - md_lines.append("# DFLASH GSM8K Sweep") - md_lines.append("") - md_lines.append("## Settings") - md_lines.append(f"- target_model: `{args.target_model}`") - md_lines.append(f"- draft_model: `{args.draft_model}`") - md_lines.append(f"- prompt_style: `{args.prompt_style}`") - if args.prompt_style == "fewshot_qa": - md_lines.append(f"- num_shots: `{args.num_shots}`") - md_lines.append(f"- max_new_tokens: `{args.max_new_tokens}`") - md_lines.append( - f"- sampling: `temperature={args.temperature}, top_p={args.top_p}, top_k={args.top_k}`" - ) - md_lines.append(f"- attention_backends: `{', '.join(attention_backends)}`") - md_lines.append(f"- tp_sizes: `{', '.join(str(x) for x in tp_sizes)}`") - md_lines.append(f"- concurrencies: `{', '.join(str(x) for x in concurrencies)}`") - md_lines.append( - f"- questions_per_concurrency: `base={args.questions_per_concurrency_base}`" + for conc, metrics in mode_metrics.items(): + results[(backend, tp, conc, mode_key)] = metrics + + _print_summary( + args=args, + attention_backends=attention_backends, + tp_sizes=tp_sizes, + concurrencies=concurrencies, + device_sm=device_sm, + results=results, ) - md_lines.append(f"- device_sm: `{device_sm}`") - md_lines.append(f"- is_blackwell: `{is_blackwell}`") - md_lines.append(f"- skip_baseline: `{bool(args.skip_baseline)}`") - md_lines.append("") - md_lines.append( - "Note: DFLASH and baseline greedy outputs may diverge on some prompts due to numerical differences " - "(e.g. verify path vs decode path). This sweep focuses on throughput." - ) - md_lines.append("") - - for backend in attention_backends: - md_lines.append(f"## Backend: `{backend}`") - md_lines.append("") - - baseline_values = { - (tp, conc): baseline_toks.get((backend, tp, conc), None) - for tp in tp_sizes - for conc in concurrencies - } - dflash_values = { - (tp, conc): dflash_toks.get((backend, tp, conc), None) - for tp in tp_sizes - for conc in concurrencies - } - speedup_values: dict[tuple[int, int], Optional[float]] = {} - for tp in tp_sizes: - for conc in concurrencies: - b = baseline_values.get((tp, conc), None) - d = dflash_values.get((tp, conc), None) - speedup_values[(tp, conc)] = ( - None if (b is None or d is None or b <= 0) else (d / b) - ) - - md_lines.append("### Baseline output tok/s") - md_lines.append( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values=baseline_values, - float_fmt=",.2f", - ) - ) - md_lines.append("") - md_lines.append("### Baseline accuracy") - md_lines.append( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values={ - (tp, conc): baseline_acc.get((backend, tp, conc), None) - for tp in tp_sizes - for conc in concurrencies - }, - float_fmt=".3f", - ) - ) - md_lines.append("") - md_lines.append("### DFLASH output tok/s") - md_lines.append( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values=dflash_values, - float_fmt=",.2f", - ) - ) - md_lines.append("") - md_lines.append("### DFLASH accuracy") - md_lines.append( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values={ - (tp, conc): dflash_acc.get((backend, tp, conc), None) - for tp in tp_sizes - for conc in concurrencies - }, - float_fmt=".3f", - ) - ) - md_lines.append("") - md_lines.append("### Speedup (DFLASH / baseline)") - md_lines.append( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values=speedup_values, - float_fmt=".3f", - ) - ) - md_lines.append("") - - md_lines.append( - "### DFLASH acceptance length (mean per-request spec_accept_length)" - ) - md_lines.append( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values={ - (tp, conc): dflash_accept_len.get((backend, tp, conc), None) - for tp in tp_sizes - for conc in concurrencies - }, - float_fmt=".3f", - ) - ) - md_lines.append("") - - if args.output_md: - with open(args.output_md, "w", encoding="utf-8") as f: - f.write("\n".join(md_lines)) - f.write("\n") - print(f"\nWrote markdown report to: {args.output_md}") - else: - print("\nMarkdown report disabled (pass --output-md to write one).") if __name__ == "__main__": diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index aa2da57686e5..00945758aa53 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -84,6 +84,7 @@ def __init__( draft_server_args = deepcopy(server_args) draft_server_args.skip_tokenizer_init = True draft_backend = draft_server_args.speculative_draft_attention_backend + supported_draft_backends = ("flashinfer", "fa3", "fa4") if draft_backend is None: draft_backend, _ = draft_server_args.get_attention_backends() if draft_backend is None: @@ -94,10 +95,11 @@ def __init__( "falling back to 'flashinfer'." ) draft_backend = "flashinfer" - elif draft_backend not in ("flashinfer", "fa3"): + elif draft_backend not in supported_draft_backends: logger.warning( - "DFLASH draft worker only supports attention_backend in {'flashinfer', 'fa3'} for now, " + "DFLASH draft worker only supports attention_backend in %s for now, " "but got %r. Falling back to 'flashinfer'.", + supported_draft_backends, draft_backend, ) draft_backend = "flashinfer" From 74814dedeb450203a30b77aa54df0ede37ba1c42 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 2 Mar 2026 07:36:33 +0000 Subject: [PATCH 47/73] clean up --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 2 +- python/sglang/srt/mem_cache/common.py | 5 ----- .../srt/model_executor/cuda_graph_runner.py | 2 ++ .../sglang/srt/model_executor/model_runner.py | 8 +++---- python/sglang/srt/models/qwen2_moe.py | 5 ----- python/sglang/srt/models/qwen3_moe.py | 5 +++++ python/sglang/srt/speculative/dflash_info.py | 2 +- .../sglang/srt/speculative/dflash_worker.py | 18 ++++------------ python/sglang/srt/speculative/draft_utils.py | 21 ------------------- 9 files changed, 17 insertions(+), 51 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 14e634d66108..b42e15f9811a 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -660,7 +660,7 @@ def main() -> None: ) ) labels.append(_get_answer_value(lines[i]["answer"])) - if not all(l != INVALID for l in labels): + if not all(label != INVALID for label in labels): raise RuntimeError("Invalid labels in GSM8K data.") # Results indexed by (backend, tp, concurrency, mode). diff --git a/python/sglang/srt/mem_cache/common.py b/python/sglang/srt/mem_cache/common.py index 896536700f6f..5a759ed11bcd 100644 --- a/python/sglang/srt/mem_cache/common.py +++ b/python/sglang/srt/mem_cache/common.py @@ -474,9 +474,6 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr req.mamba_pool_idx.unsqueeze(-1) ) req.mamba_pool_idx = None - # DFLASH tracks per-request draft progress on Req. - if hasattr(req, "dflash_draft_seq_len"): - req.dflash_draft_seq_len = 0 return tree_cache.cache_finished_req(req, is_insert=is_insert) @@ -516,8 +513,6 @@ def release_kv_cache(req: Req, tree_cache: BasePrefixCache, is_insert: bool = Tr ), "mamba state is freed while the tree cache does not manage mamba states" tree_cache.req_to_token_pool.free_mamba_cache(req) tree_cache.req_to_token_pool.free(req) - if hasattr(req, "dflash_draft_seq_len"): - req.dflash_draft_seq_len = 0 def available_and_evictable_str(tree_cache: BasePrefixCache) -> str: diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b2c759961218..089590c5cc2c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -601,6 +601,7 @@ def can_run(self, forward_batch: ForwardBatch): max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max(forward_batch.global_num_tokens_cpu) ) else: @@ -1010,6 +1011,7 @@ def replay_prepare( max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max_num_tokens ) index = bisect.bisect_left(self.capture_bs, max_batch_size) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 2486a83497ac..d3f7ae899e7b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1962,10 +1962,10 @@ def _dummy_run(self, batch_size: int, run_ctx=None): or self.spec_algorithm.is_dflash() ): if self.is_draft_worker: - raise RuntimeError("This should not happen") - else: - capture_forward_mode = ForwardMode.TARGET_VERIFY - num_tokens_per_bs = self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + capture_forward_mode = ForwardMode.TARGET_VERIFY + num_tokens_per_bs = self.server_args.speculative_num_draft_tokens if self.server_args.enable_return_hidden_states: capture_hidden_mode = CaptureHiddenMode.FULL diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index d0585fa46748..bbb883a2deff 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -608,11 +608,6 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]): for layer_id in self.layers_to_capture: setattr(self.layers[layer_id], "_is_layer_to_capture", True) - def set_dflash_layers_to_capture(self, layers_to_capture: List[int]): - self.layers_to_capture = layers_to_capture - for layer_id in self.layers_to_capture: - setattr(self.layers[layer_id], "_is_layer_to_capture", True) - def forward( self, input_ids: torch.Tensor, diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index ea79c4802c40..fff13c830a4d 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -886,6 +886,11 @@ def __init__( alt_stream=alt_stream, ) + def set_dflash_layers_to_capture(self, layers_to_capture: List[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + class Qwen3MoeForCausalLM(nn.Module): fall_back_to_pt_during_load = False diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 111ddadd57cd..5a77e8e39f5d 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -423,7 +423,7 @@ def verify( for j, token_id in enumerate(tokens): if vocab_size is not None and ( - int(token_id) > int(vocab_size) or int(token_id) < 0 + int(token_id) >= int(vocab_size) or int(token_id) < 0 ): tokens = tokens[: j + 1] break diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 00945758aa53..05a58a6aa581 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -68,7 +68,6 @@ def __init__( self.nccl_port = nccl_port self.target_worker = target_worker self.model_runner = target_worker.model_runner - self.tp_rank = tp_rank self.page_size = server_args.page_size self.device = target_worker.device @@ -316,12 +315,6 @@ def clear_cache_pool(self): # allocator and req_to_token_pool are shared with target worker pass - def on_req_finished(self, req): - # allocator and req_to_token_pool are shared with the target worker; - # there is no separate draft allocation to release here. - if hasattr(req, "dflash_draft_seq_len"): - req.dflash_draft_seq_len = 0 - def _resolve_mask_token_id( self, *, mask_token: str, mask_token_id: Optional[int] = None ) -> int: @@ -411,8 +404,8 @@ def _prepare_for_speculative_decoding( return if batch.has_grammar: - raise ValueError( - "DFLASH does not support grammar-constrained decoding yet." + raise RuntimeError( + "Invariant broken: DFLASH batch has grammar constraints, but scheduler should have rejected this request." ) if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: if ( @@ -427,7 +420,6 @@ def _prepare_for_speculative_decoding( self._warned_sampling_fallback = True bs = batch.batch_size() - device = self.model_runner.device # --- 1) Append any newly committed tokens into the draft KV cache. self._append_target_hidden_to_draft_kv(batch, draft_input) @@ -967,8 +959,8 @@ def forward_batch_generation( **kwargs, ) -> GenerationBatchResult: if getattr(batch, "return_logprob", False): - raise ValueError( - "DFLASH speculative decoding does not support return_logprob yet." + raise RuntimeError( + "Invariant broken: DFLASH batch requested return_logprob, but scheduler should have rejected this request." ) if isinstance(batch, ModelWorkerBatch): @@ -1021,8 +1013,6 @@ def _to_int32_device_tensor(x, *, device=device): ) self._append_target_hidden_to_draft_kv(batch, draft_input) batch.spec_info = draft_input - for req, draft_len in zip(batch.reqs, batch.seq_lens_cpu, strict=True): - req.dflash_draft_seq_len = int(draft_len) return GenerationBatchResult( logits_output=logits_output, diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py index 3f3e8c63388a..9c630da72fb1 100644 --- a/python/sglang/srt/speculative/draft_utils.py +++ b/python/sglang/srt/speculative/draft_utils.py @@ -31,27 +31,6 @@ def _create_backend( if backend_type is None: backend_type = self.server_args.attention_backend - if backend_type is None: - backend_type = "flashinfer" - elif backend_type == "trtllm_mha": - logger.warning( - "Draft attention backend does not support 'trtllm_mha' yet; " - "falling back to 'flashinfer'." - ) - backend_type = "flashinfer" - - if backend_type not in backend_map: - fallback_backend = "flashinfer" if "flashinfer" in backend_map else None - if fallback_backend is None: - raise ValueError(error_template.format(backend_type=backend_type)) - logger.warning( - "Draft attention backend '%s' is not supported for speculative draft; " - "falling back to '%s'.", - backend_type, - fallback_backend, - ) - backend_type = fallback_backend - if backend_type not in backend_map: raise ValueError(error_template.format(backend_type=backend_type)) From e4933535c127cb2ea425c01a03d2107aed14c5e1 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 2 Mar 2026 07:49:22 +0000 Subject: [PATCH 48/73] only run baseline once --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 39 +++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index b42e15f9811a..68e6d9980fe2 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -665,8 +665,10 @@ def main() -> None: # Results indexed by (backend, tp, concurrency, mode). results: dict[tuple[str, int, int, str], BenchMetrics] = {} + # Baseline metrics are backend-agnostic in this sweep; run once per TP and reuse. + baseline_cache_by_tp: dict[int, dict[int, BenchMetrics]] = {} - for backend in attention_backends: + for backend_idx, backend in enumerate(attention_backends): for tp in tp_sizes: port_base = find_available_port(20000) common_server_args = _build_common_server_args(args, backend=backend, tp=tp) @@ -678,18 +680,29 @@ def main() -> None: mode_server_args, expect_dflash, ) in enumerate(mode_runs): - mode_metrics = _run_mode_for_backend_tp( - mode_label=f"backend={backend} tp={tp} ({mode_name})", - model_path=args.target_model, - base_url=f"http://127.0.0.1:{find_available_port(port_base + idx)}", - server_args=mode_server_args, - expect_dflash=expect_dflash, - prompts=prompts, - labels=labels, - concurrencies=concurrencies, - num_questions_by_conc=num_questions_by_conc, - args=args, - ) + if ( + mode_key == "baseline" + and not args.skip_baseline + and backend_idx > 0 + and tp in baseline_cache_by_tp + ): + mode_metrics = baseline_cache_by_tp[tp] + else: + mode_metrics = _run_mode_for_backend_tp( + mode_label=f"backend={backend} tp={tp} ({mode_name})", + model_path=args.target_model, + base_url=f"http://127.0.0.1:{find_available_port(port_base + idx)}", + server_args=mode_server_args, + expect_dflash=expect_dflash, + prompts=prompts, + labels=labels, + concurrencies=concurrencies, + num_questions_by_conc=num_questions_by_conc, + args=args, + ) + if mode_key == "baseline" and not args.skip_baseline: + baseline_cache_by_tp[tp] = mode_metrics + for conc, metrics in mode_metrics.items(): results[(backend, tp, conc, mode_key)] = metrics From 7878ec4d4a1c10fd81e9403a59d7b6bdaaca5dcb Mon Sep 17 00:00:00 2001 From: David Wang Date: Sun, 8 Mar 2026 21:12:29 +0000 Subject: [PATCH 49/73] fix server startup timeout for autotuning --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 68e6d9980fe2..4696da5c07bb 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -384,10 +384,11 @@ def _run_mode_for_backend_tp( args: argparse.Namespace, ) -> dict[int, BenchMetrics]: print(f"\n=== {mode_label} ===") + server_start_timeout_s = int(max(DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, args.timeout_s)) proc = popen_launch_server( model_path, base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + timeout=server_start_timeout_s, other_args=server_args, ) try: @@ -553,7 +554,15 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--top-p", type=float, default=1.0) parser.add_argument("--top-k", type=int, default=1) - parser.add_argument("--timeout-s", type=int, default=3600) + parser.add_argument( + "--timeout-s", + type=int, + default=3600, + help=( + "Timeout in seconds for benchmarked /generate calls and server startup " + "health checks." + ), + ) parser.add_argument( "--mem-fraction-static", type=float, @@ -603,6 +612,8 @@ def main() -> None: raise RuntimeError(f"--top-p must be in (0, 1], got {args.top_p}.") if args.top_k == 0 or args.top_k < -1: raise RuntimeError(f"--top-k must be -1 (all vocab) or >= 1, got {args.top_k}.") + if args.timeout_s <= 0: + raise RuntimeError(f"--timeout-s must be > 0, got {args.timeout_s}.") visible_gpus = int(torch.cuda.device_count()) tp_sizes = _parse_int_csv(args.tp_sizes) From 5d689d70baf2683765eb0641b432c92e977bb8a0 Mon Sep 17 00:00:00 2001 From: David Wang Date: Sun, 8 Mar 2026 23:18:26 +0000 Subject: [PATCH 50/73] qwen3_5 support --- .../sglang/srt/model_executor/model_runner.py | 2 +- python/sglang/srt/models/qwen3_5.py | 39 ++++++++++++++++--- python/sglang/srt/models/qwen3_vl.py | 16 ++++++++ python/sglang/srt/speculative/dflash_utils.py | 22 ++++++++++- .../sglang/srt/speculative/dflash_worker.py | 8 +++- 5 files changed, 79 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b06d9cd2f3e5..93dc727f4ee6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -391,7 +391,7 @@ def __init__( trained_target_layers = dflash_draft_config.num_target_layers target_num_layers = getattr( - self.model_config.hf_config, "num_hidden_layers", None + self.model_config.hf_text_config, "num_hidden_layers", None ) if target_num_layers is None: raise ValueError( diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index b7c3f1ec0ea3..ce9b033fef48 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -364,8 +364,15 @@ def forward( ): forward_batch = kwargs.get("forward_batch", None) - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, residual = ( + self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( + hidden_states, + residual, + forward_batch, + captured_last_layer_outputs=kwargs.get( + "captured_last_layer_outputs", None + ), + ) ) if not forward_batch.forward_mode.is_idle(): @@ -609,10 +616,16 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], forward_batch: ForwardBatch, + captured_last_layer_outputs: Optional[list[torch.Tensor]] = None, **kwargs, ): - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, residual = ( + self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( + hidden_states, + residual, + forward_batch, + captured_last_layer_outputs=captured_last_layer_outputs, + ) ) if not forward_batch.forward_mode.is_idle(): @@ -684,6 +697,8 @@ def __init__( else: self.embed_tokens = PPMissingLayer() + self.layers_to_capture = [] + # Decoder layers def get_layer(idx: int, prefix: str): layer_type = config.layers_block_type[idx] @@ -725,6 +740,11 @@ def get_layer(idx: int, prefix: str): else: self.norm = PPMissingLayer() + def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def get_input_embeddings(self): return self.embed_tokens @@ -758,6 +778,7 @@ def forward( hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] + aux_hidden_states = [] # Pass through decoder layers for layer_idx in range(self.start_layer, self.end_layer): layer = self.layers[layer_idx] @@ -769,6 +790,11 @@ def forward( hidden_states=hidden_states, residual=residual, forward_batch=forward_batch, + captured_last_layer_outputs=( + aux_hidden_states + if getattr(layer, "_is_layer_to_capture", False) + else None + ), ) # Process deepstack embeddings if provided @@ -798,7 +824,10 @@ def forward( else: hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 34a645078c71..96ce39945a1d 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1074,6 +1074,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.capture_aux_hidden_states = False # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states @@ -1278,6 +1279,10 @@ def forward( pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( @@ -1285,12 +1290,23 @@ def forward( hidden_states, self.lm_head, forward_batch, + aux_hidden_states, ) else: return self.pooler(hidden_states, forward_batch) else: return hidden_states + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py index 5a88022a817e..ddec049e0a24 100644 --- a/python/sglang/srt/speculative/dflash_utils.py +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -205,6 +205,25 @@ def _cfg_get(config: Any, key: str, default: Any = None) -> Any: return getattr(config, key, default) +def _get_text_config(config: Any) -> Any: + if config is None: + return None + if isinstance(config, dict): + return config.get("text_config", config) + text_config = getattr(config, "text_config", None) + if text_config is not None: + return text_config + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + resolved = get_text_config() + if resolved is not None: + return resolved + except TypeError: + pass + return config + + def _get_dflash_config(config: Any) -> dict: if isinstance(config, dict): cfg = config.get("dflash_config", None) @@ -294,9 +313,10 @@ def resolve_target_layer_ids( def parse_dflash_draft_config(*, draft_hf_config: Any) -> DFlashDraftConfig: """Parse and validate DFLASH draft config fields from HF config/dict.""" dflash_cfg = _get_dflash_config(draft_hf_config) + draft_text_config = _get_text_config(draft_hf_config) num_hidden_layers = _parse_optional_int( - _cfg_get(draft_hf_config, "num_hidden_layers", None), + _cfg_get(draft_text_config, "num_hidden_layers", None), field_name="DFLASH draft num_hidden_layers", min_value=1, ) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 05a58a6aa581..8fadaba0d96d 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -15,7 +15,11 @@ ForwardBatch, ForwardMode, ) -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput from sglang.srt.speculative.dflash_utils import ( can_dflash_use_fused_qkv_proj, @@ -112,6 +116,7 @@ def __init__( draft_server_args.context_length = ( target_worker.model_runner.model_config.context_len ) + saved_server_args = get_global_server_args() self.draft_worker = TpModelWorker( server_args=draft_server_args, gpu_id=gpu_id, @@ -126,6 +131,7 @@ def __init__( req_to_token_pool=shared_req_to_token_pool, token_to_kv_pool_allocator=shared_token_to_kv_pool_allocator, ) + set_global_server_args_for_scheduler(saved_server_args) self.draft_model_runner = self.draft_worker.model_runner self.draft_model = self.draft_model_runner.model draft_config = parse_dflash_draft_config( From 277fe057ca65edf1fc61c1f9cd02daa16f7af3da Mon Sep 17 00:00:00 2001 From: David Wang Date: Thu, 12 Mar 2026 05:47:04 +0000 Subject: [PATCH 51/73] add draft swa --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 18 ++ python/sglang/srt/server_args.py | 29 ++++ python/sglang/srt/speculative/dflash_info.py | 13 +- .../sglang/srt/speculative/dflash_worker.py | 162 ++++++++++++++---- 4 files changed, 182 insertions(+), 40 deletions(-) diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py index 4696da5c07bb..a790eb40decc 100644 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -322,6 +322,14 @@ def _build_mode_runs( "DFLASH", "--speculative-draft-model-path", args.draft_model, + *( + [ + "--speculative-dflash-draft-window-size", + str(int(args.speculative_dflash_draft_window_size)), + ] + if args.speculative_dflash_draft_window_size is not None + else [] + ), *( [ "--speculative-draft-attention-backend", @@ -470,6 +478,10 @@ def _print_summary( "speculative_draft_attention_backend", args.speculative_draft_attention_backend, ), + ( + "speculative_dflash_draft_window_size", + args.speculative_dflash_draft_window_size, + ), ("tp_sizes", ",".join(str(x) for x in tp_sizes)), ("concurrencies", ",".join(str(x) for x in concurrencies)), ( @@ -598,6 +610,12 @@ def parse_args() -> argparse.Namespace: default=None, help="Optional server --speculative-draft-attention-backend override for DFLASH runs.", ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + default=None, + help="Optional server --speculative-dflash-draft-window-size override for DFLASH runs.", + ) return parser.parse_args() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3489227a0512..fa5f85d5f600 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -485,6 +485,7 @@ class ServerArgs: speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None speculative_dflash_block_size: Optional[int] = None + speculative_dflash_draft_window_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -2738,6 +2739,16 @@ def _handle_speculative_decoding(self): self.speculative_dflash_block_size ) + window_size = None + if self.speculative_dflash_draft_window_size is not None: + window_size = int(self.speculative_dflash_draft_window_size) + if window_size <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-draft-window-size " + f"to be positive, got {window_size}." + ) + self.speculative_dflash_draft_window_size = window_size + if self.speculative_num_draft_tokens is None: from sglang.srt.speculative.dflash_utils import ( parse_dflash_draft_config, @@ -2772,6 +2783,15 @@ def _handle_speculative_decoding(self): ) self.speculative_num_draft_tokens = inferred_block_size + if window_size is not None: + draft_tokens = int(self.speculative_num_draft_tokens) + if window_size < draft_tokens: + raise ValueError( + "DFLASH --speculative-dflash-draft-window-size must be >= " + "--speculative-num-draft-tokens (block_size). " + f"window_size={window_size}, block_size={draft_tokens}." + ) + if self.max_running_requests is None: self.max_running_requests = 48 logger.warning( @@ -4480,6 +4500,15 @@ def add_cli_args(parser: argparse.ArgumentParser): help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", default=ServerArgs.speculative_dflash_block_size, ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + help="DFLASH only. Sliding window size for the draft-model KV cache. " + "When set, the draft worker keeps a recent target-token window in its " + "local cache (paged backends may retain up to one extra page on the left " + "for alignment). Default is full context.", + default=ServerArgs.speculative_dflash_draft_window_size, + ) parser.add_argument( "--speculative-accept-threshold-single", type=float, diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 5a77e8e39f5d..3b3e7434b6c1 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -58,10 +58,12 @@ class DFlashDraftInput(SpecInput): It is NOT sent to model attention backends; the DFlash worker uses it to run the draft model and to track draft-side cache progress. - Invariant (per request): - - `draft_seq_len + ctx_len == batch.seq_lens[i]` - where `ctx_len` is the number of target context-feature tokens carried in - `target_hidden` for that request. + When draft windowing is disabled, `draft_seq_lens` matches the committed target + prefix length already materialized in the draft KV cache. When windowing is + enabled, `draft_seq_lens` is the logical resident length in the draft worker's + compact req-to-token mapping. In paged mode this may exceed the requested + window by up to `page_size - 1` so the local page table remains valid. `ctx_lens` + tracks newly committed target tokens that still need draft KV materialization. """ # Current token to start the next DFlash block (one per request). @@ -76,8 +78,7 @@ class DFlashDraftInput(SpecInput): # Context lengths per request, used to slice `target_hidden`. Device tensor (int32). ctx_lens: torch.Tensor - # How many tokens are already in the draft KV cache per request. - # The next draft step appends ctx_lens[i] tokens starting at draft_seq_lens[i]. + # How many committed tokens are visible to the draft worker per request. draft_seq_lens: torch.Tensor def __post_init__(self): diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8fadaba0d96d..22725a804c59 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -73,17 +73,29 @@ def __init__( self.target_worker = target_worker self.model_runner = target_worker.model_runner self.page_size = server_args.page_size + self.draft_window_size: Optional[int] = ( + int(server_args.speculative_dflash_draft_window_size) + if server_args.speculative_dflash_draft_window_size is not None + else None + ) + self.use_compact_draft_cache = self.draft_window_size is not None self.device = target_worker.device self._warned_sampling_fallback = False self._logged_first_verify = False # Draft runner (separate KV cache + attention backend). - # Share req_to_token_pool + token_to_kv_pool_allocator with the target worker (EAGLE3-style), - # while keeping a separate draft KV cache pool (the draft model has different KV values). - shared_req_to_token_pool, shared_token_to_kv_pool_allocator = ( + # Without draft windowing, the draft worker aliases the target request->token + # mapping and allocation state. With draft windowing enabled, the draft worker + # keeps a private compact req->token table over the same global KV index space, + # so radix-cache/prefix-hit KV remains reusable while draft attention sees only + # the recent window. + target_req_to_token_pool, target_token_to_kv_pool_allocator = ( target_worker.get_memory_pool() ) + shared_req_to_token_pool = ( + None if self.use_compact_draft_cache else target_req_to_token_pool + ) draft_server_args = deepcopy(server_args) draft_server_args.skip_tokenizer_init = True draft_backend = draft_server_args.speculative_draft_attention_backend @@ -94,8 +106,9 @@ def __init__( draft_backend = "flashinfer" elif draft_backend == "trtllm_mha": logger.warning( - "DFLASH draft worker does not support 'trtllm_mha' yet; " - "falling back to 'flashinfer'." + "DFLASH draft worker does not support 'trtllm_mha' because the " + "draft path requires non-causal attention. Falling back to " + "'flashinfer'." ) draft_backend = "flashinfer" elif draft_backend not in supported_draft_backends: @@ -106,7 +119,6 @@ def __init__( draft_backend, ) draft_backend = "flashinfer" - # Make the draft worker backend explicit and self-contained (no further overrides). draft_server_args.speculative_draft_attention_backend = None draft_server_args.prefill_attention_backend = None @@ -129,7 +141,7 @@ def __init__( nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=shared_req_to_token_pool, - token_to_kv_pool_allocator=shared_token_to_kv_pool_allocator, + token_to_kv_pool_allocator=target_token_to_kv_pool_allocator, ) set_global_server_args_for_scheduler(saved_server_args) self.draft_model_runner = self.draft_worker.model_runner @@ -162,10 +174,12 @@ def __init__( ) if self.tp_rank == 0: logger.info( - "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s", + "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s, draft_window_size=%s, compact_cache=%s", getattr(draft_server_args, "attention_backend", None), self.draft_model.__class__.__name__, self.block_size, + self.draft_window_size, + self.use_compact_draft_cache, ) logger.info( "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s, mask_token_id_override=%s", @@ -318,9 +332,57 @@ def __getattr__(self, name): return getattr(self.target_worker, name) def clear_cache_pool(self): - # allocator and req_to_token_pool are shared with target worker + # The target worker owns the shared KV allocator/cache. For the compact + # sliding-window path, the draft req->token view is rebuilt from committed + # target state before each draft forward, so there is nothing persistent + # to flush here. pass + def _gather_req_to_token_segments( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + start: torch.Tensor | None, + lengths: torch.Tensor, + ) -> torch.Tensor: + lengths = lengths.to(torch.int64) + if lengths.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + max_len = int(lengths.max().item()) + if max_len <= 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + offsets = torch.arange( + max_len, device=self.device, dtype=torch.int64 + ).unsqueeze(0) + if start is None: + pos2d = offsets.expand(req_pool_indices.shape[0], -1) + else: + pos2d = start.to(torch.int64).unsqueeze(1) + offsets + mask = offsets < lengths.unsqueeze(1) + return req_to_token[req_pool_indices[:, None], pos2d][mask].to(torch.int64) + + def _compute_compact_draft_seq_lens(self, seq_lens: torch.Tensor) -> torch.Tensor: + assert self.draft_window_size is not None + visible_lens = torch.clamp( + seq_lens.to(dtype=torch.int32, device=self.device), + max=int(self.draft_window_size), + ) + if self.page_size <= 1: + return visible_lens + + # Paged FA backends derive the page table from local token positions, so the + # compact suffix must start on a page boundary. Keep up to page_size - 1 extra + # tokens on the left to preserve valid local page structure. + seq_lens_i64 = seq_lens.to(torch.int64) + visible_lens_i64 = visible_lens.to(torch.int64) + visible_start = seq_lens_i64 - visible_lens_i64 + aligned_start = visible_start - torch.remainder(visible_start, self.page_size) + return (seq_lens_i64 - aligned_start).to(torch.int32) + def _resolve_mask_token_id( self, *, mask_token: str, mask_token_id: Optional[int] = None ) -> int: @@ -458,23 +520,28 @@ def _prepare_for_speculative_decoding( noise_embedding = embed_module(block_ids) input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) - # For spec-v1, the draft KV cache is always materialized to the current target - # prefix before drafting the next block. - prefix_lens = batch.seq_lens # int32, device + # For spec-v1, the draft KV cache is always materialized before drafting the + # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` + # are the logical resident lengths in the draft-local cache. + target_prefix_lens = batch.seq_lens # int32, device + draft_prefix_lens = draft_input.draft_seq_lens + if draft_prefix_lens.dtype != torch.int32: + draft_prefix_lens = draft_prefix_lens.to(torch.int32) + if draft_prefix_lens.device != self.device: + draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) positions_2d = self._draft_block_positions_buf[:bs] - torch.add(prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d) + torch.add( + target_prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d + ) positions = positions_2d.reshape(-1) - block_start = prefix_lens + block_start = draft_prefix_lens block_end = self._draft_block_end_buf[:bs] torch.add(block_start, int(self.block_size), out=block_end) seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] - if batch.seq_lens_cpu.dtype == torch.int32: - seq_lens_cpu.copy_(batch.seq_lens_cpu) - else: - seq_lens_cpu.copy_(batch.seq_lens_cpu.to(torch.int32)) + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) allocator = self.draft_model_runner.token_to_kv_pool_allocator token_to_kv_pool_state_backup = allocator.backup_state() try: @@ -513,8 +580,8 @@ def _prepare_for_speculative_decoding( # In this mode, `seq_lens` stores the prefix lengths; attention backends # derive kv_len by adding `draft_token_num`. draft_spec_info = self._draft_block_spec_info - seq_lens = prefix_lens - seq_lens_sum = int(batch.seq_lens_sum) + seq_lens = draft_prefix_lens + seq_lens_sum = int(draft_prefix_lens.sum().item()) forward_batch = ForwardBatch( forward_mode=ForwardMode.TARGET_VERIFY, batch_size=bs, @@ -777,24 +844,23 @@ def _append_target_hidden_to_draft_kv( total_ctx = int(draft_input.target_hidden.shape[0]) if total_ctx <= 0: + draft_input.ctx_lens = torch.zeros_like(draft_input.ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] return - req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token + target_req_to_token = batch.req_to_token_pool.req_to_token + draft_req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token req_pool_indices = batch.req_pool_indices if req_pool_indices.dtype != torch.int64: req_pool_indices = req_pool_indices.to(torch.int64) ctx_lens = draft_input.ctx_lens - draft_seq_lens = draft_input.draft_seq_lens if ctx_lens.dtype != torch.int32: ctx_lens = ctx_lens.to(torch.int32) - if draft_seq_lens.dtype != torch.int32: - draft_seq_lens = draft_seq_lens.to(torch.int32) if ctx_lens.device != device: ctx_lens = ctx_lens.to(device, non_blocking=True) - if draft_seq_lens.device != device: - draft_seq_lens = draft_seq_lens.to(device, non_blocking=True) + ctx_start = batch.seq_lens.to(torch.int64) - ctx_lens.to(torch.int64) if bs == 1: # Fast path for single request. @@ -803,8 +869,8 @@ def _append_target_hidden_to_draft_kv( r = self._block_pos_offsets[:max_ctx] else: r = torch.arange(max_ctx, device=device, dtype=torch.int64) - pos2d = draft_seq_lens.to(torch.int64)[:, None] + r[None, :] # [1, ctx] - cache2d = req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] + pos2d = ctx_start[:, None] + r[None, :] # [1, ctx] + cache2d = target_req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] ctx_cache_loc = cache2d.reshape(-1).to(torch.int64) # [ctx] ctx_positions = pos2d.reshape(-1) # [ctx] else: @@ -821,11 +887,13 @@ def _append_target_hidden_to_draft_kv( else: r = torch.arange(max_ctx, device=device, dtype=torch.int64) r = r[None, :] # [1, max_ctx] - pos2d = draft_seq_lens.to(torch.int64)[:, None] + r # [bs, max_ctx] + pos2d = ctx_start[:, None] + r # [bs, max_ctx] mask = r < ctx_lens[:, None] # Batched gather of cache locations and positions. - cache2d = req_to_token[req_pool_indices[:, None], pos2d] # [bs, max_ctx] + cache2d = target_req_to_token[ + req_pool_indices[:, None], pos2d + ] # [bs, max_ctx] ctx_cache_loc = cache2d[mask].to(torch.int64) # [sum(ctx_lens)] ctx_positions = pos2d[mask] # [sum(ctx_lens)] @@ -858,7 +926,28 @@ def _append_target_hidden_to_draft_kv( ctx_hidden, ctx_positions, ctx_cache_loc ) - draft_input.draft_seq_lens = draft_seq_lens + ctx_lens + if self.use_compact_draft_cache: + new_draft_seq_lens = self._compute_compact_draft_seq_lens(batch.seq_lens) + suffix_start = batch.seq_lens.to(torch.int64) - new_draft_seq_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + start=suffix_start, + lengths=new_draft_seq_lens, + ) + assign_req_to_token_pool_func( + batch.req_pool_indices, + draft_req_to_token, + torch.zeros_like(new_draft_seq_lens), + new_draft_seq_lens, + suffix_cache_loc, + bs, + ) + draft_input.draft_seq_lens = new_draft_seq_lens + else: + draft_input.draft_seq_lens = batch.seq_lens.to(dtype=torch.int32) draft_input.ctx_lens = torch.zeros_like(ctx_lens) draft_input.target_hidden = draft_input.target_hidden[:0] @@ -1009,12 +1098,17 @@ def _to_int32_device_tensor(x, *, device=device): return x if x.dtype == torch.int32 else x.to(torch.int32) return torch.tensor(x, dtype=torch.int32, device=device) + extend_seq_lens = _to_int32_device_tensor( + model_worker_batch.extend_seq_lens + ) draft_input = DFlashDraftInput( verified_id=next_token_ids.to(torch.int64), target_hidden=logits_output.hidden_states, - ctx_lens=_to_int32_device_tensor(model_worker_batch.extend_seq_lens), - draft_seq_lens=_to_int32_device_tensor( - model_worker_batch.extend_prefix_lens + ctx_lens=extend_seq_lens, + draft_seq_lens=( + torch.zeros_like(extend_seq_lens) + if self.use_compact_draft_cache + else _to_int32_device_tensor(model_worker_batch.extend_prefix_lens) ), ) self._append_target_hidden_to_draft_kv(batch, draft_input) From e5ef869c45f6d5983b34427562ce62e299570036 Mon Sep 17 00:00:00 2001 From: David Wang Date: Mon, 16 Mar 2026 20:52:39 +0000 Subject: [PATCH 52/73] avoid OOB in masked req_to_token gathers --- .../sglang/srt/speculative/dflash_worker.py | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 22725a804c59..353557e1f12a 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -338,6 +338,43 @@ def clear_cache_pool(self): # to flush here. pass + def _gather_req_to_token_masked( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + pos2d: torch.Tensor, + mask: torch.Tensor, + context: str, + ) -> torch.Tensor: + if pos2d.ndim != 2: + raise RuntimeError( + f"{context} expected 2D positions, got shape={tuple(pos2d.shape)}." + ) + if mask.shape != pos2d.shape: + raise RuntimeError( + f"{context} mask/position shape mismatch: {tuple(mask.shape)} vs {tuple(pos2d.shape)}." + ) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + if mask.dtype != torch.bool: + mask = mask.to(torch.bool) + + table_width = int(req_to_token.shape[1]) + if table_width <= 0: + if bool(mask.any().item()): + raise RuntimeError( + f"{context} req_to_token table is empty but gather mask is non-empty." + ) + return torch.empty((0,), dtype=torch.int64, device=self.device) + + # Only the masked-off rectangular padding can be out of range in the normal + # ragged-batch case. Replace those don't-care columns with a valid in-range + # position before the gather so the kernel only sees real positions. + safe_pos2d = pos2d.masked_fill(~mask, 0) + return req_to_token[req_pool_indices[:, None], safe_pos2d][mask].to(torch.int64) + def _gather_req_to_token_segments( self, *, @@ -363,7 +400,13 @@ def _gather_req_to_token_segments( else: pos2d = start.to(torch.int64).unsqueeze(1) + offsets mask = offsets < lengths.unsqueeze(1) - return req_to_token[req_pool_indices[:, None], pos2d][mask].to(torch.int64) + return self._gather_req_to_token_masked( + req_to_token=req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH req_to_token segment gather", + ) def _compute_compact_draft_seq_lens(self, seq_lens: torch.Tensor) -> torch.Tensor: assert self.draft_window_size is not None @@ -891,10 +934,13 @@ def _append_target_hidden_to_draft_kv( mask = r < ctx_lens[:, None] # Batched gather of cache locations and positions. - cache2d = target_req_to_token[ - req_pool_indices[:, None], pos2d - ] # [bs, max_ctx] - ctx_cache_loc = cache2d[mask].to(torch.int64) # [sum(ctx_lens)] + ctx_cache_loc = self._gather_req_to_token_masked( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH target hidden KV append", + ) # [sum(ctx_lens)] ctx_positions = pos2d[mask] # [sum(ctx_lens)] with torch.inference_mode(): From 8063b658dbafd609c37abab23e87a303ee7cc3e0 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Mar 2026 16:56:26 +0000 Subject: [PATCH 53/73] clean up model support --- python/sglang/srt/models/gpt_oss.py | 15 ---------- python/sglang/srt/models/qwen3.py | 14 --------- python/sglang/srt/models/qwen3_5.py | 39 ++++---------------------- python/sglang/srt/models/qwen3_moe.py | 17 ----------- python/sglang/srt/models/qwen3_next.py | 20 ------------- python/sglang/srt/models/qwen3_vl.py | 16 ----------- 6 files changed, 5 insertions(+), 116 deletions(-) diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 228555ea164d..96caaa65b57c 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -1104,9 +1104,6 @@ def _load_normal_weights( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight - def get_input_embeddings(self) -> nn.Embedding: - return self.model.embed_tokens - def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1129,18 +1126,6 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] - def set_dflash_layers_to_capture(self, layer_ids: List[int]): - if not self.pp_group.is_last_rank: - return - - if layer_ids is None: - raise ValueError( - "DFLASH requires explicit layer_ids for aux hidden capture." - ) - - self.capture_aux_hidden_states = True - self.model.layers_to_capture = [val + 1 for val in layer_ids] - @classmethod def get_model_config_for_expert_location(cls, config): return ModelConfigForExpertLocation( diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index e9e9b61b3e2b..b6b955ae6c7c 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -586,19 +586,5 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.layers_to_capture = [val + 1 for val in layer_ids] - def set_dflash_layers_to_capture(self, layer_ids: List[int]): - if not self.pp_group.is_last_rank: - return - - if layer_ids is None: - raise ValueError( - "DFLASH requires explicit layer_ids for aux hidden capture." - ) - - self.capture_aux_hidden_states = True - # SGLang captures "before layer i". To capture the hidden state after target - # layer `k` (HF-style), we capture before layer `k + 1`. - self.model.layers_to_capture = [val + 1 for val in layer_ids] - EntryClass = Qwen3ForCausalLM diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index ce9b033fef48..b7c3f1ec0ea3 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -364,15 +364,8 @@ def forward( ): forward_batch = kwargs.get("forward_batch", None) - hidden_states, residual = ( - self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( - hidden_states, - residual, - forward_batch, - captured_last_layer_outputs=kwargs.get( - "captured_last_layer_outputs", None - ), - ) + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch ) if not forward_batch.forward_mode.is_idle(): @@ -616,16 +609,10 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], forward_batch: ForwardBatch, - captured_last_layer_outputs: Optional[list[torch.Tensor]] = None, **kwargs, ): - hidden_states, residual = ( - self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( - hidden_states, - residual, - forward_batch, - captured_last_layer_outputs=captured_last_layer_outputs, - ) + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch ) if not forward_batch.forward_mode.is_idle(): @@ -697,8 +684,6 @@ def __init__( else: self.embed_tokens = PPMissingLayer() - self.layers_to_capture = [] - # Decoder layers def get_layer(idx: int, prefix: str): layer_type = config.layers_block_type[idx] @@ -740,11 +725,6 @@ def get_layer(idx: int, prefix: str): else: self.norm = PPMissingLayer() - def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): - self.layers_to_capture = layers_to_capture - for layer_id in self.layers_to_capture: - setattr(self.layers[layer_id], "_is_layer_to_capture", True) - def get_input_embeddings(self): return self.embed_tokens @@ -778,7 +758,6 @@ def forward( hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] - aux_hidden_states = [] # Pass through decoder layers for layer_idx in range(self.start_layer, self.end_layer): layer = self.layers[layer_idx] @@ -790,11 +769,6 @@ def forward( hidden_states=hidden_states, residual=residual, forward_batch=forward_batch, - captured_last_layer_outputs=( - aux_hidden_states - if getattr(layer, "_is_layer_to_capture", False) - else None - ), ) # Process deepstack embeddings if provided @@ -824,10 +798,7 @@ def forward( else: hidden_states, _ = self.norm(hidden_states, residual) - if len(aux_hidden_states) == 0: - return hidden_states - - return hidden_states, aux_hidden_states + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 0e39d74a497c..845502fe2e88 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -889,11 +889,6 @@ def __init__( alt_stream=alt_stream, ) - def set_dflash_layers_to_capture(self, layers_to_capture: List[int]): - self.layers_to_capture = layers_to_capture - for layer_id in self.layers_to_capture: - setattr(self.layers[layer_id], "_is_layer_to_capture", True) - class Qwen3MoeForCausalLM(nn.Module): fall_back_to_pt_during_load = False @@ -1031,18 +1026,6 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) - def set_dflash_layers_to_capture(self, layer_ids: List[int]): - if not self.pp_group.is_last_rank: - return - - if layer_ids is None: - raise ValueError( - "DFLASH requires explicit layer_ids for aux hidden capture." - ) - - self.capture_aux_hidden_states = True - self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 34445121d555..42146e6057ea 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -864,11 +864,6 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: list[int]): for layer_id in self.layers_to_capture: setattr(self.layers[layer_id], "_is_layer_to_capture", True) - def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): - self.layers_to_capture = layers_to_capture - for layer_id in self.layers_to_capture: - setattr(self.layers[layer_id], "_is_layer_to_capture", True) - def forward( self, input_ids: torch.Tensor, @@ -1003,9 +998,6 @@ def forward( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight - def get_input_embeddings(self) -> nn.Embedding: - return self.model.embed_tokens - def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1179,17 +1171,5 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) - def set_dflash_layers_to_capture(self, layer_ids: list[int]): - if not self.pp_group.is_last_rank: - return - - if layer_ids is None: - raise ValueError( - "DFLASH requires explicit layer_ids for aux hidden capture." - ) - - self.capture_aux_hidden_states = True - self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) - EntryClass = Qwen3NextForCausalLM diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 96ce39945a1d..34a645078c71 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1074,7 +1074,6 @@ def __init__( self.logits_processor = LogitsProcessor(self.config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) - self.capture_aux_hidden_states = False # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states @@ -1279,10 +1278,6 @@ def forward( pp_proxy_tensors=pp_proxy_tensors, ) - aux_hidden_states = None - if self.capture_aux_hidden_states: - hidden_states, aux_hidden_states = hidden_states - if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( @@ -1290,23 +1285,12 @@ def forward( hidden_states, self.lm_head, forward_batch, - aux_hidden_states, ) else: return self.pooler(hidden_states, forward_batch) else: return hidden_states - def set_dflash_layers_to_capture(self, layer_ids: List[int]): - if not self.pp_group.is_last_rank: - return - if layer_ids is None: - raise ValueError( - "DFLASH requires explicit layer_ids for aux hidden capture." - ) - self.capture_aux_hidden_states = True - self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) From fd3aa66f81a48adf220d2f771184f2f6ab707533 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Mar 2026 16:58:43 +0000 Subject: [PATCH 54/73] clean up benchmarking script --- benchmark/dflash/bench_dflash_gsm8k_sweep.py | 749 ------------------- 1 file changed, 749 deletions(-) delete mode 100644 benchmark/dflash/bench_dflash_gsm8k_sweep.py diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py deleted file mode 100644 index a790eb40decc..000000000000 --- a/benchmark/dflash/bench_dflash_gsm8k_sweep.py +++ /dev/null @@ -1,749 +0,0 @@ -"""DFLASH vs baseline GSM8K sweep. - -This is a *benchmark script* (not a CI test): it can take a long time because it -launches servers for multiple (attention_backend, tp_size) configs and runs a -GSM8K workload for each (concurrency, num_questions) setting. - -Example usage: - ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py - ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --concurrencies 32 --tp-sizes 8 -""" - -from __future__ import annotations - -import argparse -import ast -import os -import re -import statistics -import time -from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass -from typing import Optional - -import requests -import torch -from transformers import AutoTokenizer - -from sglang.srt.utils import get_device_sm, kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - find_available_port, - popen_launch_server, -) -from sglang.utils import download_and_cache_file, read_jsonl - -INVALID = -9999999 - - -def _parse_int_csv(value: str) -> list[int]: - return [int(x) for x in value.split(",") if x.strip()] - - -def _filter_attention_backends(backends: list[str], *, device_sm: int) -> list[str]: - if not (80 <= device_sm <= 90): - backends = [b for b in backends if b != "fa3"] - if device_sm < 100: - backends = [b for b in backends if b not in ("fa4", "trtllm_mha")] - return backends or ["flashinfer"] - - -def _get_answer_value(answer_str: str) -> int: - answer_str = answer_str.replace(",", "") - numbers = re.findall(r"\d+", answer_str) - if len(numbers) < 1: - return INVALID - try: - return ast.literal_eval(numbers[-1]) - except SyntaxError: - return INVALID - - -def _maybe_download_gsm8k(data_path: str) -> str: - url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" - if os.path.isfile(data_path): - return data_path - return download_and_cache_file(url) - - -def _flush_cache(base_url: str) -> None: - resp = requests.get(base_url + "/flush_cache", timeout=60) - resp.raise_for_status() - - -def _send_generate( - base_url: str, - text: str | list[str], - *, - max_new_tokens: int, - temperature: float, - top_p: float, - top_k: int, - timeout_s: int, -) -> list[dict]: - if isinstance(text, list) and not text: - return [] - sampling_params: dict = { - "temperature": float(temperature), - "top_p": float(top_p), - "top_k": int(top_k), - "max_new_tokens": int(max_new_tokens), - } - resp = requests.post( - base_url + "/generate", - json={ - "text": text, - "sampling_params": sampling_params, - }, - timeout=int(timeout_s), - ) - resp.raise_for_status() - out = resp.json() - if isinstance(text, list): - if not isinstance(out, list): - raise RuntimeError( - "Expected a list response for batched /generate, but got " - f"type={type(out).__name__}." - ) - if len(out) != len(text): - raise RuntimeError( - "Batched /generate output length mismatch: " - f"got {len(out)} outputs for {len(text)} prompts." - ) - return out - - if isinstance(out, list): - raise RuntimeError( - "Expected an object response for single /generate, but got " - f"type={type(out).__name__}." - ) - return [out] - - -@dataclass(frozen=True) -class BenchMetrics: - latency_s: float - output_tokens: int - output_toks_per_s: float - accuracy: Optional[float] - invalid_rate: Optional[float] - spec_accept_length: Optional[float] - spec_verify_ct_sum: int - - -def _run_gsm8k_requests( - base_url: str, - *, - prompts: list[str], - labels: Optional[list[int]], - max_new_tokens: int, - temperature: float, - top_p: float, - top_k: int, - concurrency: int, - batch_requests: bool, - timeout_s: int, - expect_dflash: bool, -) -> BenchMetrics: - if labels is not None and len(labels) != len(prompts): - raise ValueError("labels length must match prompts length") - - start = time.perf_counter() - total_tokens = 0 - spec_verify_ct_sum = 0 - spec_accept_lengths: list[float] = [] - correct = 0 - invalid = 0 - - def _handle_output(out: dict, label: Optional[int]) -> None: - nonlocal total_tokens, spec_verify_ct_sum, correct, invalid - meta = out.get("meta_info", {}) or {} - total_tokens += int(meta.get("completion_tokens", 0)) - spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) - if "spec_accept_length" in meta: - try: - spec_accept_lengths.append(float(meta["spec_accept_length"])) - except (TypeError, ValueError): - pass - - if label is not None: - pred = _get_answer_value(out.get("text", "")) - if pred == INVALID: - invalid += 1 - if pred == label: - correct += 1 - - if batch_requests: - bs = max(int(concurrency), 1) - for start_idx in range(0, len(prompts), bs): - chunk_prompts = prompts[start_idx : start_idx + bs] - chunk_labels = ( - labels[start_idx : start_idx + bs] if labels is not None else None - ) - outs = _send_generate( - base_url, - chunk_prompts, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - timeout_s=timeout_s, - ) - if chunk_labels is None: - for out in outs: - _handle_output(out, None) - else: - for out, label in zip(outs, chunk_labels): - _handle_output(out, label) - else: - with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: - futures = { - pool.submit( - _send_generate, - base_url=base_url, - text=prompt, - max_new_tokens=max_new_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - timeout_s=timeout_s, - ): i - for i, prompt in enumerate(prompts) - } - for fut in as_completed(futures): - i = futures[fut] - outs = fut.result() - if len(outs) != 1: - raise RuntimeError( - "Expected exactly one output for single /generate request." - ) - label = None if labels is None else labels[i] - _handle_output(outs[0], label) - - latency = time.perf_counter() - start - toks_per_s = total_tokens / max(latency, 1e-6) - - if expect_dflash and spec_verify_ct_sum <= 0: - raise RuntimeError( - "DFLASH sanity check failed: did not observe any `spec_verify_ct` in responses " - "(DFLASH may not have been enabled)." - ) - - spec_accept_length = ( - float(statistics.mean(spec_accept_lengths)) if spec_accept_lengths else None - ) - - if labels is None: - acc = None - invalid_rate = None - else: - acc = correct / max(len(prompts), 1) - invalid_rate = invalid / max(len(prompts), 1) - - return BenchMetrics( - latency_s=float(latency), - output_tokens=int(total_tokens), - output_toks_per_s=float(toks_per_s), - accuracy=acc, - invalid_rate=invalid_rate, - spec_accept_length=spec_accept_length, - spec_verify_ct_sum=int(spec_verify_ct_sum), - ) - - -def _format_table( - *, - tp_sizes: list[int], - concurrencies: list[int], - values: dict[tuple[int, int], Optional[float]], - float_fmt: str, -) -> str: - header = ["tp\\conc"] + [str(c) for c in concurrencies] - rows: list[list[str]] = [header] - for tp in tp_sizes: - row = [str(tp)] - for c in concurrencies: - v = values.get((tp, c), None) - row.append("N/A" if v is None else format(v, float_fmt)) - rows.append(row) - - col_widths = [ - max(len(row[col_idx]) for row in rows) for col_idx in range(len(rows[0])) - ] - - lines: list[str] = [] - lines.append(" ".join(cell.rjust(col_widths[i]) for i, cell in enumerate(rows[0]))) - lines.append(" ".join("-" * w for w in col_widths)) - for row in rows[1:]: - lines.append(" ".join(cell.rjust(col_widths[i]) for i, cell in enumerate(row))) - return "\n".join(lines) - - -def _build_common_server_args( - args: argparse.Namespace, *, backend: str, tp: int -) -> list[str]: - common_server_args: list[str] = [ - "--trust-remote-code", - "--attention-backend", - backend, - "--tp-size", - str(tp), - "--dtype", - str(args.dtype), - "--max-running-requests", - str(args.max_running_requests), - "--cuda-graph-max-bs", - "32", - ] - if args.mem_fraction_static is not None: - common_server_args.extend( - ["--mem-fraction-static", str(args.mem_fraction_static)] - ) - if args.disable_radix_cache: - common_server_args.append("--disable-radix-cache") - if args.page_size is not None: - common_server_args.extend(["--page-size", str(int(args.page_size))]) - return common_server_args - - -def _build_mode_runs( - args: argparse.Namespace, common_server_args: list[str] -) -> list[tuple[str, str, list[str], bool]]: - mode_runs: list[tuple[str, str, list[str], bool]] = [] - if not args.skip_baseline: - mode_runs.append(("baseline", "baseline", common_server_args, False)) - mode_runs.append( - ( - "dflash", - "DFLASH", - [ - *common_server_args, - "--speculative-algorithm", - "DFLASH", - "--speculative-draft-model-path", - args.draft_model, - *( - [ - "--speculative-dflash-draft-window-size", - str(int(args.speculative_dflash_draft_window_size)), - ] - if args.speculative_dflash_draft_window_size is not None - else [] - ), - *( - [ - "--speculative-draft-attention-backend", - args.speculative_draft_attention_backend, - ] - if args.speculative_draft_attention_backend - else [] - ), - ], - True, - ) - ) - return mode_runs - - -def _collect_metric( - *, - results: dict[tuple[str, int, int, str], BenchMetrics], - backend: str, - tp_sizes: list[int], - concurrencies: list[int], - mode: str, - field: str, -) -> dict[tuple[int, int], Optional[float]]: - out: dict[tuple[int, int], Optional[float]] = {} - for tp in tp_sizes: - for conc in concurrencies: - metrics = results.get((backend, tp, conc, mode), None) - out[(tp, conc)] = None if metrics is None else getattr(metrics, field) - return out - - -def _compute_speedup( - baseline: dict[tuple[int, int], Optional[float]], - dflash: dict[tuple[int, int], Optional[float]], -) -> dict[tuple[int, int], Optional[float]]: - return { - key: None if (b is None or d is None or b <= 0) else (d / b) - for key, b in baseline.items() - for d in [dflash.get(key, None)] - } - - -def _print_kv_lines(items: list[tuple[str, object]]) -> None: - for key, value in items: - print(f"{key}={value}") - - -def _run_mode_for_backend_tp( - *, - mode_label: str, - model_path: str, - base_url: str, - server_args: list[str], - expect_dflash: bool, - prompts: list[str], - labels: list[int], - concurrencies: list[int], - num_questions_by_conc: dict[int, int], - args: argparse.Namespace, -) -> dict[int, BenchMetrics]: - print(f"\n=== {mode_label} ===") - server_start_timeout_s = int(max(DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, args.timeout_s)) - proc = popen_launch_server( - model_path, - base_url, - timeout=server_start_timeout_s, - other_args=server_args, - ) - try: - _send_generate( - base_url, - "Hello", - max_new_tokens=8, - temperature=float(args.temperature), - top_p=float(args.top_p), - top_k=int(args.top_k), - timeout_s=min(int(args.timeout_s), 300), - ) - - metrics_by_conc: dict[int, BenchMetrics] = {} - for conc in concurrencies: - n = num_questions_by_conc[conc] - _flush_cache(base_url) - metrics = _run_gsm8k_requests( - base_url, - prompts=prompts[:n], - labels=labels[:n], - max_new_tokens=int(args.max_new_tokens), - temperature=float(args.temperature), - top_p=float(args.top_p), - top_k=int(args.top_k), - concurrency=int(conc), - batch_requests=bool(args.batch_requests), - timeout_s=int(args.timeout_s), - expect_dflash=expect_dflash, - ) - metrics_by_conc[conc] = metrics - line = ( - f"[{mode_label}] conc={conc:>2} n={n:<4} " - f"toks/s={metrics.output_toks_per_s:,.2f} " - f"latency={metrics.latency_s:.1f}s " - f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" - ) - if expect_dflash: - accept_len = ( - "N/A" - if metrics.spec_accept_length is None - else f"{metrics.spec_accept_length:.3f}" - ) - line += ( - f" accept_len={accept_len} " - f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" - ) - print(line) - return metrics_by_conc - finally: - kill_process_tree(proc.pid) - try: - proc.wait(timeout=30) - except Exception: - pass - - -def _print_summary( - *, - args: argparse.Namespace, - attention_backends: list[str], - tp_sizes: list[int], - concurrencies: list[int], - device_sm: int, - results: dict[tuple[str, int, int, str], BenchMetrics], -) -> None: - print("\n=== DFLASH GSM8K Sweep Summary ===") - _print_kv_lines( - [ - ("target_model", args.target_model), - ("draft_model", args.draft_model), - ("max_new_tokens", args.max_new_tokens), - ( - "sampling", - f"temperature:{args.temperature}, top_p:{args.top_p}, top_k:{args.top_k}", - ), - ("attention_backends", ",".join(attention_backends)), - ( - "speculative_draft_attention_backend", - args.speculative_draft_attention_backend, - ), - ( - "speculative_dflash_draft_window_size", - args.speculative_dflash_draft_window_size, - ), - ("tp_sizes", ",".join(str(x) for x in tp_sizes)), - ("concurrencies", ",".join(str(x) for x in concurrencies)), - ( - "questions_per_concurrency_base", - args.questions_per_concurrency_base, - ), - ("device_sm", device_sm), - ("skip_baseline", bool(args.skip_baseline)), - ] - ) - - section_fields = [ - ("Baseline output tok/s", "baseline", "output_toks_per_s", ",.2f"), - ("Baseline accuracy", "baseline", "accuracy", ".3f"), - ("DFLASH output tok/s", "dflash", "output_toks_per_s", ",.2f"), - ("DFLASH accuracy", "dflash", "accuracy", ".3f"), - ( - "DFLASH acceptance length (mean spec_accept_length)", - "dflash", - "spec_accept_length", - ".3f", - ), - ] - - for backend in attention_backends: - print(f"\n=== Backend: {backend} ===") - metrics_map = { - (mode, field): _collect_metric( - results=results, - backend=backend, - tp_sizes=tp_sizes, - concurrencies=concurrencies, - mode=mode, - field=field, - ) - for _, mode, field, _ in section_fields - } - sections: list[tuple[str, dict[tuple[int, int], Optional[float]], str]] = [ - (title, metrics_map[(mode, field)], fmt) - for title, mode, field, fmt in section_fields - ] - sections.insert( - 4, - ( - "Speedup (DFLASH / baseline)", - _compute_speedup( - metrics_map[("baseline", "output_toks_per_s")], - metrics_map[("dflash", "output_toks_per_s")], - ), - ".3f", - ), - ) - - for title, values, fmt in sections: - print(f"\n{title}") - print( - _format_table( - tp_sizes=tp_sizes, - concurrencies=concurrencies, - values=values, - float_fmt=fmt, - ) - ) - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser() - parser.add_argument("--data-path", default="test.jsonl") - parser.add_argument("--target-model", default="Qwen/Qwen3-8B") - parser.add_argument("--draft-model", default="z-lab/Qwen3-8B-DFlash-b16") - parser.add_argument( - "--skip-baseline", - action="store_true", - help="Skip running the baseline (target-only) sweep; only run DFLASH and report N/A for baseline/speedup.", - ) - parser.add_argument( - "--batch-requests", - action="store_true", - help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.", - ) - parser.add_argument("--max-new-tokens", type=int, default=2048) - parser.add_argument("--temperature", type=float, default=0.0) - parser.add_argument("--top-p", type=float, default=1.0) - parser.add_argument("--top-k", type=int, default=1) - parser.add_argument( - "--timeout-s", - type=int, - default=3600, - help=( - "Timeout in seconds for benchmarked /generate calls and server startup " - "health checks." - ), - ) - parser.add_argument( - "--mem-fraction-static", - type=float, - default=None, - help="Optional server --mem-fraction-static override. If unset, use the server auto heuristic.", - ) - parser.add_argument("--disable-radix-cache", action="store_true") - parser.add_argument("--dtype", default="bfloat16") - parser.add_argument( - "--page-size", - type=int, - default=None, - help="Optional server --page-size override for both baseline and DFLASH runs.", - ) - parser.add_argument("--max-running-requests", type=int, default=32) - parser.add_argument("--tp-sizes", default="1,2,4,8") - parser.add_argument("--concurrencies", default="1,2,4,8,16,32") - parser.add_argument( - "--questions-per-concurrency-base", - type=int, - default=128, - help="num_questions = base * concurrency (default matches the sweep plan).", - ) - parser.add_argument( - "--max-questions-per-config", - type=int, - default=1024, - help="Cap num_questions per (tp, concurrency) run (default: 1024).", - ) - parser.add_argument("--attention-backends", default="flashinfer,fa3,trtllm_mha,fa4") - parser.add_argument( - "--speculative-draft-attention-backend", - default=None, - help="Optional server --speculative-draft-attention-backend override for DFLASH runs.", - ) - parser.add_argument( - "--speculative-dflash-draft-window-size", - type=int, - default=None, - help="Optional server --speculative-dflash-draft-window-size override for DFLASH runs.", - ) - return parser.parse_args() - - -def main() -> None: - args = parse_args() - - if not torch.cuda.is_available(): - raise RuntimeError("CUDA is required for this sweep.") - if args.temperature < 0.0: - raise RuntimeError(f"--temperature must be >= 0, got {args.temperature}.") - if not (0.0 < args.top_p <= 1.0): - raise RuntimeError(f"--top-p must be in (0, 1], got {args.top_p}.") - if args.top_k == 0 or args.top_k < -1: - raise RuntimeError(f"--top-k must be -1 (all vocab) or >= 1, got {args.top_k}.") - if args.timeout_s <= 0: - raise RuntimeError(f"--timeout-s must be > 0, got {args.timeout_s}.") - - visible_gpus = int(torch.cuda.device_count()) - tp_sizes = _parse_int_csv(args.tp_sizes) - tp_sizes = [tp for tp in tp_sizes if tp >= 1 and tp <= visible_gpus] - if not tp_sizes: - raise RuntimeError( - f"No tp sizes are runnable with visible_gpus={visible_gpus}. " - "Set CUDA_VISIBLE_DEVICES accordingly." - ) - - concurrencies = _parse_int_csv(args.concurrencies) - concurrencies = [c for c in concurrencies if c >= 1] - if not concurrencies: - raise RuntimeError("No concurrencies specified.") - - num_questions_by_conc = { - c: min( - int(args.questions_per_concurrency_base) * int(c), - int(args.max_questions_per_config), - ) - for c in concurrencies - } - max_questions = max(num_questions_by_conc.values()) - - attention_backends = [ - s.strip() for s in args.attention_backends.split(",") if s.strip() - ] - device_sm = get_device_sm() - attention_backends = _filter_attention_backends( - attention_backends, device_sm=device_sm - ) - - data_path = _maybe_download_gsm8k(args.data_path) - lines = list(read_jsonl(data_path)) - if len(lines) < max_questions: - raise RuntimeError( - f"GSM8K file only has {len(lines)} lines, but need {max_questions}." - ) - - tokenizer = AutoTokenizer.from_pretrained(args.target_model) - - prompts: list[str] = [] - labels: list[int] = [] - for i in range(max_questions): - user_content = ( - lines[i]["question"] - + "\nPlease reason step by step, and put your final answer within \\boxed{}." - ) - prompts.append( - tokenizer.apply_chat_template( - [{"role": "user", "content": user_content}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) - ) - labels.append(_get_answer_value(lines[i]["answer"])) - if not all(label != INVALID for label in labels): - raise RuntimeError("Invalid labels in GSM8K data.") - - # Results indexed by (backend, tp, concurrency, mode). - results: dict[tuple[str, int, int, str], BenchMetrics] = {} - # Baseline metrics are backend-agnostic in this sweep; run once per TP and reuse. - baseline_cache_by_tp: dict[int, dict[int, BenchMetrics]] = {} - - for backend_idx, backend in enumerate(attention_backends): - for tp in tp_sizes: - port_base = find_available_port(20000) - common_server_args = _build_common_server_args(args, backend=backend, tp=tp) - mode_runs = _build_mode_runs(args, common_server_args) - - for idx, ( - mode_key, - mode_name, - mode_server_args, - expect_dflash, - ) in enumerate(mode_runs): - if ( - mode_key == "baseline" - and not args.skip_baseline - and backend_idx > 0 - and tp in baseline_cache_by_tp - ): - mode_metrics = baseline_cache_by_tp[tp] - else: - mode_metrics = _run_mode_for_backend_tp( - mode_label=f"backend={backend} tp={tp} ({mode_name})", - model_path=args.target_model, - base_url=f"http://127.0.0.1:{find_available_port(port_base + idx)}", - server_args=mode_server_args, - expect_dflash=expect_dflash, - prompts=prompts, - labels=labels, - concurrencies=concurrencies, - num_questions_by_conc=num_questions_by_conc, - args=args, - ) - if mode_key == "baseline" and not args.skip_baseline: - baseline_cache_by_tp[tp] = mode_metrics - - for conc, metrics in mode_metrics.items(): - results[(backend, tp, conc, mode_key)] = metrics - - _print_summary( - args=args, - attention_backends=attention_backends, - tp_sizes=tp_sizes, - concurrencies=concurrencies, - device_sm=device_sm, - results=results, - ) - - -if __name__ == "__main__": - main() From 180d2a8bc757628a5e1136cc74eebd081380ac71 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Mar 2026 19:00:31 +0000 Subject: [PATCH 55/73] clean up dflash cuda graph runner paths --- .../srt/model_executor/cuda_graph_runner.py | 25 +++++++++---------- .../sglang/srt/model_executor/model_runner.py | 14 ++--------- python/sglang/srt/models/dflash.py | 19 ++++++++------ .../sglang/srt/speculative/dflash_worker.py | 5 +++- python/sglang/srt/speculative/spec_info.py | 3 +++ 5 files changed, 33 insertions(+), 33 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 089590c5cc2c..ea3d4d21322d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -472,16 +472,9 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 - if ( - model_runner.spec_algorithm.is_eagle() - or model_runner.spec_algorithm.is_standalone() - or model_runner.spec_algorithm.is_ngram() - or model_runner.spec_algorithm.is_dflash() - ): + if model_runner.spec_algorithm.is_speculative(): if self.model_runner.is_draft_worker: - # EAGLE/standalone/ngram draft workers use separate cuda-graph runners; do not - # capture TARGET_VERIFY graphs here. DFLASH draft uses a fixed-size block and - # reuses TARGET_VERIFY graphs for performance. + # DFLASH draft workers reuse this runner for TARGET_VERIFY mode. if not self.model_runner.spec_algorithm.is_dflash(): raise RuntimeError("This should not happen") self.capture_forward_mode = ForwardMode.TARGET_VERIFY @@ -1102,15 +1095,21 @@ def replay( self.graphs[graph_key].replay() output = self.output_buffers[graph_key] - if isinstance(output, torch.Tensor): - return output[: self.raw_num_token] if isinstance(output, LogitsProcessorOutput): if self.is_dllm: next_token_logits = None - full_logits = output.full_logits[: self.raw_num_token] + full_logits = ( + output.full_logits[: self.raw_num_token] + if output.full_logits is not None + else None + ) else: full_logits = None - next_token_logits = output.next_token_logits[: self.raw_num_token] + next_token_logits = ( + output.next_token_logits[: self.raw_num_token] + if output.next_token_logits is not None + else None + ) return LogitsProcessorOutput( next_token_logits=next_token_logits, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 93dc727f4ee6..1b06dc56fa2a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1899,12 +1899,7 @@ def _should_run_flashinfer_autotune(self) -> bool: if major < 9: return False - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - or self.spec_algorithm.is_dflash() - ): + if self.spec_algorithm.is_speculative(): return not self.is_draft_worker return True @@ -1934,12 +1929,7 @@ def _dummy_run(self, batch_size: int, run_ctx=None): capture_forward_mode = ForwardMode.EXTEND capture_hidden_mode = CaptureHiddenMode.NULL num_tokens_per_bs = 1 - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - or self.spec_algorithm.is_dflash() - ): + if self.spec_algorithm.is_speculative(): if self.is_draft_worker: if not self.spec_algorithm.is_dflash(): raise RuntimeError("This should not happen") diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py index c76617e4b40b..27f5cdbf539d 100644 --- a/python/sglang/srt/models/dflash.py +++ b/python/sglang/srt/models/dflash.py @@ -20,6 +20,7 @@ QKVParallelLinear, RowParallelLinear, ) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.radix_attention import AttentionType, RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -314,7 +315,7 @@ def forward( input_embeds: Optional[torch.Tensor] = None, get_embedding: bool = False, pp_proxy_tensors=None, - ) -> torch.Tensor: + ) -> LogitsProcessorOutput: if input_embeds is None: raise ValueError( "DFlashDraftModel requires `input_embeds` (use the target embedding)." @@ -327,12 +328,16 @@ def forward( positions, hidden_states, forward_batch, residual ) - if hidden_states.numel() == 0: - return hidden_states - if residual is None: - return self.norm(hidden_states) - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + if hidden_states.numel() != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return LogitsProcessorOutput( + next_token_logits=None, + hidden_states=hidden_states, + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 353557e1f12a..b987f9cf395a 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -645,13 +645,16 @@ def _prepare_for_speculative_decoding( ) with torch.inference_mode(): - draft_hidden = self.draft_model_runner.forward( + draft_logits_output = self.draft_model_runner.forward( forward_batch ).logits_output finally: # Drop the speculative block from the shared allocator (EAGLE3-style). allocator.restore_state(token_to_kv_pool_state_backup) + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") draft_hidden = draft_hidden.view(bs, self.block_size, -1) draft_next = self._greedy_sample_from_vocab_parallel_head( hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 5c4daeb702d4..3e5727187572 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -34,6 +34,9 @@ def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm: def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE + def is_speculative(self) -> bool: + return self != SpeculativeAlgorithm.NONE + def is_eagle(self) -> bool: # NOTE: EAGLE3 is a variant of EAGLE return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 From ab97a910f6110a71e2a63468ba996551281c4480 Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Mar 2026 19:41:36 +0000 Subject: [PATCH 56/73] clean up dflash request validation --- python/sglang/srt/managers/scheduler.py | 44 ++++++++++++++----------- 1 file changed, 25 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 86c3c9f05e48..3ffa0eee6c28 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -248,6 +248,24 @@ def copy_to_cpu(self): self.copy_done.record() +def validate_dflash_request(req: Req) -> Optional[str]: + if req.return_logprob: + return "DFLASH speculative decoding does not support return_logprob yet." + + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + return ( + "DFLASH speculative decoding does not support " + "grammar-constrained decoding yet." + ) + + return None + + class Scheduler( SchedulerOutputProcessorMixin, SchedulerUpdateWeightsMixin, @@ -1633,25 +1651,13 @@ def handle_generate_request( self._add_request_to_queue(req) return - if self.spec_algorithm.is_dflash() and req.return_logprob: - req.set_finish_with_abort( - "DFLASH speculative decoding does not support return_logprob yet." - ) - self.init_req_max_new_tokens(req) - self._add_request_to_queue(req) - return - if self.spec_algorithm.is_dflash() and ( - req.sampling_params.json_schema is not None - or req.sampling_params.regex is not None - or req.sampling_params.ebnf is not None - or req.sampling_params.structural_tag is not None - ): - req.set_finish_with_abort( - "DFLASH speculative decoding does not support grammar-constrained decoding yet." - ) - self.init_req_max_new_tokens(req) - self._add_request_to_queue(req) - return + if self.spec_algorithm.is_dflash(): + error_msg = validate_dflash_request(req) + if error_msg is not None: + req.set_finish_with_abort(error_msg) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return # Handle multimodal inputs if recv_req.mm_inputs is not None: From 6231c4b7d0bcac1f8aaf4b2b9a8d2bc56bc788df Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Mar 2026 20:33:19 +0000 Subject: [PATCH 57/73] clean up stop string handling --- python/sglang/srt/speculative/dflash_info.py | 76 +++++--------------- 1 file changed, 16 insertions(+), 60 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 3b3e7434b6c1..54eab66fddd8 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -8,7 +8,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, alloc_token_slots, @@ -50,6 +50,20 @@ def _compute_paged_keep_slots( return keep_slots +def _append_verified_tokens(req: Req, proposed_tokens: List[int]) -> int: + appended = 0 + for token_id in proposed_tokens: + token_id = int(token_id) + req.output_ids.append(token_id) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(token_id) + return appended + + @dataclass class DFlashDraftInput(SpecInput): """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. @@ -397,65 +411,7 @@ def verify( int(packed[i, max_acc + 1].item()) ] - appended = 0 - if ( - req.grammar is None - and not req.sampling_params.stop_strs - and not req.sampling_params.stop_regex_strs - ): - remaining = int(req.sampling_params.max_new_tokens) - len( - req.output_ids - ) - if remaining > 0: - tokens = proposed[:remaining] - if not req.sampling_params.ignore_eos: - stop_token_ids = req.sampling_params.stop_token_ids - eos_token_ids = req.eos_token_ids - tokenizer = req.tokenizer - tokenizer_eos = ( - tokenizer.eos_token_id if tokenizer is not None else None - ) - additional_stop = ( - tokenizer.additional_stop_token_ids - if tokenizer is not None - else None - ) - vocab_size = getattr(req, "vocab_size", None) - - for j, token_id in enumerate(tokens): - if vocab_size is not None and ( - int(token_id) >= int(vocab_size) or int(token_id) < 0 - ): - tokens = tokens[: j + 1] - break - if stop_token_ids and token_id in stop_token_ids: - tokens = tokens[: j + 1] - break - if eos_token_ids and token_id in eos_token_ids: - tokens = tokens[: j + 1] - break - if tokenizer_eos is not None and int(token_id) == int( - tokenizer_eos - ): - tokens = tokens[: j + 1] - break - if additional_stop and token_id in additional_stop: - tokens = tokens[: j + 1] - break - - req.output_ids.extend(int(tok) for tok in tokens) - appended = len(tokens) - if appended > 0: - req.check_finished(new_accepted_len=appended) - else: - for tok in proposed: - req.output_ids.append(int(tok)) - appended += 1 - req.check_finished() - if req.finished(): - break - if req.grammar is not None: - req.grammar.accept_token(int(tok)) + appended = _append_verified_tokens(req, proposed) if req.output_ids: new_verified_token = int(req.output_ids[-1]) From 2a48abffede910ab6bda717c7d750fbbf48c413c Mon Sep 17 00:00:00 2001 From: David Wang Date: Tue, 24 Mar 2026 20:45:50 +0000 Subject: [PATCH 58/73] inline stop strings logic --- python/sglang/srt/speculative/dflash_info.py | 27 ++++++++------------ 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py index 54eab66fddd8..fbb06cc70ee1 100644 --- a/python/sglang/srt/speculative/dflash_info.py +++ b/python/sglang/srt/speculative/dflash_info.py @@ -8,7 +8,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor -from sglang.srt.managers.schedule_batch import Req, ScheduleBatch +from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.common import ( alloc_paged_token_slots_extend, alloc_token_slots, @@ -50,20 +50,6 @@ def _compute_paged_keep_slots( return keep_slots -def _append_verified_tokens(req: Req, proposed_tokens: List[int]) -> int: - appended = 0 - for token_id in proposed_tokens: - token_id = int(token_id) - req.output_ids.append(token_id) - appended += 1 - req.check_finished() - if req.finished(): - break - if req.grammar is not None: - req.grammar.accept_token(token_id) - return appended - - @dataclass class DFlashDraftInput(SpecInput): """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. @@ -411,7 +397,16 @@ def verify( int(packed[i, max_acc + 1].item()) ] - appended = _append_verified_tokens(req, proposed) + appended = 0 + for token_id in proposed: + token_id = int(token_id) + req.output_ids.append(token_id) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(token_id) if req.output_ids: new_verified_token = int(req.output_ids[-1]) From bea986b2f91432e0bdf2c283527427b0e63fcda1 Mon Sep 17 00:00:00 2001 From: David Wang Date: Wed, 25 Mar 2026 02:09:21 +0000 Subject: [PATCH 59/73] dflash tests --- .../test/server_fixtures/dflash_fixture.py | 40 +++ python/sglang/test/test_utils.py | 4 + .../spec/dflash/test_dflash_basic.py | 65 ++++ .../spec/dflash/test_dflash_infer_a.py | 283 ++++++++++++++++++ .../spec/dflash/test_dflash_infer_beta.py | 101 +++++++ 5 files changed, 493 insertions(+) create mode 100644 python/sglang/test/server_fixtures/dflash_fixture.py create mode 100644 test/registered/spec/dflash/test_dflash_basic.py create mode 100644 test/registered/spec/dflash/test_dflash_infer_a.py create mode 100644 test/registered/spec/dflash/test_dflash_infer_beta.py diff --git a/python/sglang/test/server_fixtures/dflash_fixture.py b/python/sglang/test/server_fixtures/dflash_fixture.py new file mode 100644 index 000000000000..9ad1692fdfb1 --- /dev/null +++ b/python/sglang/test/server_fixtures/dflash_fixture.py @@ -0,0 +1,40 @@ +from sglang.srt.environ import envs +from sglang.srt.utils.common import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class DFlashServerBase(CustomTestCase): + target_model = DEFAULT_TARGET_MODEL_DFLASH + draft_model = DEFAULT_DRAFT_MODEL_DFLASH + spec_algo = "DFLASH" + spec_block_size = 16 + extra_args = [] + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + with envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): + cls.process = popen_launch_server( + cls.target_model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + f"--speculative-algorithm={cls.spec_algo}", + f"--speculative-draft-model-path={cls.draft_model}", + f"--speculative-num-draft-tokens={cls.spec_block_size}", + ] + + cls.extra_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index e24cdfced1db..6132ba7ebe86 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -106,6 +106,10 @@ DEFAULT_TARGET_MODEL_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" DEFAULT_DRAFT_MODEL_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" +# DFLASH model +DEFAULT_TARGET_MODEL_DFLASH = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_DRAFT_MODEL_DFLASH = "z-lab/LLaMA3.1-8B-Instruct-DFlash-UltraChat" + # EAGLE2 with DP-Attention models DEFAULT_TARGET_MODEL_EAGLE_DP_ATTN = "Qwen/Qwen3-30B-A3B" DEFAULT_DRAFT_MODEL_EAGLE_DP_ATTN = "Tengyunw/qwen3_30b_moe_eagle3" diff --git a/test/registered/spec/dflash/test_dflash_basic.py b/test/registered/spec/dflash/test_dflash_basic.py new file mode 100644 index 000000000000..7c34c9402753 --- /dev/null +++ b/test/registered/spec/dflash/test_dflash_basic.py @@ -0,0 +1,65 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.run_eval import run_eval +from sglang.test.server_fixtures.dflash_fixture import DFlashServerBase +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, +) + +register_cuda_ci(est_time=50, suite="stage-b-test-small-1-gpu") + + +class TestDFlashBasic(DFlashServerBase): + target_model = DEFAULT_TARGET_MODEL_DFLASH + draft_model = DEFAULT_DRAFT_MODEL_DFLASH + + spec_algo = "DFLASH" + spec_block_size = 16 + + extra_args = [ + "--dtype", + "float16", + "--chunked-prefill-size", + 1024, + ] + + @classmethod + def setUpClass(cls): + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + super().setUpClass() + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.target_model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.72) + + server_info = requests.get(self.base_url + "/server_info").json() + avg_spec_accept_length = server_info["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 3.15) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/spec/dflash/test_dflash_infer_a.py b/test/registered/spec/dflash/test_dflash_infer_a.py new file mode 100644 index 000000000000..d2375bb9c8c3 --- /dev/null +++ b/test/registered/spec/dflash/test_dflash_infer_a.py @@ -0,0 +1,283 @@ +import os +import random +import unittest + +import sglang as sgl +from sglang.srt.utils.hf_transformers_utils import get_tokenizer +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, + CustomTestCase, +) + +register_cuda_ci(est_time=561, suite="stage-b-test-large-1-gpu") + + +class TestDFlashEngine(CustomTestCase): + BASE_CONFIG = { + "model_path": DEFAULT_TARGET_MODEL_DFLASH, + "speculative_draft_model_path": DEFAULT_DRAFT_MODEL_DFLASH, + "speculative_algorithm": "DFLASH", + "speculative_num_draft_tokens": 16, + "cuda_graph_max_bs": 5, + "dtype": "bfloat16", + "trust_remote_code": True, + } + NUM_CONFIGS = 2 + + THRESHOLDS = { + "batch_avg_accept_len": 2.09, + "accept_len": 5.61, + } + + def setUp(self): + self.prompt = "Today is a sunny day and I like" + self.sampling_params = {"temperature": 0, "max_new_tokens": 8} + + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + ref_engine = sgl.Engine( + model_path=self.BASE_CONFIG["model_path"], + cuda_graph_max_bs=1, + ) + self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)[ + "text" + ] + ref_engine.shutdown() + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + def test_correctness(self): + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + configs = [ + # Basic config + self.BASE_CONFIG, + # Chunked prefill + {**self.BASE_CONFIG, "chunked_prefill_size": 4}, + ] + + for i, config in enumerate(configs[: self.NUM_CONFIGS]): + with self.subTest(i=i): + print(f"{config=}") + engine = sgl.Engine( + **config, log_level="info", decode_log_interval=10 + ) + try: + self._test_single_generation(engine) + self._test_first_token_finish(engine) + self._test_batch_generation(engine) + self._test_eos_token(engine) + self._test_acc_length(engine) + finally: + engine.flush_cache() + engine.shutdown() + print("=" * 100) + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + def _test_single_generation(self, engine): + output = engine.generate(self.prompt, self.sampling_params)["text"] + print(f"{output=}, {self.ref_output=}") + self.assertEqual(output, self.ref_output) + + def _test_batch_generation(self, engine): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + params = {"temperature": 0, "max_new_tokens": 50} + + outputs = engine.generate(prompts, params) + for prompt, output in zip(prompts, outputs): + print(f"Prompt: {prompt}") + print(f"Generated: {output['text']}") + print("-" * 40) + + print(f"{engine.get_server_info()=}") + + avg_spec_accept_length = engine.get_server_info()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater( + avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"] + ) + + def _test_first_token_finish(self, engine): + prompt = [ + f"There are {i} apples on the table. How to divide them equally?" + for i in range(8) + ] + params = [ + {"temperature": 0, "max_new_tokens": random.randint(1, 3)} for _ in range(8) + ] + outputs = engine.generate(prompt, params) + for i, output in enumerate(outputs): + print(f"Prompt: {prompt[i]}") + print(f"Generated: {output['text']}") + print("-" * 40) + + def _test_eos_token(self, engine): + prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" + params = { + "temperature": 0.1, + "max_new_tokens": 1024, + "skip_special_tokens": False, + } + + tokenizer = get_tokenizer(DEFAULT_TARGET_MODEL_DFLASH) + output = engine.generate(prompt, params)["text"] + print(f"{output=}") + + tokens = tokenizer.encode(output, truncation=False) + self.assertNotIn(tokenizer.eos_token_id, tokens) + + def _test_acc_length(self, engine): + prompt = [ + "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", + ] * 5 + sampling_params = {"temperature": 0, "max_new_tokens": 512} + output = engine.generate(prompt, sampling_params) + output = output[0] + + if "spec_verify_ct" in output["meta_info"]: + acc_length = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["spec_verify_ct"] + ) + else: + acc_length = 1.0 + + speed = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["e2e_latency"] + ) + print(f"{acc_length=:.4f}, {speed=}") + + self.assertGreater(acc_length, self.THRESHOLDS["accept_len"]) + + +class TestDFlashRadixCache(CustomTestCase): + BASE_CONFIG = { + "model_path": DEFAULT_TARGET_MODEL_DFLASH, + "speculative_draft_model_path": DEFAULT_DRAFT_MODEL_DFLASH, + "speculative_algorithm": "DFLASH", + "speculative_num_draft_tokens": 16, + "dtype": "bfloat16", + "trust_remote_code": True, + "attention_backend": "flashinfer", + "skip_server_warmup": True, + "cuda_graph_max_bs": 5, + } + THRESHOLDS = { + "batch_avg_accept_len": 4.2, + "accept_len": 4.3, + } + + def test_correctness(self): + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + configs = [ + # Basic config + self.BASE_CONFIG, + # Chunked prefill & Page Size > 1 + {**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4}, + {**self.BASE_CONFIG, "page_size": 4}, + # Large page size tend to expose IMA bugs. + {**self.BASE_CONFIG, "page_size": 256}, + {**self.BASE_CONFIG, "cuda_graph_bs": [5], "page_size": 4}, + # Disable CUDA Graph + { + **self.BASE_CONFIG, + "disable_cuda_graph": True, + "page_size": 4, + }, + ] + + try: + for i, config in enumerate(configs): + with self.subTest(i=i): + print(f"{config=}") + engine = sgl.Engine( + **config, log_level="info", decode_log_interval=10 + ) + try: + self._test_acc_length(engine) + self._test_batch_generation(engine) + finally: + engine.shutdown() + print("=" * 100) + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + def _test_acc_length(self, engine): + warmup_prompt = [ + "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 512} + engine.generate(warmup_prompt, sampling_params) + test_prompt = [ + "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ] + output = engine.generate(test_prompt, sampling_params) + output = output[0] + + if "spec_verify_ct" in output["meta_info"]: + acc_length = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["spec_verify_ct"] + ) + else: + acc_length = 1.0 + + speed = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["e2e_latency"] + ) + + print(f"{acc_length=:.4f}, {speed=}") + self.assertGreater(acc_length, self.THRESHOLDS["accept_len"]) + + def _test_batch_generation(self, engine): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + params = {"temperature": 0, "max_new_tokens": 50} + + outputs = engine.generate(prompts, params) + for prompt, output in zip(prompts, outputs): + print(f"Prompt: {prompt}") + print(f"Generated: {output['text']}") + print("-" * 40) + + print(f"{engine.get_server_info()=}") + + avg_spec_accept_length = engine.get_server_info()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater( + avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"] + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/registered/spec/dflash/test_dflash_infer_beta.py b/test/registered/spec/dflash/test_dflash_infer_beta.py new file mode 100644 index 000000000000..c5cea35a7e46 --- /dev/null +++ b/test/registered/spec/dflash/test_dflash_infer_beta.py @@ -0,0 +1,101 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.environ import envs +from sglang.srt.utils import kill_process_tree +from sglang.test.ci.ci_register import register_cuda_ci +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.kits.matched_stop_kit import MatchedStopMixin +from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test +from sglang.test.test_utils import ( + DEFAULT_DRAFT_MODEL_DFLASH, + DEFAULT_TARGET_MODEL_DFLASH, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +register_cuda_ci(est_time=283, suite="stage-b-test-small-1-gpu") + + +class TestDFlashServerBase(CustomTestCase, MatchedStopMixin): + max_running_requests = 64 + attention_backend = "flashinfer" + page_size = 1 + other_launch_args = [] + model = DEFAULT_TARGET_MODEL_DFLASH + draft_model = DEFAULT_DRAFT_MODEL_DFLASH + + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + launch_args = [ + "--trust-remote-code", + "--attention-backend", + cls.attention_backend, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + cls.draft_model, + "--page-size", + str(cls.page_size), + "--max-running-requests", + str(cls.max_running_requests), + "--cuda-graph-bs", + *[str(i) for i in range(1, cls.max_running_requests + 1)], + ] + launch_args.extend(cls.other_launch_args) + old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" + try: + with envs.SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_BUSY.override( + 1 + ), envs.SGLANG_SPEC_NAN_DETECTION.override( + True + ), envs.SGLANG_SPEC_OOB_DETECTION.override( + True + ): + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=launch_args, + ) + finally: + if old_value is None: + del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] + else: + os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_radix_attention(self): + run_radix_attention_test(self.base_url) + assert self.process.poll() is None + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=1000, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"TestDFlashServerBase -- {metrics=}") + self.assertGreater(metrics["accuracy"], 0.23) + assert self.process.poll() is None + + +class TestDFlashServerPage(TestDFlashServerBase): + page_size = 64 + + +if __name__ == "__main__": + unittest.main() From 5e234f88ab622406bd76eb4a30ff4c2a64cb0387 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 00:00:48 -0700 Subject: [PATCH 60/73] lazy import dflash_utils in model_runner --- python/sglang/srt/model_executor/model_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1b06dc56fa2a..179c9564997c 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -150,9 +150,6 @@ get_global_server_args, set_global_server_args_for_scheduler, ) -from sglang.srt.speculative.dflash_utils import ( - parse_dflash_draft_config, -) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, @@ -377,6 +374,10 @@ def __init__( self.eagle_aux_hidden_state_layer_ids = None if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + # Select target layers to capture for building DFlash context features. draft_model_config = ModelConfig.from_server_args( server_args, From aff48a22c034e0be3e4e9bfa3acfcd6d7855a0fc Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 00:17:24 -0700 Subject: [PATCH 61/73] lazy import dflash_utils in kv_cache_mixin --- .../sglang/srt/model_executor/model_runner_kv_cache_mixin.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index dadf0feeb1e6..8720fbf06f7f 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -24,7 +24,6 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator -from sglang.srt.speculative.dflash_utils import scale_kv_cell_size_per_token_for_dflash from sglang.srt.utils.common import ( get_available_gpu_memory, is_float4_e2m1fn_x2, @@ -141,6 +140,10 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): cell_size = self.get_cell_size_per_token(num_layers) if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + from sglang.srt.speculative.dflash_utils import ( + scale_kv_cell_size_per_token_for_dflash, + ) + draft_num_layers = getattr(self, "dflash_draft_num_layers", None) if ( draft_num_layers is not None From 0f0132be6af55c1838885ebc77685a23242d942d Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 00:51:40 -0700 Subject: [PATCH 62/73] migrate dflash infer_a to infer_b; rename beta to b; delete engine test --- .../spec/dflash/test_dflash_infer_a.py | 283 ------------------ ...h_infer_beta.py => test_dflash_infer_b.py} | 60 +++- 2 files changed, 59 insertions(+), 284 deletions(-) delete mode 100644 test/registered/spec/dflash/test_dflash_infer_a.py rename test/registered/spec/dflash/{test_dflash_infer_beta.py => test_dflash_infer_b.py} (60%) diff --git a/test/registered/spec/dflash/test_dflash_infer_a.py b/test/registered/spec/dflash/test_dflash_infer_a.py deleted file mode 100644 index d2375bb9c8c3..000000000000 --- a/test/registered/spec/dflash/test_dflash_infer_a.py +++ /dev/null @@ -1,283 +0,0 @@ -import os -import random -import unittest - -import sglang as sgl -from sglang.srt.utils.hf_transformers_utils import get_tokenizer -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.test_utils import ( - DEFAULT_DRAFT_MODEL_DFLASH, - DEFAULT_TARGET_MODEL_DFLASH, - CustomTestCase, -) - -register_cuda_ci(est_time=561, suite="stage-b-test-large-1-gpu") - - -class TestDFlashEngine(CustomTestCase): - BASE_CONFIG = { - "model_path": DEFAULT_TARGET_MODEL_DFLASH, - "speculative_draft_model_path": DEFAULT_DRAFT_MODEL_DFLASH, - "speculative_algorithm": "DFLASH", - "speculative_num_draft_tokens": 16, - "cuda_graph_max_bs": 5, - "dtype": "bfloat16", - "trust_remote_code": True, - } - NUM_CONFIGS = 2 - - THRESHOLDS = { - "batch_avg_accept_len": 2.09, - "accept_len": 5.61, - } - - def setUp(self): - self.prompt = "Today is a sunny day and I like" - self.sampling_params = {"temperature": 0, "max_new_tokens": 8} - - old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" - try: - ref_engine = sgl.Engine( - model_path=self.BASE_CONFIG["model_path"], - cuda_graph_max_bs=1, - ) - self.ref_output = ref_engine.generate(self.prompt, self.sampling_params)[ - "text" - ] - ref_engine.shutdown() - finally: - if old_value is None: - del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] - else: - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value - - def test_correctness(self): - old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" - try: - configs = [ - # Basic config - self.BASE_CONFIG, - # Chunked prefill - {**self.BASE_CONFIG, "chunked_prefill_size": 4}, - ] - - for i, config in enumerate(configs[: self.NUM_CONFIGS]): - with self.subTest(i=i): - print(f"{config=}") - engine = sgl.Engine( - **config, log_level="info", decode_log_interval=10 - ) - try: - self._test_single_generation(engine) - self._test_first_token_finish(engine) - self._test_batch_generation(engine) - self._test_eos_token(engine) - self._test_acc_length(engine) - finally: - engine.flush_cache() - engine.shutdown() - print("=" * 100) - finally: - if old_value is None: - del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] - else: - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value - - def _test_single_generation(self, engine): - output = engine.generate(self.prompt, self.sampling_params)["text"] - print(f"{output=}, {self.ref_output=}") - self.assertEqual(output, self.ref_output) - - def _test_batch_generation(self, engine): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - params = {"temperature": 0, "max_new_tokens": 50} - - outputs = engine.generate(prompts, params) - for prompt, output in zip(prompts, outputs): - print(f"Prompt: {prompt}") - print(f"Generated: {output['text']}") - print("-" * 40) - - print(f"{engine.get_server_info()=}") - - avg_spec_accept_length = engine.get_server_info()["internal_states"][0][ - "avg_spec_accept_length" - ] - print(f"{avg_spec_accept_length=}") - self.assertGreater( - avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"] - ) - - def _test_first_token_finish(self, engine): - prompt = [ - f"There are {i} apples on the table. How to divide them equally?" - for i in range(8) - ] - params = [ - {"temperature": 0, "max_new_tokens": random.randint(1, 3)} for _ in range(8) - ] - outputs = engine.generate(prompt, params) - for i, output in enumerate(outputs): - print(f"Prompt: {prompt[i]}") - print(f"Generated: {output['text']}") - print("-" * 40) - - def _test_eos_token(self, engine): - prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" - params = { - "temperature": 0.1, - "max_new_tokens": 1024, - "skip_special_tokens": False, - } - - tokenizer = get_tokenizer(DEFAULT_TARGET_MODEL_DFLASH) - output = engine.generate(prompt, params)["text"] - print(f"{output=}") - - tokens = tokenizer.encode(output, truncation=False) - self.assertNotIn(tokenizer.eos_token_id, tokens) - - def _test_acc_length(self, engine): - prompt = [ - "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", - ] * 5 - sampling_params = {"temperature": 0, "max_new_tokens": 512} - output = engine.generate(prompt, sampling_params) - output = output[0] - - if "spec_verify_ct" in output["meta_info"]: - acc_length = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["spec_verify_ct"] - ) - else: - acc_length = 1.0 - - speed = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["e2e_latency"] - ) - print(f"{acc_length=:.4f}, {speed=}") - - self.assertGreater(acc_length, self.THRESHOLDS["accept_len"]) - - -class TestDFlashRadixCache(CustomTestCase): - BASE_CONFIG = { - "model_path": DEFAULT_TARGET_MODEL_DFLASH, - "speculative_draft_model_path": DEFAULT_DRAFT_MODEL_DFLASH, - "speculative_algorithm": "DFLASH", - "speculative_num_draft_tokens": 16, - "dtype": "bfloat16", - "trust_remote_code": True, - "attention_backend": "flashinfer", - "skip_server_warmup": True, - "cuda_graph_max_bs": 5, - } - THRESHOLDS = { - "batch_avg_accept_len": 4.2, - "accept_len": 4.3, - } - - def test_correctness(self): - old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" - configs = [ - # Basic config - self.BASE_CONFIG, - # Chunked prefill & Page Size > 1 - {**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4}, - {**self.BASE_CONFIG, "page_size": 4}, - # Large page size tend to expose IMA bugs. - {**self.BASE_CONFIG, "page_size": 256}, - {**self.BASE_CONFIG, "cuda_graph_bs": [5], "page_size": 4}, - # Disable CUDA Graph - { - **self.BASE_CONFIG, - "disable_cuda_graph": True, - "page_size": 4, - }, - ] - - try: - for i, config in enumerate(configs): - with self.subTest(i=i): - print(f"{config=}") - engine = sgl.Engine( - **config, log_level="info", decode_log_interval=10 - ) - try: - self._test_acc_length(engine) - self._test_batch_generation(engine) - finally: - engine.shutdown() - print("=" * 100) - finally: - if old_value is None: - del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] - else: - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value - - def _test_acc_length(self, engine): - warmup_prompt = [ - "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", - ] - sampling_params = {"temperature": 0, "max_new_tokens": 512} - engine.generate(warmup_prompt, sampling_params) - test_prompt = [ - "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - ] - output = engine.generate(test_prompt, sampling_params) - output = output[0] - - if "spec_verify_ct" in output["meta_info"]: - acc_length = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["spec_verify_ct"] - ) - else: - acc_length = 1.0 - - speed = ( - output["meta_info"]["completion_tokens"] - / output["meta_info"]["e2e_latency"] - ) - - print(f"{acc_length=:.4f}, {speed=}") - self.assertGreater(acc_length, self.THRESHOLDS["accept_len"]) - - def _test_batch_generation(self, engine): - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - params = {"temperature": 0, "max_new_tokens": 50} - - outputs = engine.generate(prompts, params) - for prompt, output in zip(prompts, outputs): - print(f"Prompt: {prompt}") - print(f"Generated: {output['text']}") - print("-" * 40) - - print(f"{engine.get_server_info()=}") - - avg_spec_accept_length = engine.get_server_info()["internal_states"][0][ - "avg_spec_accept_length" - ] - print(f"{avg_spec_accept_length=}") - self.assertGreater( - avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"] - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/registered/spec/dflash/test_dflash_infer_beta.py b/test/registered/spec/dflash/test_dflash_infer_b.py similarity index 60% rename from test/registered/spec/dflash/test_dflash_infer_beta.py rename to test/registered/spec/dflash/test_dflash_infer_b.py index c5cea35a7e46..b30f551da3ac 100644 --- a/test/registered/spec/dflash/test_dflash_infer_beta.py +++ b/test/registered/spec/dflash/test_dflash_infer_b.py @@ -2,6 +2,8 @@ import unittest from types import SimpleNamespace +import openai + from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci @@ -17,7 +19,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=283, suite="stage-b-test-small-1-gpu") +register_cuda_ci(est_time=360, suite="stage-b-test-small-1-gpu") class TestDFlashServerBase(CustomTestCase, MatchedStopMixin): @@ -92,10 +94,66 @@ def test_gsm8k(self): self.assertGreater(metrics["accuracy"], 0.23) assert self.process.poll() is None + def test_early_stop(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + for i in range(8): + max_tokens = (i % 3) + 1 + response = client.completions.create( + model=self.model, + prompt=f"There are {i} apples on the table. How to divide them equally?", + max_tokens=max_tokens, + temperature=0, + ) + text = response.choices[0].text + print(f"early_stop: max_tokens={max_tokens}, text={text!r}") + assert self.process.poll() is None + + def test_eos_handling(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + response = client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": "Today is a sunny day and I like"}], + max_tokens=256, + temperature=0.1, + ) + text = response.choices[0].message.content + print(f"eos_handling: text={text!r}") + self.assertNotIn("<|eot_id|>", text) + self.assertNotIn("<|end_of_text|>", text) + assert self.process.poll() is None + + def test_greedy_determinism(self): + client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") + prompt = "The capital of France is" + outputs = [] + for _ in range(2): + response = client.completions.create( + model=self.model, + prompt=prompt, + max_tokens=32, + temperature=0, + ) + outputs.append(response.choices[0].text) + print(f"determinism: {outputs=}") + self.assertEqual(outputs[0], outputs[1]) + assert self.process.poll() is None + class TestDFlashServerPage(TestDFlashServerBase): page_size = 64 +class TestDFlashServerPage256(TestDFlashServerBase): + page_size = 256 + + +class TestDFlashServerChunkedPrefill(TestDFlashServerBase): + other_launch_args = ["--chunked-prefill-size", "4"] + + +class TestDFlashServerNoCudaGraph(TestDFlashServerBase): + other_launch_args = ["--disable-cuda-graph"] + + if __name__ == "__main__": unittest.main() From dab5401099b7c5ed4df89210f00f873893e0b72e Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 00:55:28 -0700 Subject: [PATCH 63/73] use GSM8KMixin; rename to test_dflash.py --- ...{test_dflash_infer_b.py => test_dflash.py} | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) rename test/registered/spec/dflash/{test_dflash_infer_b.py => test_dflash.py} (88%) diff --git a/test/registered/spec/dflash/test_dflash_infer_b.py b/test/registered/spec/dflash/test_dflash.py similarity index 88% rename from test/registered/spec/dflash/test_dflash_infer_b.py rename to test/registered/spec/dflash/test_dflash.py index b30f551da3ac..96211c6ddba7 100644 --- a/test/registered/spec/dflash/test_dflash_infer_b.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -1,13 +1,12 @@ import os import unittest -from types import SimpleNamespace import openai from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.kits.gsm8k_accuracy_kit import GSM8KMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test from sglang.test.test_utils import ( @@ -22,13 +21,14 @@ register_cuda_ci(est_time=360, suite="stage-b-test-small-1-gpu") -class TestDFlashServerBase(CustomTestCase, MatchedStopMixin): +class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): max_running_requests = 64 attention_backend = "flashinfer" page_size = 1 other_launch_args = [] model = DEFAULT_TARGET_MODEL_DFLASH draft_model = DEFAULT_DRAFT_MODEL_DFLASH + gsm8k_accuracy_thres = 0.23 @classmethod def setUpClass(cls): @@ -79,21 +79,6 @@ def test_radix_attention(self): run_radix_attention_test(self.base_url) assert self.process.poll() is None - def test_gsm8k(self): - args = SimpleNamespace( - num_shots=5, - data_path=None, - num_questions=1000, - max_new_tokens=512, - parallel=128, - host="http://127.0.0.1", - port=int(self.base_url.split(":")[-1]), - ) - metrics = run_eval(args) - print(f"TestDFlashServerBase -- {metrics=}") - self.assertGreater(metrics["accuracy"], 0.23) - assert self.process.poll() is None - def test_early_stop(self): client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") for i in range(8): From fca00c789522e6bea359a04792921d840f1fbd85 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 00:58:47 -0700 Subject: [PATCH 64/73] delete test_dflash_basic; add accept_length_thres; remove unused fixture --- .../test/server_fixtures/dflash_fixture.py | 40 ------------ test/registered/spec/dflash/test_dflash.py | 1 + .../spec/dflash/test_dflash_basic.py | 65 ------------------- 3 files changed, 1 insertion(+), 105 deletions(-) delete mode 100644 python/sglang/test/server_fixtures/dflash_fixture.py delete mode 100644 test/registered/spec/dflash/test_dflash_basic.py diff --git a/python/sglang/test/server_fixtures/dflash_fixture.py b/python/sglang/test/server_fixtures/dflash_fixture.py deleted file mode 100644 index 9ad1692fdfb1..000000000000 --- a/python/sglang/test/server_fixtures/dflash_fixture.py +++ /dev/null @@ -1,40 +0,0 @@ -from sglang.srt.environ import envs -from sglang.srt.utils.common import kill_process_tree -from sglang.test.test_utils import ( - DEFAULT_DRAFT_MODEL_DFLASH, - DEFAULT_TARGET_MODEL_DFLASH, - DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - DEFAULT_URL_FOR_TEST, - CustomTestCase, - popen_launch_server, -) - - -class DFlashServerBase(CustomTestCase): - target_model = DEFAULT_TARGET_MODEL_DFLASH - draft_model = DEFAULT_DRAFT_MODEL_DFLASH - spec_algo = "DFLASH" - spec_block_size = 16 - extra_args = [] - - @classmethod - def setUpClass(cls): - cls.base_url = DEFAULT_URL_FOR_TEST - with envs.SGLANG_SPEC_NAN_DETECTION.override( - True - ), envs.SGLANG_SPEC_OOB_DETECTION.override(True): - cls.process = popen_launch_server( - cls.target_model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=[ - f"--speculative-algorithm={cls.spec_algo}", - f"--speculative-draft-model-path={cls.draft_model}", - f"--speculative-num-draft-tokens={cls.spec_block_size}", - ] - + cls.extra_args, - ) - - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 96211c6ddba7..1a276d65abaf 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -29,6 +29,7 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): model = DEFAULT_TARGET_MODEL_DFLASH draft_model = DEFAULT_DRAFT_MODEL_DFLASH gsm8k_accuracy_thres = 0.23 + gsm8k_accept_length_thres = 3.15 @classmethod def setUpClass(cls): diff --git a/test/registered/spec/dflash/test_dflash_basic.py b/test/registered/spec/dflash/test_dflash_basic.py deleted file mode 100644 index 7c34c9402753..000000000000 --- a/test/registered/spec/dflash/test_dflash_basic.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import unittest -from types import SimpleNamespace - -import requests - -from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.run_eval import run_eval -from sglang.test.server_fixtures.dflash_fixture import DFlashServerBase -from sglang.test.test_utils import ( - DEFAULT_DRAFT_MODEL_DFLASH, - DEFAULT_TARGET_MODEL_DFLASH, -) - -register_cuda_ci(est_time=50, suite="stage-b-test-small-1-gpu") - - -class TestDFlashBasic(DFlashServerBase): - target_model = DEFAULT_TARGET_MODEL_DFLASH - draft_model = DEFAULT_DRAFT_MODEL_DFLASH - - spec_algo = "DFLASH" - spec_block_size = 16 - - extra_args = [ - "--dtype", - "float16", - "--chunked-prefill-size", - 1024, - ] - - @classmethod - def setUpClass(cls): - old_value = os.environ.get("SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN") - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1" - try: - super().setUpClass() - finally: - if old_value is None: - del os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] - else: - os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = old_value - - def test_mmlu(self): - args = SimpleNamespace( - base_url=self.base_url, - model=self.target_model, - eval_name="mmlu", - num_examples=64, - num_threads=32, - ) - - metrics = run_eval(args) - self.assertGreaterEqual(metrics["score"], 0.72) - - server_info = requests.get(self.base_url + "/server_info").json() - avg_spec_accept_length = server_info["internal_states"][0][ - "avg_spec_accept_length" - ] - print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 3.15) - - -if __name__ == "__main__": - unittest.main() From 085bbb7a4c93e44d1394eb17a06973216cb5e8a6 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 01:06:29 -0700 Subject: [PATCH 65/73] fix suite name: stage-b-test-1-gpu-small --- test/registered/spec/dflash/test_dflash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 1a276d65abaf..9069cd4d0a0c 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -18,7 +18,7 @@ popen_launch_server, ) -register_cuda_ci(est_time=360, suite="stage-b-test-small-1-gpu") +register_cuda_ci(est_time=360, suite="stage-b-test-1-gpu-small") class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): From 80a238a1b65a12f35820d8364b96e0bd3294f1d2 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 01:09:57 -0700 Subject: [PATCH 66/73] fix GSM8KMixin import path --- test/registered/spec/dflash/test_dflash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 9069cd4d0a0c..370b4aa7a46d 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -6,7 +6,7 @@ from sglang.srt.environ import envs from sglang.srt.utils import kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci -from sglang.test.kits.gsm8k_accuracy_kit import GSM8KMixin +from sglang.test.kits.eval_accuracy_kit import GSM8KMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test from sglang.test.test_utils import ( From 97017f28f2f3d50792fe25b3785990c7f42c4511 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 11:03:53 -0700 Subject: [PATCH 67/73] pass memory_pool_config to draft worker --- python/sglang/srt/speculative/dflash_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index b987f9cf395a..030aa21e5b35 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -142,6 +142,7 @@ def __init__( is_draft_worker=True, req_to_token_pool=shared_req_to_token_pool, token_to_kv_pool_allocator=target_token_to_kv_pool_allocator, + memory_pool_config=target_worker.model_runner.memory_pool_config, ) set_global_server_args_for_scheduler(saved_server_args) self.draft_model_runner = self.draft_worker.model_runner From 213bf695372bb16818a369df588fa2f67766bfa4 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 11:10:33 -0700 Subject: [PATCH 68/73] set gsm8k accept_length_thres to 2.8 --- test/registered/spec/dflash/test_dflash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 370b4aa7a46d..3822521dfad1 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -29,7 +29,7 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): model = DEFAULT_TARGET_MODEL_DFLASH draft_model = DEFAULT_DRAFT_MODEL_DFLASH gsm8k_accuracy_thres = 0.23 - gsm8k_accept_length_thres = 3.15 + gsm8k_accept_length_thres = 2.8 @classmethod def setUpClass(cls): From 2b833bbd5a3787ba31d881c1008c09e0ebce50f2 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 11:11:14 -0700 Subject: [PATCH 69/73] gsm8k thresholds: accuracy 0.75, accept_length 2.8 --- test/registered/spec/dflash/test_dflash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 3822521dfad1..7c555bc79b5f 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -28,7 +28,7 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): other_launch_args = [] model = DEFAULT_TARGET_MODEL_DFLASH draft_model = DEFAULT_DRAFT_MODEL_DFLASH - gsm8k_accuracy_thres = 0.23 + gsm8k_accuracy_thres = 0.75 gsm8k_accept_length_thres = 2.8 @classmethod From 16c61a6dfd9d1f9ff8497f14a9a662556a281a68 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 11:30:24 -0700 Subject: [PATCH 70/73] reduce CI time: drop page64; radix test only on page256 with 50 nodes --- test/registered/spec/dflash/test_dflash.py | 25 ++++++++++++++-------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 7c555bc79b5f..6ef12d355d0b 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -8,7 +8,7 @@ from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.kits.eval_accuracy_kit import GSM8KMixin from sglang.test.kits.matched_stop_kit import MatchedStopMixin -from sglang.test.kits.radix_cache_server_kit import run_radix_attention_test +from sglang.test.kits.radix_cache_server_kit import gen_radix_tree from sglang.test.test_utils import ( DEFAULT_DRAFT_MODEL_DFLASH, DEFAULT_TARGET_MODEL_DFLASH, @@ -76,10 +76,6 @@ def setUpClass(cls): def tearDownClass(cls): kill_process_tree(cls.process.pid) - def test_radix_attention(self): - run_radix_attention_test(self.base_url) - assert self.process.poll() is None - def test_early_stop(self): client = openai.Client(base_url=self.base_url + "/v1", api_key="EMPTY") for i in range(8): @@ -125,13 +121,24 @@ def test_greedy_determinism(self): assert self.process.poll() is None -class TestDFlashServerPage(TestDFlashServerBase): - page_size = 64 - - class TestDFlashServerPage256(TestDFlashServerBase): page_size = 256 + def test_radix_attention(self): + import requests + + nodes = gen_radix_tree(num_nodes=50) + data = { + "input_ids": [node["input_ids"] for node in nodes], + "sampling_params": [ + {"max_new_tokens": node["decode_len"], "temperature": 0} + for node in nodes + ], + } + res = requests.post(self.base_url + "/generate", json=data) + assert res.status_code == 200 + assert self.process.poll() is None + class TestDFlashServerChunkedPrefill(TestDFlashServerBase): other_launch_args = ["--chunked-prefill-size", "4"] From ef360dd90b46f5ba542ddd4842cf7c701698fce4 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 11:37:36 -0700 Subject: [PATCH 71/73] increase max_running_requests to 128 --- test/registered/spec/dflash/test_dflash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 6ef12d355d0b..87ac158b2eca 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -22,7 +22,7 @@ class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): - max_running_requests = 64 + max_running_requests = 128 attention_backend = "flashinfer" page_size = 1 other_launch_args = [] From a4a2504e8eaa4353ec7b2206ef89b4adc07fb82e Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 11:42:33 -0700 Subject: [PATCH 72/73] add mem-fraction-static 0.65 for 5090 --- test/registered/spec/dflash/test_dflash.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 87ac158b2eca..8584f0b80719 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -46,6 +46,8 @@ def setUpClass(cls): str(cls.page_size), "--max-running-requests", str(cls.max_running_requests), + "--mem-fraction-static", + "0.65", "--cuda-graph-bs", *[str(i) for i in range(1, cls.max_running_requests + 1)], ] From f937b282910cd8a570d299f0d12cae8da3aeceb7 Mon Sep 17 00:00:00 2001 From: hnyls2002 Date: Tue, 7 Apr 2026 12:32:43 -0700 Subject: [PATCH 73/73] revert to max_running 64; est_time 300 --- test/registered/spec/dflash/test_dflash.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/registered/spec/dflash/test_dflash.py b/test/registered/spec/dflash/test_dflash.py index 8584f0b80719..aa9ee2327d21 100644 --- a/test/registered/spec/dflash/test_dflash.py +++ b/test/registered/spec/dflash/test_dflash.py @@ -18,11 +18,11 @@ popen_launch_server, ) -register_cuda_ci(est_time=360, suite="stage-b-test-1-gpu-small") +register_cuda_ci(est_time=300, suite="stage-b-test-1-gpu-small") class TestDFlashServerBase(CustomTestCase, MatchedStopMixin, GSM8KMixin): - max_running_requests = 128 + max_running_requests = 64 attention_backend = "flashinfer" page_size = 1 other_launch_args = [] @@ -46,8 +46,6 @@ def setUpClass(cls): str(cls.page_size), "--max-running-requests", str(cls.max_running_requests), - "--mem-fraction-static", - "0.65", "--cuda-graph-bs", *[str(i) for i in range(1, cls.max_running_requests + 1)], ]