diff --git a/benchmark/dflash/bench_dflash_gsm8k_sweep.py b/benchmark/dflash/bench_dflash_gsm8k_sweep.py new file mode 100644 index 000000000000..92f5494d233c --- /dev/null +++ b/benchmark/dflash/bench_dflash_gsm8k_sweep.py @@ -0,0 +1,802 @@ +"""DFLASH vs baseline GSM8K sweep. + +This is a *benchmark script* (not a CI test): it can take a long time because it +launches servers for multiple (attention_backend, tp_size) configs and runs a +GSM8K workload for each (concurrency, num_questions) setting. + +Example usage: + ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py + ./venv/bin/python benchmark/dflash/bench_dflash_gsm8k_sweep.py --skip-baseline --concurrencies 32 --tp-sizes 8 +""" + +from __future__ import annotations + +import argparse +import ast +import os +import re +import statistics +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass +from typing import Optional + +import requests +import torch +from transformers import AutoTokenizer + +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + find_available_port, + popen_launch_server, +) +from sglang.utils import download_and_cache_file, read_jsonl + +INVALID = -9999999 + + +def _parse_int_csv(value: str) -> list[int]: + return [int(x) for x in value.split(",") if x.strip()] + + +def _filter_attention_backends(backends: list[str], *, device_sm: int) -> list[str]: + if not (80 <= device_sm <= 90): + backends = [b for b in backends if b != "fa3"] + if device_sm < 100: + backends = [b for b in backends if b not in ("fa4", "trtllm_mha")] + return backends or ["flashinfer"] + + +def _get_answer_value(answer_str: str) -> int: + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def _maybe_download_gsm8k(data_path: str) -> str: + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + if os.path.isfile(data_path): + return data_path + return download_and_cache_file(url) + + +def _flush_cache(base_url: str) -> None: + resp = requests.get(base_url + "/flush_cache", timeout=60) + resp.raise_for_status() + + +def _send_generate( + base_url: str, + text: str | list[str], + *, + max_new_tokens: int, + temperature: float, + top_p: float, + top_k: int, + timeout_s: int, +) -> list[dict]: + if isinstance(text, list) and not text: + return [] + sampling_params: dict = { + "temperature": float(temperature), + "top_p": float(top_p), + "top_k": int(top_k), + "max_new_tokens": int(max_new_tokens), + } + resp = requests.post( + base_url + "/generate", + json={ + "text": text, + "sampling_params": sampling_params, + }, + timeout=int(timeout_s), + ) + resp.raise_for_status() + out = resp.json() + if isinstance(text, list): + if not isinstance(out, list): + raise RuntimeError( + "Expected a list response for batched /generate, but got " + f"type={type(out).__name__}." + ) + if len(out) != len(text): + raise RuntimeError( + "Batched /generate output length mismatch: " + f"got {len(out)} outputs for {len(text)} prompts." + ) + return out + + if isinstance(out, list): + raise RuntimeError( + "Expected an object response for single /generate, but got " + f"type={type(out).__name__}." + ) + return [out] + + +@dataclass(frozen=True) +class BenchMetrics: + latency_s: float + output_tokens: int + output_toks_per_s: float + accuracy: Optional[float] + invalid_rate: Optional[float] + spec_accept_length: Optional[float] + spec_verify_ct_sum: int + + +def _run_gsm8k_requests( + base_url: str, + *, + prompts: list[str], + labels: Optional[list[int]], + max_new_tokens: int, + temperature: float, + top_p: float, + top_k: int, + concurrency: int, + batch_requests: bool, + timeout_s: int, + expect_dflash: bool, +) -> BenchMetrics: + if labels is not None and len(labels) != len(prompts): + raise ValueError("labels length must match prompts length") + + # Drop the first batch from metrics to exclude one-time JIT/cuda-graph overhead + # that often happens immediately after /flush_cache for large batch sizes. + bs = max(int(concurrency), 1) + if len(prompts) > bs: + warmup_prompts = prompts[:bs] + if batch_requests: + _send_generate( + base_url, + warmup_prompts, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + timeout_s=timeout_s, + ) + else: + with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: + futures = [ + pool.submit( + _send_generate, + base_url=base_url, + text=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + timeout_s=timeout_s, + ) + for prompt in warmup_prompts + ] + for fut in as_completed(futures): + outs = fut.result() + if len(outs) != 1: + raise RuntimeError( + "Expected exactly one output for single /generate warmup request." + ) + + prompts = prompts[bs:] + labels = labels[bs:] if labels is not None else None + + start = time.perf_counter() + total_tokens = 0 + spec_verify_ct_sum = 0 + spec_accept_lengths: list[float] = [] + correct = 0 + invalid = 0 + + def _handle_output(out: dict, label: Optional[int]) -> None: + nonlocal total_tokens, spec_verify_ct_sum, correct, invalid + meta = out.get("meta_info", {}) or {} + total_tokens += int(meta.get("completion_tokens", 0)) + spec_verify_ct_sum += int(meta.get("spec_verify_ct", 0)) + if "spec_accept_length" in meta: + try: + spec_accept_lengths.append(float(meta["spec_accept_length"])) + except (TypeError, ValueError): + pass + + if label is not None: + pred = _get_answer_value(out.get("text", "")) + if pred == INVALID: + invalid += 1 + if pred == label: + correct += 1 + + if batch_requests: + bs = max(int(concurrency), 1) + for start_idx in range(0, len(prompts), bs): + chunk_prompts = prompts[start_idx : start_idx + bs] + chunk_labels = ( + labels[start_idx : start_idx + bs] if labels is not None else None + ) + outs = _send_generate( + base_url, + chunk_prompts, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + timeout_s=timeout_s, + ) + if chunk_labels is None: + for out in outs: + _handle_output(out, None) + else: + for out, label in zip(outs, chunk_labels): + _handle_output(out, label) + else: + with ThreadPoolExecutor(max_workers=int(concurrency)) as pool: + futures = { + pool.submit( + _send_generate, + base_url=base_url, + text=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + timeout_s=timeout_s, + ): i + for i, prompt in enumerate(prompts) + } + for fut in as_completed(futures): + i = futures[fut] + outs = fut.result() + if len(outs) != 1: + raise RuntimeError( + "Expected exactly one output for single /generate request." + ) + label = None if labels is None else labels[i] + _handle_output(outs[0], label) + + latency = time.perf_counter() - start + toks_per_s = total_tokens / max(latency, 1e-6) + + if expect_dflash and spec_verify_ct_sum <= 0: + raise RuntimeError( + "DFLASH sanity check failed: did not observe any `spec_verify_ct` in responses " + "(DFLASH may not have been enabled)." + ) + + spec_accept_length = ( + float(statistics.mean(spec_accept_lengths)) if spec_accept_lengths else None + ) + + if labels is None: + acc = None + invalid_rate = None + else: + acc = correct / max(len(prompts), 1) + invalid_rate = invalid / max(len(prompts), 1) + + return BenchMetrics( + latency_s=float(latency), + output_tokens=int(total_tokens), + output_toks_per_s=float(toks_per_s), + accuracy=acc, + invalid_rate=invalid_rate, + spec_accept_length=spec_accept_length, + spec_verify_ct_sum=int(spec_verify_ct_sum), + ) + + +def _format_table( + *, + tp_sizes: list[int], + concurrencies: list[int], + values: dict[tuple[int, int], Optional[float]], + float_fmt: str, +) -> str: + header = ["tp\\conc"] + [str(c) for c in concurrencies] + rows: list[list[str]] = [header] + for tp in tp_sizes: + row = [str(tp)] + for c in concurrencies: + v = values.get((tp, c), None) + row.append("N/A" if v is None else format(v, float_fmt)) + rows.append(row) + + col_widths = [ + max(len(row[col_idx]) for row in rows) for col_idx in range(len(rows[0])) + ] + + lines: list[str] = [] + lines.append(" ".join(cell.rjust(col_widths[i]) for i, cell in enumerate(rows[0]))) + lines.append(" ".join("-" * w for w in col_widths)) + for row in rows[1:]: + lines.append(" ".join(cell.rjust(col_widths[i]) for i, cell in enumerate(row))) + return "\n".join(lines) + + +def _build_common_server_args( + args: argparse.Namespace, *, backend: str, tp: int +) -> list[str]: + common_server_args: list[str] = [ + "--trust-remote-code", + "--attention-backend", + backend, + "--tp-size", + str(tp), + "--dtype", + str(args.dtype), + "--max-running-requests", + str(args.max_running_requests), + "--cuda-graph-max-bs", + "32", + "--mamba-scheduler-strategy", + str(args.mamba_scheduler_strategy), + ] + if args.mem_fraction_static is not None: + common_server_args.extend( + ["--mem-fraction-static", str(args.mem_fraction_static)] + ) + if args.disable_radix_cache: + common_server_args.append("--disable-radix-cache") + if args.page_size is not None: + common_server_args.extend(["--page-size", str(int(args.page_size))]) + return common_server_args + + +def _build_mode_runs( + args: argparse.Namespace, common_server_args: list[str] +) -> list[tuple[str, str, list[str], bool]]: + mode_runs: list[tuple[str, str, list[str], bool]] = [] + if not args.skip_baseline: + mode_runs.append(("baseline", "baseline", common_server_args, False)) + mode_runs.append( + ( + "dflash", + "DFLASH", + [ + *common_server_args, + "--speculative-algorithm", + "DFLASH", + "--speculative-draft-model-path", + args.draft_model, + *( + [ + "--speculative-dflash-draft-window-size", + str(int(args.speculative_dflash_draft_window_size)), + ] + if args.speculative_dflash_draft_window_size is not None + else [] + ), + *( + [ + "--speculative-draft-attention-backend", + args.speculative_draft_attention_backend, + ] + if args.speculative_draft_attention_backend + else [] + ), + ], + True, + ) + ) + return mode_runs + + +def _collect_metric( + *, + results: dict[tuple[str, int, int, str], BenchMetrics], + backend: str, + tp_sizes: list[int], + concurrencies: list[int], + mode: str, + field: str, +) -> dict[tuple[int, int], Optional[float]]: + out: dict[tuple[int, int], Optional[float]] = {} + for tp in tp_sizes: + for conc in concurrencies: + metrics = results.get((backend, tp, conc, mode), None) + out[(tp, conc)] = None if metrics is None else getattr(metrics, field) + return out + + +def _compute_speedup( + baseline: dict[tuple[int, int], Optional[float]], + dflash: dict[tuple[int, int], Optional[float]], +) -> dict[tuple[int, int], Optional[float]]: + return { + key: None if (b is None or d is None or b <= 0) else (d / b) + for key, b in baseline.items() + for d in [dflash.get(key, None)] + } + + +def _print_kv_lines(items: list[tuple[str, object]]) -> None: + for key, value in items: + print(f"{key}={value}") + + +def _run_mode_for_backend_tp( + *, + mode_label: str, + model_path: str, + base_url: str, + server_args: list[str], + expect_dflash: bool, + prompts: list[str], + labels: list[int], + concurrencies: list[int], + num_questions_by_conc: dict[int, int], + args: argparse.Namespace, +) -> dict[int, BenchMetrics]: + print(f"\n=== {mode_label} ===") + server_start_timeout_s = int(max(DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, args.timeout_s)) + proc = popen_launch_server( + model_path, + base_url, + timeout=server_start_timeout_s, + other_args=server_args, + ) + try: + _send_generate( + base_url, + "Hello", + max_new_tokens=8, + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), + timeout_s=min(int(args.timeout_s), 300), + ) + + metrics_by_conc: dict[int, BenchMetrics] = {} + for conc in concurrencies: + n = num_questions_by_conc[conc] + _flush_cache(base_url) + print( + f"[warmup] run 1 warmup batch (size={conc}) after /flush_cache; excluded from metrics." + ) + metrics = _run_gsm8k_requests( + base_url, + prompts=prompts[: n + conc], + labels=labels[: n + conc], + max_new_tokens=int(args.max_new_tokens), + temperature=float(args.temperature), + top_p=float(args.top_p), + top_k=int(args.top_k), + concurrency=int(conc), + batch_requests=bool(args.batch_requests), + timeout_s=int(args.timeout_s), + expect_dflash=expect_dflash, + ) + metrics_by_conc[conc] = metrics + line = ( + f"[{mode_label}] conc={conc:>2} n={n:<4} " + f"toks/s={metrics.output_toks_per_s:,.2f} " + f"latency={metrics.latency_s:.1f}s " + f"acc={metrics.accuracy:.3f} invalid={metrics.invalid_rate:.3f}" + ) + if expect_dflash: + accept_len = ( + "N/A" + if metrics.spec_accept_length is None + else f"{metrics.spec_accept_length:.3f}" + ) + line += ( + f" accept_len={accept_len} " + f"spec_verify_ct_sum={metrics.spec_verify_ct_sum}" + ) + print(line) + return metrics_by_conc + finally: + kill_process_tree(proc.pid) + try: + proc.wait(timeout=30) + except Exception: + pass + + +def _print_summary( + *, + args: argparse.Namespace, + attention_backends: list[str], + tp_sizes: list[int], + concurrencies: list[int], + device_sm: int, + results: dict[tuple[str, int, int, str], BenchMetrics], +) -> None: + print("\n=== DFLASH GSM8K Sweep Summary ===") + _print_kv_lines( + [ + ("target_model", args.target_model), + ("draft_model", args.draft_model), + ("max_new_tokens", args.max_new_tokens), + ( + "sampling", + f"temperature:{args.temperature}, top_p:{args.top_p}, top_k:{args.top_k}", + ), + ("attention_backends", ",".join(attention_backends)), + ( + "speculative_draft_attention_backend", + args.speculative_draft_attention_backend, + ), + ( + "speculative_dflash_draft_window_size", + args.speculative_dflash_draft_window_size, + ), + ("tp_sizes", ",".join(str(x) for x in tp_sizes)), + ("concurrencies", ",".join(str(x) for x in concurrencies)), + ( + "questions_per_concurrency_base", + args.questions_per_concurrency_base, + ), + ("device_sm", device_sm), + ("skip_baseline", bool(args.skip_baseline)), + ] + ) + + section_fields = [ + ("Baseline output tok/s", "baseline", "output_toks_per_s", ",.2f"), + ("Baseline accuracy", "baseline", "accuracy", ".3f"), + ("DFLASH output tok/s", "dflash", "output_toks_per_s", ",.2f"), + ("DFLASH accuracy", "dflash", "accuracy", ".3f"), + ( + "DFLASH acceptance length (mean spec_accept_length)", + "dflash", + "spec_accept_length", + ".3f", + ), + ] + + for backend in attention_backends: + print(f"\n=== Backend: {backend} ===") + metrics_map = { + (mode, field): _collect_metric( + results=results, + backend=backend, + tp_sizes=tp_sizes, + concurrencies=concurrencies, + mode=mode, + field=field, + ) + for _, mode, field, _ in section_fields + } + sections: list[tuple[str, dict[tuple[int, int], Optional[float]], str]] = [ + (title, metrics_map[(mode, field)], fmt) + for title, mode, field, fmt in section_fields + ] + sections.insert( + 4, + ( + "Speedup (DFLASH / baseline)", + _compute_speedup( + metrics_map[("baseline", "output_toks_per_s")], + metrics_map[("dflash", "output_toks_per_s")], + ), + ".3f", + ), + ) + + for title, values, fmt in sections: + print(f"\n{title}") + print( + _format_table( + tp_sizes=tp_sizes, + concurrencies=concurrencies, + values=values, + float_fmt=fmt, + ) + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser() + parser.add_argument("--data-path", default="test.jsonl") + parser.add_argument("--target-model", default="Qwen/Qwen3-8B") + parser.add_argument("--draft-model", default="z-lab/Qwen3-8B-DFlash-b16") + parser.add_argument( + "--skip-baseline", + action="store_true", + help="Skip running the baseline (target-only) sweep; only run DFLASH and report N/A for baseline/speedup.", + ) + parser.add_argument( + "--batch-requests", + action="store_true", + help="Send prompts as server-side batched /generate requests (batch size = concurrency) instead of client-side concurrent requests.", + ) + parser.add_argument("--max-new-tokens", type=int, default=2048) + parser.add_argument("--temperature", type=float, default=0.0) + parser.add_argument("--top-p", type=float, default=1.0) + parser.add_argument("--top-k", type=int, default=1) + parser.add_argument( + "--timeout-s", + type=int, + default=3600, + help=( + "Timeout in seconds for benchmarked /generate calls and server startup " + "health checks." + ), + ) + parser.add_argument( + "--mem-fraction-static", + type=float, + default=None, + help="Optional server --mem-fraction-static override. If unset, use the server auto heuristic.", + ) + parser.add_argument("--disable-radix-cache", action="store_true") + parser.add_argument("--dtype", default="bfloat16") + parser.add_argument( + "--page-size", + type=int, + default=None, + help="Optional server --page-size override for both baseline and DFLASH runs.", + ) + parser.add_argument("--max-running-requests", type=int, default=32) + parser.add_argument( + "--mamba-scheduler-strategy", + default="no_buffer", + help=( + "Server --mamba-scheduler-strategy value to pass through to benchmark " + "runs, e.g. `no_buffer` or `extra_buffer`." + ), + ) + parser.add_argument("--tp-sizes", default="1,2,4,8") + parser.add_argument("--concurrencies", default="1,2,4,8,16,32") + parser.add_argument( + "--questions-per-concurrency-base", + type=int, + default=128, + help="num_questions = base * concurrency (default matches the sweep plan).", + ) + parser.add_argument( + "--max-questions-per-config", + type=int, + default=1024, + help="Cap num_questions per (tp, concurrency) run (default: 1024).", + ) + parser.add_argument("--attention-backends", default="flashinfer,fa3,trtllm_mha,fa4") + parser.add_argument( + "--speculative-draft-attention-backend", + default=None, + help="Optional server --speculative-draft-attention-backend override for DFLASH runs.", + ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + default=None, + help="Optional server --speculative-dflash-draft-window-size override for DFLASH runs.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this sweep.") + if args.temperature < 0.0: + raise RuntimeError(f"--temperature must be >= 0, got {args.temperature}.") + if not (0.0 < args.top_p <= 1.0): + raise RuntimeError(f"--top-p must be in (0, 1], got {args.top_p}.") + if args.top_k == 0 or args.top_k < -1: + raise RuntimeError(f"--top-k must be -1 (all vocab) or >= 1, got {args.top_k}.") + if args.timeout_s <= 0: + raise RuntimeError(f"--timeout-s must be > 0, got {args.timeout_s}.") + + visible_gpus = int(torch.cuda.device_count()) + tp_sizes = _parse_int_csv(args.tp_sizes) + tp_sizes = [tp for tp in tp_sizes if tp >= 1 and tp <= visible_gpus] + if not tp_sizes: + raise RuntimeError( + f"No tp sizes are runnable with visible_gpus={visible_gpus}. " + "Set CUDA_VISIBLE_DEVICES accordingly." + ) + + concurrencies = _parse_int_csv(args.concurrencies) + concurrencies = [c for c in concurrencies if c >= 1] + if not concurrencies: + raise RuntimeError("No concurrencies specified.") + + num_questions_by_conc = { + c: min( + int(args.questions_per_concurrency_base) * int(c), + int(args.max_questions_per_config), + ) + for c in concurrencies + } + max_questions = max(num_questions_by_conc.values()) + + attention_backends = [ + s.strip() for s in args.attention_backends.split(",") if s.strip() + ] + device_sm = get_device_sm() + attention_backends = _filter_attention_backends( + attention_backends, device_sm=device_sm + ) + + data_path = _maybe_download_gsm8k(args.data_path) + lines = list(read_jsonl(data_path)) + if len(lines) < max_questions: + raise RuntimeError( + f"GSM8K file only has {len(lines)} lines, but need {max_questions}." + ) + + tokenizer = AutoTokenizer.from_pretrained(args.target_model) + + prompts: list[str] = [] + labels: list[int] = [] + for i in range(max_questions): + user_content = ( + lines[i]["question"] + + "\nPlease reason step by step, and put your final answer within \\boxed{}." + ) + prompts.append( + tokenizer.apply_chat_template( + [{"role": "user", "content": user_content}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ) + labels.append(_get_answer_value(lines[i]["answer"])) + if not all(label != INVALID for label in labels): + raise RuntimeError("Invalid labels in GSM8K data.") + + # Results indexed by (backend, tp, concurrency, mode). + results: dict[tuple[str, int, int, str], BenchMetrics] = {} + # Baseline metrics are backend-agnostic in this sweep; run once per TP and reuse. + baseline_cache_by_tp: dict[int, dict[int, BenchMetrics]] = {} + + for backend_idx, backend in enumerate(attention_backends): + for tp in tp_sizes: + port_base = find_available_port(20000) + common_server_args = _build_common_server_args(args, backend=backend, tp=tp) + mode_runs = _build_mode_runs(args, common_server_args) + + for idx, ( + mode_key, + mode_name, + mode_server_args, + expect_dflash, + ) in enumerate(mode_runs): + if ( + mode_key == "baseline" + and not args.skip_baseline + and backend_idx > 0 + and tp in baseline_cache_by_tp + ): + mode_metrics = baseline_cache_by_tp[tp] + else: + mode_metrics = _run_mode_for_backend_tp( + mode_label=f"backend={backend} tp={tp} ({mode_name})", + model_path=args.target_model, + base_url=f"http://127.0.0.1:{find_available_port(port_base + idx)}", + server_args=mode_server_args, + expect_dflash=expect_dflash, + prompts=prompts, + labels=labels, + concurrencies=concurrencies, + num_questions_by_conc=num_questions_by_conc, + args=args, + ) + if mode_key == "baseline" and not args.skip_baseline: + baseline_cache_by_tp[tp] = mode_metrics + + for conc, metrics in mode_metrics.items(): + results[(backend, tp, conc, mode_key)] = metrics + + _print_summary( + args=args, + attention_backends=attention_backends, + tp_sizes=tp_sizes, + concurrencies=concurrencies, + device_sm=device_sm, + results=results, + ) + + +if __name__ == "__main__": + main() diff --git a/python/sglang/srt/environ.py b/python/sglang/srt/environ.py index 5399fd2a4289..48100794d754 100644 --- a/python/sglang/srt/environ.py +++ b/python/sglang/srt/environ.py @@ -420,6 +420,7 @@ class Envs: # Overlap Spec V2 SGLANG_ENABLE_SPEC_V2 = EnvBool(False) SGLANG_ENABLE_OVERLAP_PLAN_STREAM = EnvBool(False) + SGLANG_ENABLE_DFLASH_SPEC_V2 = EnvBool(False) # Spec Config SGLANG_SPEC_ENABLE_STRICT_FILTER_CHECK = EnvBool(True) diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c511272544f9..8f8f846e3989 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -591,8 +591,24 @@ def init_forward_metadata_capture_cuda_graph( fast_decode_plan, decode_wrappers[i] ) elif forward_mode.is_target_verify(): + # FlashInfer's prefill wrapper decides mask mode based on whether + # `custom_mask_buf` is initialized (not whether a custom mask is provided). + # For cases like DFLASH draft (ENCODER_ONLY / non-causal) we do NOT use a + # custom mask, so we must avoid initializing `custom_mask_buf`, otherwise + # FlashInfer will treat the (zero) buffer as a real mask and block attention. + use_custom_mask = ( + spec_info is not None + and getattr(spec_info, "custom_mask", None) is not None + ) prefill_wrappers = [] for i in range(self.num_wrappers): + wrapper_kwargs = {} + if use_custom_mask: + wrapper_kwargs = { + "custom_mask_buf": self.cuda_graph_custom_mask, + "mask_indptr_buf": self.cuda_graph_qk_indptr[i][: bs + 1], + } + prefill_wrappers.append( BatchPrefillWithPagedKVCacheWrapper( self.workspace_buffer, @@ -603,8 +619,7 @@ def init_forward_metadata_capture_cuda_graph( paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], paged_kv_indices_buf=self.cuda_graph_kv_indices[i], paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], - custom_mask_buf=self.cuda_graph_custom_mask, - mask_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + **wrapper_kwargs, ) ) seq_lens_sum = seq_lens.sum().item() @@ -777,10 +792,14 @@ def forward_extend( layer, cache_loc, k, v, layer.k_scale, layer.v_scale ) + causal = ( + not layer.is_cross_attention + and layer.attn_type != AttentionType.ENCODER_ONLY + ) o = prefill_wrapper_paged.forward( q.view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), - causal=not layer.is_cross_attention, + causal=causal, sm_scale=layer.scaling, # Disable sliding window attention for multi-item scoring: # - Sliding window could cut across item boundaries, breaking semantic coherence @@ -832,11 +851,6 @@ def forward_extend( ) else: - if not self.is_dllm_model: - # TODO: design a better interface - # For other models, use causal attention for the ragged part as previously - causal = True - o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( q.view(-1, layer.tp_q_head_num, layer.head_dim), k.view(-1, layer.tp_k_head_num, layer.head_dim), diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5597d8baea43..089afe06683a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -2036,10 +2036,20 @@ def prepare_for_decode(self): ) def maybe_wait_verify_done(self): - if self.is_spec_v2: - draft_input: EagleDraftInput = self.spec_info - if draft_input.verify_done is not None: - draft_input.verify_done.synchronize() + if not self.is_spec_v2: + return + + draft_input: EagleDraftInput = self.spec_info + verify_done = getattr(draft_input, "verify_done", None) + if verify_done is None: + return + + if envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): + torch.get_device_module(self.device).current_stream().wait_event( + verify_done + ) + else: + verify_done.synchronize() def filter_batch( self, diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ec3a254adf98..2e97972ae8a7 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -248,6 +248,27 @@ def copy_to_cpu(self): self.copy_done.record() +def validate_dflash_request(req: Req, enable_overlap: bool) -> Optional[str]: + if req.return_logprob: + return "DFLASH speculative decoding does not support return_logprob yet." + + if enable_overlap and req.return_hidden_states: + return "DFLASH speculative decoding does not support return_hidden_states yet." + + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + or req.sampling_params.ebnf is not None + or req.sampling_params.structural_tag is not None + ): + return ( + "DFLASH speculative decoding does not support " + "grammar-constrained decoding yet." + ) + + return None + + class Scheduler( SchedulerOutputProcessorMixin, SchedulerUpdateWeightsMixin, @@ -1633,6 +1654,13 @@ def handle_generate_request( self._add_request_to_queue(req) return + if self.spec_algorithm.is_dflash(): + error_msg = validate_dflash_request(req, self.enable_overlap) + if error_msg is not None: + req.set_finish_with_abort(error_msg) + self.init_req_max_new_tokens(req) + self._add_request_to_queue(req) + return # Handle multimodal inputs if recv_req.mm_inputs is not None: image_inputs = self._get_multimodal_inputs(recv_req.mm_inputs) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 9ff800a9f65d..ea3d4d21322d 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -472,18 +472,15 @@ def __init__(self, model_runner: ModelRunner): self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 - if ( - model_runner.spec_algorithm.is_eagle() - or model_runner.spec_algorithm.is_standalone() - or model_runner.spec_algorithm.is_ngram() - ): + if model_runner.spec_algorithm.is_speculative(): if self.model_runner.is_draft_worker: - raise RuntimeError("This should not happen") - else: - self.capture_forward_mode = ForwardMode.TARGET_VERIFY - self.num_tokens_per_bs = ( - self.model_runner.server_args.speculative_num_draft_tokens - ) + # DFLASH draft workers reuse this runner for TARGET_VERIFY mode. + if not self.model_runner.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + self.capture_forward_mode = ForwardMode.TARGET_VERIFY + self.num_tokens_per_bs = ( + self.model_runner.server_args.speculative_num_draft_tokens + ) elif self.is_dllm: self.capture_forward_mode = ForwardMode.DLLM_EXTEND self.num_tokens_per_bs = self.dllm_config.block_size @@ -560,6 +557,18 @@ def __init__(self, model_runner: ModelRunner): and model_runner.eagle_use_aux_hidden_state ): self.model_runner.model.set_eagle3_layers_to_capture() + if ( + model_runner.spec_algorithm.is_dflash() + and model_runner.dflash_use_aux_hidden_state + ): + if not hasattr(self.model_runner.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model_runner.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH aux hidden capture." + ) + self.model_runner.model.set_dflash_layers_to_capture( + self.model_runner.dflash_target_layer_ids + ) # Capture try: @@ -585,6 +594,7 @@ def can_run(self, forward_batch: ForwardBatch): max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max(forward_batch.global_num_tokens_cpu) ) else: @@ -912,6 +922,12 @@ def run_once(): kwargs["pp_proxy_tensors"] = PPProxyTensors( {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and "input_embeds" in inspect.signature(forward).parameters + ): + kwargs["input_embeds"] = buffers.input_embeds[:num_tokens] logits_output_or_pp_proxy_tensors = forward( input_ids, @@ -988,6 +1004,7 @@ def replay_prepare( max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() + or self.model_runner.spec_algorithm.is_dflash() else max_num_tokens ) index = bisect.bisect_left(self.capture_bs, max_batch_size) @@ -1009,6 +1026,13 @@ def replay_prepare( ), pp_proxy_tensors=pp_proxy_tensors, ) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + buffers.input_embeds[:raw_num_token].copy_(forward_batch.input_embeds) + # Padded tokens aren't read, so skip zeroing them. if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, @@ -1054,6 +1078,14 @@ def replay( # In speculative decoding, these two fields are still needed. self.buffers.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.buffers.positions[: self.raw_num_token].copy_(forward_batch.positions) + if ( + self.model_runner.spec_algorithm.is_dflash() + and self.model_runner.is_draft_worker + and forward_batch.input_embeds is not None + ): + self.buffers.input_embeds[: self.raw_num_token].copy_( + forward_batch.input_embeds + ) # Replay if self.enable_pdmux: @@ -1066,10 +1098,18 @@ def replay( if isinstance(output, LogitsProcessorOutput): if self.is_dllm: next_token_logits = None - full_logits = output.full_logits[: self.raw_num_token] + full_logits = ( + output.full_logits[: self.raw_num_token] + if output.full_logits is not None + else None + ) else: full_logits = None - next_token_logits = output.next_token_logits[: self.raw_num_token] + next_token_logits = ( + output.next_token_logits[: self.raw_num_token] + if output.next_token_logits is not None + else None + ) return LogitsProcessorOutput( next_token_logits=next_token_logits, @@ -1111,6 +1151,32 @@ def get_spec_info(self, num_tokens: int): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.model_runner.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_verify_mask_policy, + ) + + # Avoid enabling custom-mask modes during graph capture for backends that + # can express DFLASH verify via their built-in causal path. + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, + custom_mask=( + None + if (self.model_runner.is_draft_worker or not build_custom_mask) + else self.buffers.custom_mask + ), + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.model_runner.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.model_runner.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9b615cce8499..1b06dc56fa2a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -150,6 +150,9 @@ get_global_server_args, set_global_server_args_for_scheduler, ) +from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, +) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( MultiprocessingSerializer, @@ -345,6 +348,9 @@ def __init__( self.remote_instance_transfer_engine_weight_info = None # auxiliary hidden capture mode. TODO: expose this to server args? self.eagle_use_aux_hidden_state = False + self.dflash_use_aux_hidden_state = False + self.dflash_target_layer_ids = None + self.dflash_draft_num_layers = None if self.spec_algorithm.is_eagle3() and not self.is_draft_worker: # load draft config draft_model_config = ModelConfig.from_server_args( @@ -370,6 +376,48 @@ def __init__( # if there is no aux layer, set to None self.eagle_aux_hidden_state_layer_ids = None + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + # Select target layers to capture for building DFlash context features. + draft_model_config = ModelConfig.from_server_args( + server_args, + model_path=(server_args.speculative_draft_model_path), + model_revision=server_args.speculative_draft_model_revision, + is_draft_model=True, + ) + dflash_draft_config = parse_dflash_draft_config( + draft_hf_config=draft_model_config.hf_config + ) + draft_num_layers = dflash_draft_config.require_num_layers() + trained_target_layers = dflash_draft_config.num_target_layers + + target_num_layers = getattr( + self.model_config.hf_text_config, "num_hidden_layers", None + ) + if target_num_layers is None: + raise ValueError( + "DFLASH requires target num_hidden_layers in config. " + f"Got target={target_num_layers}." + ) + target_num_layers = int(target_num_layers) + + if ( + trained_target_layers is not None + and trained_target_layers != target_num_layers + ): + logger.warning( + "DFLASH draft config num_target_layers=%s differs from runtime target num_hidden_layers=%s; " + "selecting capture layers based on the runtime target model.", + trained_target_layers, + target_num_layers, + ) + + self.dflash_use_aux_hidden_state = True + self.dflash_draft_num_layers = int(draft_num_layers) + self.dflash_target_layer_ids = dflash_draft_config.resolve_target_layer_ids( + target_num_layers=int(target_num_layers), + draft_num_layers=int(draft_num_layers), + ) + # Apply the rank zero filter to logger if server_args.show_time_cost: enable_show_time_cost() @@ -636,6 +684,14 @@ def initialize(self, pre_model_load_memory: float): self.eagle_aux_hidden_state_layer_ids ) + if self.dflash_use_aux_hidden_state: + if not hasattr(self.model, "set_dflash_layers_to_capture"): + raise ValueError( + f"Model {self.model.__class__.__name__} does not implement set_dflash_layers_to_capture, " + "which is required for DFLASH." + ) + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) + # Initialize piecewise CUDA graph self.init_piecewise_cuda_graphs() @@ -1843,11 +1899,7 @@ def _should_run_flashinfer_autotune(self) -> bool: if major < 9: return False - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): return not self.is_draft_worker return True @@ -1877,16 +1929,12 @@ def _dummy_run(self, batch_size: int, run_ctx=None): capture_forward_mode = ForwardMode.EXTEND capture_hidden_mode = CaptureHiddenMode.NULL num_tokens_per_bs = 1 - if ( - self.spec_algorithm.is_eagle() - or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_ngram() - ): + if self.spec_algorithm.is_speculative(): if self.is_draft_worker: - raise RuntimeError("This should not happen") - else: - capture_forward_mode = ForwardMode.TARGET_VERIFY - num_tokens_per_bs = self.server_args.speculative_num_draft_tokens + if not self.spec_algorithm.is_dflash(): + raise RuntimeError("This should not happen") + capture_forward_mode = ForwardMode.TARGET_VERIFY + num_tokens_per_bs = self.server_args.speculative_num_draft_tokens if self.server_args.enable_return_hidden_states: capture_hidden_mode = CaptureHiddenMode.FULL @@ -1906,6 +1954,8 @@ def _dummy_run(self, batch_size: int, run_ctx=None): if self.eagle_use_aux_hidden_state: self.model.set_eagle3_layers_to_capture() + if self.dflash_use_aux_hidden_state: + self.model.set_dflash_layers_to_capture(self.dflash_target_layer_ids) require_mlp_tp_gather_ = require_mlp_tp_gather(self.server_args) if require_gathered_buffer(self.server_args): @@ -2015,6 +2065,21 @@ def get_spec_info(): seq_lens_sum=None, seq_lens_cpu=None, ) + elif self.spec_algorithm.is_dflash(): + from sglang.srt.speculative.dflash_info import DFlashVerifyInput + + # Dummy warmup only needs shape metadata; avoid forcing custom-mask mode. + spec_info = DFlashVerifyInput( + draft_token=None, + positions=None, + draft_token_num=self.server_args.speculative_num_draft_tokens, + custom_mask=None, + capture_hidden_mode=( + CaptureHiddenMode.NULL + if self.is_draft_worker + else CaptureHiddenMode.FULL + ), + ) elif self.spec_algorithm.is_ngram(): from sglang.srt.speculative.ngram_info import NgramVerifyInput diff --git a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py index 32cf5d7e8846..1b836810dba5 100644 --- a/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py +++ b/python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py @@ -24,6 +24,11 @@ ReqToTokenPool, ) from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator +from sglang.srt.speculative.dflash_utils import ( + resolve_dflash_auto_memory_plan, + resolve_dflash_max_mamba_cache_size, + scale_kv_cell_size_per_token_for_dflash, +) from sglang.srt.utils.common import ( get_available_gpu_memory, is_float4_e2m1fn_x2, @@ -139,10 +144,56 @@ def profile_max_num_token(self: ModelRunner, pre_model_load_memory: int): num_layers = self.num_effective_layers cell_size = self.get_cell_size_per_token(num_layers) + if self.spec_algorithm.is_dflash() and not self.is_draft_worker: + draft_num_layers = getattr(self, "dflash_draft_num_layers", None) + if ( + draft_num_layers is not None + and int(draft_num_layers) > 0 + and int(num_layers) > 0 + ): + cell_size = scale_kv_cell_size_per_token_for_dflash( + target_cell_size_per_token=cell_size, + target_num_layers=int(num_layers), + draft_num_layers=int(draft_num_layers), + ) rest_memory = post_model_load_memory - pre_model_load_memory * ( 1 - self.mem_fraction_static ) + if ( + getattr(self.server_args, "_auto_mem_fraction_static", False) + and self.spec_algorithm.is_dflash() + and (config := self.mambaish_config) is not None + and self.server_args.max_running_requests is not None + ): + max_running_requests = self.server_args.max_running_requests // self.dp_size + if max_running_requests > 0: + memory_plan = resolve_dflash_auto_memory_plan( + rest_memory_gb=rest_memory, + post_model_load_memory_gb=post_model_load_memory, + cell_size=cell_size, + max_running_requests=max_running_requests, + mamba_cache_per_req=config.mamba2_cache_params.mamba_cache_per_req, + speculative_num_draft_tokens=int( + self.server_args.speculative_num_draft_tokens or 0 + ), + chunked_prefill_size=self.server_args.chunked_prefill_size, + max_prefill_tokens=int(self.server_args.max_prefill_tokens), + page_size=int(self.server_args.page_size), + mamba_ratio=self._calculate_mamba_ratio(), + explicit_max_mamba_cache_size=self.server_args.max_mamba_cache_size, + ) + if memory_plan.required_rest_memory_gb > float(rest_memory): + logger.info( + "Raise effective DFLASH rest-memory budget from %.2f GB to %.2f GB " + "(max_running_requests=%d, max_mamba_cache_size=%d, min_required_tokens=%d).", + rest_memory, + memory_plan.required_rest_memory_gb, + max_running_requests, + memory_plan.max_mamba_cache_size, + memory_plan.min_required_tokens, + ) + rest_memory = memory_plan.required_rest_memory_gb if self.mambaish_config is not None: rest_memory = self.handle_max_mamba_cache(rest_memory) @@ -175,6 +226,17 @@ def handle_max_mamba_cache(self: ModelRunner, total_rest_memory): server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // ( server_args.dp_size if server_args.enable_dp_attention else 1 ) + elif ( + self.spec_algorithm.is_dflash() + and server_args.max_running_requests is not None + ): + # DFLASH hybrid runs should reserve resident mamba cache directly + # from the requested concurrency so the later request clamp becomes + # a safety backstop instead of the normal path. + server_args.max_mamba_cache_size = resolve_dflash_max_mamba_cache_size( + max_running_requests=server_args.max_running_requests // self.dp_size, + mamba_ratio=self._calculate_mamba_ratio(), + ) elif ( server_args.disable_radix_cache and server_args.max_running_requests is not None diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index e9a4f511bfbb..fc69f193d9e3 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2230,6 +2230,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/python/sglang/srt/models/dflash.py b/python/sglang/srt/models/dflash.py new file mode 100644 index 000000000000..27f5cdbf539d --- /dev/null +++ b/python/sglang/srt/models/dflash.py @@ -0,0 +1,399 @@ +# Adapted from the DFlash reference implementation (HF) but implemented with +# SGLang primitives (RadixAttention + SGLang KV cache). This model intentionally +# does not include token embeddings or an LM head; DFlash uses the target model's +# embedding/lm_head. + +from __future__ import annotations + +import logging +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.radix_attention import AttentionType, RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.utils import apply_qk_norm +from sglang.srt.speculative.dflash_utils import ( + can_dflash_slice_qkv_weight, + parse_dflash_draft_config, +) + +logger = logging.getLogger(__name__) + + +class DFlashAttention(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + tp_size = int(get_tensor_model_parallel_world_size()) + total_num_heads = int(config.num_attention_heads) + total_num_kv_heads = int( + getattr(config, "num_key_value_heads", total_num_heads) + ) + head_dim = int(getattr(config, "head_dim", hidden_size // total_num_heads)) + + self.hidden_size = hidden_size + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + assert self.total_num_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_heads divisible by tp_size. " + f"total_num_heads={self.total_num_heads}, tp_size={tp_size}." + ) + self.num_heads = self.total_num_heads // tp_size + if self.total_num_kv_heads >= tp_size: + assert self.total_num_kv_heads % tp_size == 0, ( + f"DFlashAttention requires total_num_kv_heads divisible by tp_size when >= tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + else: + assert tp_size % self.total_num_kv_heads == 0, ( + f"DFlashAttention requires tp_size divisible by total_num_kv_heads when total_num_kv_heads < tp_size. " + f"total_num_kv_heads={self.total_num_kv_heads}, tp_size={tp_size}." + ) + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * head_dim + self.kv_size = self.num_kv_heads * head_dim + + attention_bias = bool(getattr(config, "attention_bias", False)) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=attention_bias, + prefix="qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * head_dim, + hidden_size, + bias=attention_bias, + prefix="o_proj", + ) + + # Per-head Q/K RMSNorm, matching HF Qwen3. + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + rope_theta = float(getattr(config, "rope_theta", 1000000)) + rope_scaling = getattr(config, "rope_scaling", None) + rope_is_neox_style = bool( + getattr( + config, "rope_is_neox_style", getattr(config, "is_neox_style", True) + ) + ) + max_position_embeddings = int(getattr(config, "max_position_embeddings", 32768)) + self.rotary_emb = get_rope( + head_dim, + rotary_dim=head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=rope_is_neox_style, + ) + + self.scaling = head_dim**-0.5 + # DFlash uses non-causal attention over the draft block. + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + attn_type=AttentionType.ENCODER_ONLY, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = apply_qk_norm(q, k, self.q_norm, self.k_norm, self.head_dim) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + def kv_proj_only( + self, hidden_states: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Project hidden_states to K/V only (skip Q). + + This is used by DFlash to materialize ctx tokens into the draft KV cache: + we only need K/V for the cached tokens; Q is never consumed. + """ + # Fast path for unquantized weights: slice the fused QKV weight and run one GEMM. + can_slice_qkv_weight, _ = can_dflash_slice_qkv_weight(self.qkv_proj) + if can_slice_qkv_weight: + kv_slice = slice(self.q_size, self.q_size + 2 * self.kv_size) + weight = self.qkv_proj.weight[kv_slice] + bias = ( + self.qkv_proj.bias[kv_slice] if self.qkv_proj.bias is not None else None + ) + kv = F.linear(hidden_states, weight, bias) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + return k, v + + # Fallback: compute full QKV and discard Q (keeps compatibility with quantized weights). + qkv, _ = self.qkv_proj(hidden_states) + _, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + return k, v + + def apply_k_norm(self, k: torch.Tensor) -> torch.Tensor: + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + return k_by_head.view_as(k) + + def apply_k_rope(self, positions: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + # Use a minimal dummy query (1 head) to avoid doing full-Q work. + dummy_q = k.new_empty((k.shape[0], self.head_dim)) + _, k = self.rotary_emb(positions, dummy_q, k) + return k + + +class DFlashMLP(nn.Module): + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + hidden_size = int(config.hidden_size) + intermediate_size = int(getattr(config, "intermediate_size", 0)) + if intermediate_size <= 0: + raise ValueError( + f"Invalid intermediate_size={intermediate_size} for DFlash MLP." + ) + + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix="gate_up_proj" if not prefix else f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix="down_proj" if not prefix else f"{prefix}.down_proj", + ) + hidden_act = getattr(config, "hidden_act", "silu") + if hidden_act != "silu": + raise ValueError( + f"Unsupported DFlash activation: {hidden_act}. Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DFlashDecoderLayer(nn.Module): + def __init__(self, config, layer_id: int) -> None: + super().__init__() + hidden_size = int(config.hidden_size) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.self_attn = DFlashAttention(config=config, layer_id=layer_id) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.mlp = DFlashMLP(config=config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + if hidden_states.numel() == 0: + # Keep return types consistent for upstream callers. + if residual is None: + residual = hidden_states + return hidden_states, residual + + # Pre-norm attention with fused residual+norm when possible (Qwen3-style). + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + attn_out = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + hidden_states, residual = self.post_attention_layernorm(attn_out, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DFlashDraftModel(nn.Module): + """SGLang DFlash draft model (no embedding / lm_head weights). + + The checkpoint provides: + - transformer weights for `layers.*` + - `fc.weight`, `hidden_norm.weight` for projecting target context features + - `norm.weight` for final normalization + """ + + def __init__(self, config, quant_config=None, prefix: str = "") -> None: + super().__init__() + self.config = config + + hidden_size = int(config.hidden_size) + num_layers = int(config.num_hidden_layers) + rms_norm_eps = float(getattr(config, "rms_norm_eps", 1e-6)) + + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config=config, layer_id=i) for i in range(num_layers)] + ) + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + # Project per-token target context features: + # concat(K * hidden_size) -> hidden_size, where K is the number of target-layer + # feature tensors concatenated per token (not necessarily equal to num_layers). + draft_config = parse_dflash_draft_config(draft_hf_config=config) + target_num_layers = ( + int(draft_config.num_target_layers) + if draft_config.num_target_layers is not None + else num_layers + ) + target_layer_ids = draft_config.resolve_target_layer_ids( + target_num_layers=target_num_layers, draft_num_layers=num_layers + ) + num_context_features = len(target_layer_ids) + + self.num_context_features = int(num_context_features) + self.fc = nn.Linear( + self.num_context_features * hidden_size, hidden_size, bias=False + ) + self.hidden_norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + self.block_size = draft_config.resolve_block_size(default=16) + + def project_target_hidden(self, target_hidden: torch.Tensor) -> torch.Tensor: + """Project concatenated target-layer hidden states into draft hidden_size.""" + expected = int(self.fc.in_features) + if target_hidden.ndim != 2 or int(target_hidden.shape[-1]) != expected: + raise ValueError( + "DFLASH target_hidden feature dim mismatch. " + f"Expected shape [N, {expected}] " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got shape={tuple(target_hidden.shape)}. " + "This usually means the target model is capturing a different number of layer features than " + "the draft checkpoint/config expects." + ) + return self.hidden_norm(self.fc(target_hidden)) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + get_embedding: bool = False, + pp_proxy_tensors=None, + ) -> LogitsProcessorOutput: + if input_embeds is None: + raise ValueError( + "DFlashDraftModel requires `input_embeds` (use the target embedding)." + ) + hidden_states = input_embeds + residual: Optional[torch.Tensor] = None + + for layer in self.layers: + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual + ) + + if hidden_states.numel() != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return LogitsProcessorOutput( + next_token_logits=None, + hidden_states=hidden_states, + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + + def resolve_param_name(name: str) -> Optional[str]: + if name in params_dict: + return name + if name.startswith("model."): + stripped_name = name[len("model.") :] + if stripped_name in params_dict: + return stripped_name + else: + prefixed_name = f"model.{name}" + if prefixed_name in params_dict: + return prefixed_name + return None + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if f".{weight_name}." not in name: + continue + mapped_name = name.replace(weight_name, param_name) + resolved_name = resolve_param_name(mapped_name) + if resolved_name is None: + continue + param = params_dict[resolved_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + resolved_name = resolve_param_name(name) + if resolved_name is None: + # Ignore unexpected weights (e.g., HF rotary caches). + continue + param = params_dict[resolved_name] + if resolved_name.endswith("fc.weight") and tuple( + loaded_weight.shape + ) != tuple(param.shape): + raise ValueError( + "DFLASH fc.weight shape mismatch. This usually means the draft checkpoint's " + "number of context features (K) does not match this config. " + f"Expected fc.weight.shape={tuple(param.shape)} " + f"(num_context_features={self.num_context_features}, hidden_size={int(self.config.hidden_size)}), " + f"but got {tuple(loaded_weight.shape)} for weight '{name}'." + ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = DFlashDraftModel diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index 96caaa65b57c..228555ea164d 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -1104,6 +1104,9 @@ def _load_normal_weights( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1126,6 +1129,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + @classmethod def get_model_config_for_expert_location(cls, config): return ModelConfigForExpertLocation( diff --git a/python/sglang/srt/models/kimi_k25.py b/python/sglang/srt/models/kimi_k25.py index bf931d5cc45e..7924895e055c 100644 --- a/python/sglang/srt/models/kimi_k25.py +++ b/python/sglang/srt/models/kimi_k25.py @@ -805,6 +805,30 @@ def set_eagle3_layers_to_capture( self.language_model.set_eagle3_layers_to_capture(layer_ids) + def set_dflash_layers_to_capture(self, layer_ids: List[int]) -> None: + """Set the layers to capture for DFLASH draft model training.""" + if not hasattr(self.language_model, "set_dflash_layers_to_capture"): + raise AttributeError( + "language_model does not support DFLASH layer capture." + ) + + self.language_model.set_dflash_layers_to_capture(layer_ids) + + def get_input_embeddings(self): + if not hasattr(self.language_model, "get_input_embeddings"): + raise AttributeError( + "language_model does not support get_input_embeddings()." + ) + + return self.language_model.get_input_embeddings() + + @property + def lm_head(self): + if not hasattr(self.language_model, "lm_head"): + raise AttributeError("language_model does not expose lm_head.") + + return self.language_model.lm_head + def get_embed_and_head(self) -> Tuple[torch.Tensor, torch.Tensor]: """Get embedding and LM head weights for speculative decoding.""" if not hasattr(self.language_model, "get_embed_and_head"): diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 01e934dcc096..4e092c2761d7 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -789,6 +789,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): # of the (i-1)th layer as aux hidden state self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.layers_to_capture = [val + 1 for val in layer_ids] + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/models/qwen3.py b/python/sglang/srt/models/qwen3.py index b6b955ae6c7c..e9e9b61b3e2b 100644 --- a/python/sglang/srt/models/qwen3.py +++ b/python/sglang/srt/models/qwen3.py @@ -586,5 +586,19 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.layers_to_capture = [val + 1 for val in layer_ids] + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + # SGLang captures "before layer i". To capture the hidden state after target + # layer `k` (HF-style), we capture before layer `k + 1`. + self.model.layers_to_capture = [val + 1 for val in layer_ids] + EntryClass = Qwen3ForCausalLM diff --git a/python/sglang/srt/models/qwen3_5.py b/python/sglang/srt/models/qwen3_5.py index b7c3f1ec0ea3..ce9b033fef48 100644 --- a/python/sglang/srt/models/qwen3_5.py +++ b/python/sglang/srt/models/qwen3_5.py @@ -364,8 +364,15 @@ def forward( ): forward_batch = kwargs.get("forward_batch", None) - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, residual = ( + self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( + hidden_states, + residual, + forward_batch, + captured_last_layer_outputs=kwargs.get( + "captured_last_layer_outputs", None + ), + ) ) if not forward_batch.forward_mode.is_idle(): @@ -609,10 +616,16 @@ def forward( hidden_states: torch.Tensor, residual: Optional[torch.Tensor], forward_batch: ForwardBatch, + captured_last_layer_outputs: Optional[list[torch.Tensor]] = None, **kwargs, ): - hidden_states, residual = self.layer_communicator.prepare_attn( - hidden_states, residual, forward_batch + hidden_states, residual = ( + self.layer_communicator.prepare_attn_and_capture_last_layer_outputs( + hidden_states, + residual, + forward_batch, + captured_last_layer_outputs=captured_last_layer_outputs, + ) ) if not forward_batch.forward_mode.is_idle(): @@ -684,6 +697,8 @@ def __init__( else: self.embed_tokens = PPMissingLayer() + self.layers_to_capture = [] + # Decoder layers def get_layer(idx: int, prefix: str): layer_type = config.layers_block_type[idx] @@ -725,6 +740,11 @@ def get_layer(idx: int, prefix: str): else: self.norm = PPMissingLayer() + def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def get_input_embeddings(self): return self.embed_tokens @@ -758,6 +778,7 @@ def forward( hidden_states = pp_proxy_tensors["hidden_states"] residual = pp_proxy_tensors["residual"] + aux_hidden_states = [] # Pass through decoder layers for layer_idx in range(self.start_layer, self.end_layer): layer = self.layers[layer_idx] @@ -769,6 +790,11 @@ def forward( hidden_states=hidden_states, residual=residual, forward_batch=forward_batch, + captured_last_layer_outputs=( + aux_hidden_states + if getattr(layer, "_is_layer_to_capture", False) + else None + ), ) # Process deepstack embeddings if provided @@ -798,7 +824,10 @@ def forward( else: hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states + if len(aux_hidden_states) == 0: + return hidden_states + + return hidden_states, aux_hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 845502fe2e88..0e39d74a497c 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -889,6 +889,11 @@ def __init__( alt_stream=alt_stream, ) + def set_dflash_layers_to_capture(self, layers_to_capture: List[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + class Qwen3MoeForCausalLM(nn.Module): fall_back_to_pt_during_load = False @@ -1026,6 +1031,18 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 42146e6057ea..34445121d555 100644 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -864,6 +864,11 @@ def set_eagle3_layers_to_capture(self, layers_to_capture: list[int]): for layer_id in self.layers_to_capture: setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def set_dflash_layers_to_capture(self, layers_to_capture: list[int]): + self.layers_to_capture = layers_to_capture + for layer_id in self.layers_to_capture: + setattr(self.layers[layer_id], "_is_layer_to_capture", True) + def forward( self, input_ids: torch.Tensor, @@ -998,6 +1003,9 @@ def forward( def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + def set_embed_and_head(self, embed, head): del self.model.embed_tokens.weight del self.lm_head.weight @@ -1171,5 +1179,17 @@ def set_eagle3_layers_to_capture(self, layer_ids: Optional[list[int]] = None): else: self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids]) + def set_dflash_layers_to_capture(self, layer_ids: list[int]): + if not self.pp_group.is_last_rank: + return + + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + EntryClass = Qwen3NextForCausalLM diff --git a/python/sglang/srt/models/qwen3_vl.py b/python/sglang/srt/models/qwen3_vl.py index 34a645078c71..96ce39945a1d 100644 --- a/python/sglang/srt/models/qwen3_vl.py +++ b/python/sglang/srt/models/qwen3_vl.py @@ -1074,6 +1074,7 @@ def __init__( self.logits_processor = LogitsProcessor(self.config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.capture_aux_hidden_states = False # like {8:0, 16:1, 24:2}, which stands for the captured deepstack features on # 8, 16, 24 layer will be merged to 0, 1, 2 layer of decoder output hidden_states @@ -1278,6 +1279,10 @@ def forward( pp_proxy_tensors=pp_proxy_tensors, ) + aux_hidden_states = None + if self.capture_aux_hidden_states: + hidden_states, aux_hidden_states = hidden_states + if self.pp_group.is_last_rank: if not get_embedding: return self.logits_processor( @@ -1285,12 +1290,23 @@ def forward( hidden_states, self.lm_head, forward_batch, + aux_hidden_states, ) else: return self.pooler(hidden_states, forward_batch) else: return hidden_states + def set_dflash_layers_to_capture(self, layer_ids: List[int]): + if not self.pp_group.is_last_rank: + return + if layer_ids is None: + raise ValueError( + "DFLASH requires explicit layer_ids for aux hidden capture." + ) + self.capture_aux_hidden_states = True + self.model.set_dflash_layers_to_capture([val + 1 for val in layer_ids]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f2dfe68d192b..2ac3720e8252 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -484,6 +484,8 @@ class ServerArgs: speculative_num_steps: Optional[int] = None speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None + speculative_dflash_block_size: Optional[int] = None + speculative_dflash_draft_window_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None @@ -1107,6 +1109,7 @@ def _handle_gpu_memory_settings(self, gpu_mem): The coefficient 1.5 is a heuristic value, in the future, we can do better estimation by looking at the model types, hidden sizes or even do a dummy run. """ + self._auto_mem_fraction_static = self.mem_fraction_static is None if gpu_mem is not None: if gpu_mem < 20 * 1024: # T4, 4080 @@ -1244,7 +1247,7 @@ def _handle_gpu_memory_settings(self, gpu_mem): if self.speculative_algorithm == "STANDALONE": # standalonedraft model and cuda graphs reserved_mem += 6 * 1024 - elif self.speculative_algorithm != "NGRAM": + elif self.speculative_algorithm not in {"NGRAM", "DFLASH"}: # eagle draft models and cuda graphs reserved_mem += 4 * 1024 @@ -2679,6 +2682,145 @@ def _handle_speculative_decoding(self): if self.speculative_algorithm == "NEXTN": self.speculative_algorithm = "EAGLE" + if self.speculative_algorithm == "DFLASH": + if self.enable_dp_attention: + raise ValueError( + "Currently DFLASH speculative decoding does not support dp attention." + ) + + if self.pp_size != 1: + raise ValueError( + "Currently DFLASH speculative decoding only supports pp_size == 1." + ) + + if self.speculative_draft_model_path is None: + raise ValueError( + "DFLASH speculative decoding requires setting --speculative-draft-model-path." + ) + + # DFLASH does not use EAGLE-style `num_steps`/`topk`, but those fields still + # affect generic scheduler/KV-cache accounting (buffer sizing, KV freeing, + # RoPE reservation). Force them to 1 to avoid surprising memory behavior. + # + # For DFlash, the natural unit is `block_size` (verify window length). + if self.speculative_num_steps is None: + self.speculative_num_steps = 1 + elif int(self.speculative_num_steps) != 1: + logger.warning( + "DFLASH only supports speculative_num_steps == 1; overriding speculative_num_steps=%s to 1.", + self.speculative_num_steps, + ) + self.speculative_num_steps = 1 + + if self.speculative_eagle_topk is None: + self.speculative_eagle_topk = 1 + elif int(self.speculative_eagle_topk) != 1: + logger.warning( + "DFLASH only supports speculative_eagle_topk == 1; overriding speculative_eagle_topk=%s to 1.", + self.speculative_eagle_topk, + ) + self.speculative_eagle_topk = 1 + + if self.speculative_dflash_block_size is not None: + if int(self.speculative_dflash_block_size) <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-block-size to be positive, " + f"got {self.speculative_dflash_block_size}." + ) + if self.speculative_num_draft_tokens is not None and int( + self.speculative_num_draft_tokens + ) != int(self.speculative_dflash_block_size): + raise ValueError( + "Both --speculative-num-draft-tokens and --speculative-dflash-block-size are set " + "but they differ. For DFLASH they must match. " + f"speculative_num_draft_tokens={self.speculative_num_draft_tokens}, " + f"speculative_dflash_block_size={self.speculative_dflash_block_size}." + ) + self.speculative_num_draft_tokens = int( + self.speculative_dflash_block_size + ) + + window_size = None + if self.speculative_dflash_draft_window_size is not None: + window_size = int(self.speculative_dflash_draft_window_size) + if window_size <= 0: + raise ValueError( + "DFLASH requires --speculative-dflash-draft-window-size " + f"to be positive, got {window_size}." + ) + self.speculative_dflash_draft_window_size = window_size + + if self.speculative_num_draft_tokens is None: + from sglang.srt.speculative.dflash_utils import ( + parse_dflash_draft_config, + ) + + model_override_args = json.loads(self.json_model_override_args) + inferred_block_size = None + try: + from sglang.srt.utils.hf_transformers_utils import get_config + + draft_hf_config = get_config( + self.speculative_draft_model_path, + trust_remote_code=self.trust_remote_code, + revision=self.speculative_draft_model_revision, + model_override_args=model_override_args, + ) + inferred_block_size = parse_dflash_draft_config( + draft_hf_config=draft_hf_config + ).resolve_block_size(default=None) + except Exception as e: + logger.warning( + "Failed to infer DFLASH block_size from draft model config; " + "defaulting speculative_num_draft_tokens to 16. Error: %s", + e, + ) + + if inferred_block_size is None: + inferred_block_size = 16 + logger.warning( + "speculative_num_draft_tokens is not set; defaulting to %d for DFLASH.", + inferred_block_size, + ) + self.speculative_num_draft_tokens = inferred_block_size + + if window_size is not None: + draft_tokens = int(self.speculative_num_draft_tokens) + if window_size < draft_tokens: + raise ValueError( + "DFLASH --speculative-dflash-draft-window-size must be >= " + "--speculative-num-draft-tokens (block_size). " + f"window_size={window_size}, block_size={draft_tokens}." + ) + + if self.max_running_requests is None: + self.max_running_requests = 48 + logger.warning( + "Max running requests is reset to 48 for speculative decoding. You can override this by explicitly setting --max-running-requests." + ) + + if ( + envs.SGLANG_ENABLE_SPEC_V2.get() + and envs.SGLANG_ENABLE_DFLASH_SPEC_V2.get() + ): + self.disable_overlap_schedule = False + logger.warning( + "DFLASH spec v2 is enabled and overlap schedule is turned on (experimental)." + ) + else: + self.disable_overlap_schedule = True + logger.warning( + "Overlap scheduler is disabled when using DFLASH speculative decoding. " + "Set env SGLANG_ENABLE_SPEC_V2=True and SGLANG_ENABLE_DFLASH_SPEC_V2=True to " + "enable the experimental overlap scheduler for DFLASH." + ) + + if self.enable_mixed_chunk: + self.enable_mixed_chunk = False + logger.warning( + "Mixed chunked prefill is disabled because of using dflash speculative decoding." + ) + if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding @@ -4320,7 +4462,7 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], + choices=["DFLASH", "EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( @@ -4364,6 +4506,21 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of tokens sampled from the draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-dflash-block-size", + type=int, + help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", + default=ServerArgs.speculative_dflash_block_size, + ) + parser.add_argument( + "--speculative-dflash-draft-window-size", + type=int, + help="DFLASH only. Sliding window size for the draft-model KV cache. " + "When set, the draft worker keeps a recent target-token window in its " + "local cache (paged backends may retain up to one extra page on the left " + "for alignment). Default is full context.", + default=ServerArgs.speculative_dflash_draft_window_size, + ) parser.add_argument( "--speculative-accept-threshold-single", type=float, diff --git a/python/sglang/srt/speculative/dflash_info.py b/python/sglang/srt/speculative/dflash_info.py new file mode 100644 index 000000000000..dd9d91e4a0a5 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info.py @@ -0,0 +1,491 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import List, Tuple + +import torch + +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.dflash_utils import ( + apply_dflash_verify_logits_adjustments, + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func + + +def _compute_paged_keep_slots( + *, + prefix_lens: torch.Tensor, + commit_lens: torch.Tensor, + draft_token_num: int, + page_size: int, +) -> torch.Tensor: + """Compute how many draft slots per request must remain allocated. + + The allocator frees at page granularity for paged mode, so we can only release + full pages from the tail after verify. + """ + + if page_size <= 1: + raise ValueError(f"Expected page_size > 1, got {page_size}.") + + seq_dtype = prefix_lens.dtype + extended_lens = prefix_lens + int(draft_token_num) + new_lens = prefix_lens + commit_lens.to(seq_dtype) + aligned_new_lens = ((new_lens + page_size - 1) // page_size) * page_size + keep_lens = torch.minimum(aligned_new_lens, extended_lens) + keep_slots = (keep_lens - prefix_lens).to(torch.int64) + keep_slots.clamp_(min=0, max=int(draft_token_num)) + return keep_slots + + +@dataclass +class DFlashDraftInput(SpecInput): + """Per-batch DFlash draft state for spec-v1 (non-overlap) scheduling. + + This object is stored on `ScheduleBatch.spec_info` between decode iterations. + It is NOT sent to model attention backends; the DFlash worker uses it to run + the draft model and to track draft-side cache progress. + + When draft windowing is disabled, `draft_seq_lens` matches the committed target + prefix length already materialized in the draft KV cache. When windowing is + enabled, `draft_seq_lens` is the logical resident length in the draft worker's + compact req-to-token mapping. In paged mode this may exceed the requested + window by up to `page_size - 1` so the local page table remains valid. `ctx_lens` + tracks newly committed target tokens that still need draft KV materialization. + """ + + # Current token to start the next DFlash block (one per request). + verified_id: torch.Tensor + + # Flattened context features for tokens that need to be appended into the draft cache. + # Shape: [sum(ctx_lens), K * hidden_size], where K is the number of target-layer + # hidden-state features concatenated per token (len(dflash_config.target_layer_ids), + # or default K == draft_num_layers for existing checkpoints). + target_hidden: torch.Tensor + + # Context lengths per request, used to slice `target_hidden`. Device tensor (int32). + ctx_lens: torch.Tensor + + # How many committed tokens are visible to the draft worker per request. + draft_seq_lens: torch.Tensor + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Draft state does not change token accounting. + return (1, 1) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + old_ctx_lens = self.ctx_lens + old_target_hidden = self.target_hidden + + self.verified_id = self.verified_id[new_indices] + self.ctx_lens = old_ctx_lens[new_indices] + self.draft_seq_lens = self.draft_seq_lens[new_indices] + + if old_target_hidden is None or old_target_hidden.numel() == 0: + self.target_hidden = old_target_hidden + return + + # Rebuild target_hidden for the filtered batch using vectorized indexing. + old_bs = int(old_ctx_lens.shape[0]) + offsets = torch.zeros( + (old_bs + 1,), dtype=torch.int64, device=old_ctx_lens.device + ) + offsets[1:].copy_(old_ctx_lens.to(torch.int64).cumsum(0)) + + start = offsets[:-1] + seg_start = start[new_indices] + seg_lens = old_ctx_lens[new_indices].to(torch.int64) + + max_len = int(seg_lens.max().item()) if seg_lens.numel() > 0 else 0 + if max_len <= 0: + self.target_hidden = old_target_hidden[:0] + return + + r = torch.arange(max_len, device=old_ctx_lens.device, dtype=torch.int64)[ + None, : + ] + pos2d = seg_start[:, None] + r + mask = r < seg_lens[:, None] + flat_pos = pos2d[mask] + self.target_hidden = ( + old_target_hidden.index_select(0, flat_pos) + if flat_pos.numel() > 0 + else old_target_hidden[:0] + ) + + def merge_batch(self, spec_info: "DFlashDraftInput"): + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.ctx_lens = torch.cat([self.ctx_lens, spec_info.ctx_lens], dim=0) + self.draft_seq_lens = torch.cat( + [self.draft_seq_lens, spec_info.draft_seq_lens], dim=0 + ) + if self.target_hidden is None or self.target_hidden.numel() == 0: + self.target_hidden = spec_info.target_hidden + elif ( + spec_info.target_hidden is not None and spec_info.target_hidden.numel() > 0 + ): + self.target_hidden = torch.cat( + [self.target_hidden, spec_info.target_hidden], dim=0 + ) + + +@dataclass +class DFlashVerifyInput(SpecInput): + """Inputs for a target-model verify forward in DFlash (spec-v1). + + The verify forward is run with `ForwardMode.TARGET_VERIFY` so that the target + model returns logits for all tokens in the block, enabling accept-length + computation. + """ + + draft_token: torch.Tensor + positions: torch.Tensor + draft_token_num: int + # Kept for compatibility with attention backends that gate tree metadata by `topk > 1`. + # DFLASH verify is linear (non-tree), so this is always 1. + topk: int = 1 + # Custom attention "allow mask" for TARGET_VERIFY in backends that require it (e.g. triton). + # Semantics follow SGLang speculative conventions: True means the (q, k) pair is allowed. + custom_mask: torch.Tensor | None = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL + + # Shape info for padding (e.g., DP attention / CUDA graph). + num_tokens_per_batch: int = -1 + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_VERIFY) + if self.num_tokens_per_batch == -1: + self.num_tokens_per_batch = int(self.draft_token_num) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + + def prepare_for_verify( + self, + batch: ScheduleBatch, + page_size: int, + *, + build_custom_mask: bool = True, + ): + if batch.forward_mode.is_idle(): + return + + batch.input_ids = self.draft_token + + if page_size == 1: + batch.out_cache_loc = alloc_token_slots( + batch.tree_cache, len(batch.input_ids) + ) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + prefix_lens_cpu = batch.seq_lens_cpu + end_offset = prefix_lens + self.draft_token_num + end_offset_cpu = prefix_lens_cpu + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + prefix_lens, + prefix_lens_cpu, + end_offset, + end_offset_cpu, + last_loc, + len(batch.input_ids), + ) + self.last_loc = last_loc + + bs = batch.batch_size() + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + if not build_custom_mask: + self.custom_mask = None + return + + if self.draft_token_num <= 0: + raise ValueError( + f"DFLASH draft_token_num must be positive, got {self.draft_token_num}." + ) + mask_chunks: List[torch.Tensor] = [] + q_len = int(self.draft_token_num) + q_idx = torch.arange(q_len, device=batch.device, dtype=torch.int32).unsqueeze(1) + for prefix_len in batch.seq_lens_cpu.tolist(): + prefix_len_i = int(prefix_len) + kv_len = prefix_len_i + q_len + k_idx = torch.arange( + kv_len, device=batch.device, dtype=torch.int32 + ).unsqueeze(0) + # Allow attending to the full prefix and to tokens up to (and including) the + # current query position within the verify block (standard causal masking). + allow = k_idx <= (prefix_len_i + q_idx) + mask_chunks.append(allow.flatten()) + self.custom_mask = ( + torch.cat(mask_chunks, dim=0) + if mask_chunks + else torch.empty((0,), dtype=torch.bool, device=batch.device) + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + device = req_pool_indices.device + bs = len(req_pool_indices) + + qo_indptr = torch.arange( + 0, + (bs + 1) * self.draft_token_num, + step=self.draft_token_num, + dtype=torch.int32, + device=device, + ) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=device) + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + self.draft_token_num * bs, + dtype=torch.int32, + device=device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + mask = self.custom_mask + if mask is not None: + mask_numel = ( + paged_kernel_lens_sum * self.draft_token_num + + (self.draft_token_num**2) * bs + ) + if mask.numel() < mask_numel: + # FIXME(attn): temporary fix for custom mask padding with cuda graph + mask = torch.cat( + [ + mask, + torch.full( + (mask_numel - mask.numel(),), + True, + dtype=torch.bool, + device=device, + ), + ], + dim=0, + ) + self.custom_mask = mask + return kv_indices, cum_kv_seq_len, qo_indptr, mask + + def verify( + self, + *, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + page_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """DFlash verification for greedy and non-greedy sampling. + + Returns: + new_verified_id: int64 tensor [bs] (the new current token per request) + commit_lens: int32 tensor [bs] (how many verify-input tokens are committed) + next_target_hidden: tensor [sum(commit_lens), feature_dim] + accept_length_per_req_cpu: list[int] (accepted draft tokens per request) + """ + if batch.forward_mode.is_idle(): + empty = torch.empty((0,), dtype=torch.int64, device=batch.device) + return empty, empty.to(torch.int32), empty, [] + + bs = batch.batch_size() + device = logits_output.next_token_logits.device + + sampling_info = batch.sampling_info + if sampling_info is not None: + if len(sampling_info) != bs: + raise RuntimeError( + "DFLASH verify sampling_info size mismatch: " + f"len(sampling_info)={len(sampling_info)}, bs={bs}." + ) + apply_dflash_verify_logits_adjustments( + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + draft_token_num=self.draft_token_num, + ) + + candidates = self.draft_token.view(bs, self.draft_token_num) + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + top_ks = [int(req.sampling_params.top_k) for req in batch.reqs] + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + max_top_k=max(max(top_ks), 1) if top_ks else 1, + uniform_top_k_value=( + top_ks[0] + if top_ks and all(top_k == top_ks[0] for top_k in top_ks) + else None + ), + ) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, self.draft_token_num + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + + # Single D2H transfer: candidates[1:] + accept_len + bonus + packed = torch.cat( + [candidates[:, 1:], accept_len.unsqueeze(1), bonus.unsqueeze(1)], dim=1 + ).cpu() + + max_acc = self.draft_token_num - 1 + accept_length_per_req_cpu: List[int] = [] + commit_lens_cpu: List[int] = [] + new_verified_list: List[int] = [] + + for i, req in enumerate(batch.reqs): + acc_len = int(packed[i, max_acc].item()) + proposed = packed[i, :acc_len].tolist() + [ + int(packed[i, max_acc + 1].item()) + ] + + appended = 0 + for token_id in proposed: + token_id = int(token_id) + req.output_ids.append(token_id) + appended += 1 + req.check_finished() + if req.finished(): + break + if req.grammar is not None: + req.grammar.accept_token(token_id) + + if req.output_ids: + new_verified_token = int(req.output_ids[-1]) + elif req.origin_input_ids: + # If no token was appended in this verify step, keep the current token unchanged. + new_verified_token = int(req.origin_input_ids[-1]) + else: + raise RuntimeError( + "DFLASH verify cannot determine current token: both output_ids and origin_input_ids are empty." + ) + + commit_lens_cpu.append(appended) + new_verified_list.append(new_verified_token) + accept_length_per_req_cpu.append(max(0, appended - 1)) + req.spec_verify_ct += 1 + req.spec_accepted_tokens += accept_length_per_req_cpu[-1] + + commit_lens = torch.tensor(commit_lens_cpu, dtype=torch.int32, device=device) + new_verified_id = torch.tensor( + new_verified_list, dtype=torch.int64, device=device + ) + + # Free uncommitted KV cache slots and compact out_cache_loc. + if page_size == 1: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + keep_mask = ( + torch.arange(self.draft_token_num, device=device)[None, :] + < commit_lens[:, None] + ) + batch.token_to_kv_pool_allocator.free(out_cache_loc[~keep_mask]) + batch.out_cache_loc = out_cache_loc[keep_mask] + else: + out_cache_loc = batch.out_cache_loc.view(bs, self.draft_token_num) + row_offsets = torch.arange(self.draft_token_num, device=device)[None, :] + keep_slots = _compute_paged_keep_slots( + prefix_lens=batch.seq_lens, + commit_lens=commit_lens, + draft_token_num=self.draft_token_num, + page_size=page_size, + ) + free_mask = row_offsets >= keep_slots[:, None] + batch.token_to_kv_pool_allocator.free(out_cache_loc[free_mask]) + + keep_mask = row_offsets < commit_lens[:, None] + batch.out_cache_loc = out_cache_loc[keep_mask] + + # Update req-level KV cache accounting. + for req, commit_len in zip(batch.reqs, commit_lens_cpu, strict=True): + req.kv_committed_len += commit_len + req.kv_allocated_len = req.kv_committed_len + + # Update req_to_token pool mapping for newly committed tokens. + end_offset = batch.seq_lens + commit_lens.to(batch.seq_lens.dtype) + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + bs, + ) + + # Update batch seq lens. + batch.seq_lens.add_(commit_lens.to(batch.seq_lens.dtype)) + batch.seq_lens_cpu.add_( + torch.tensor(commit_lens_cpu, dtype=batch.seq_lens_cpu.dtype) + ) + # Keep seq_lens_sum in sync; flashinfer indices updaters rely on this for buffer sizing. + batch.seq_lens_sum += sum(commit_lens_cpu) + + # Build next-step context features from the committed verify-input tokens. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) + hidden = hidden.view(bs, self.draft_token_num, -1) + segments: List[torch.Tensor] = [] + for i, ln in enumerate(commit_lens_cpu): + if ln > 0: + segments.append(hidden[i, :ln, :]) + next_target_hidden = torch.cat(segments, dim=0) if segments else hidden[:0] + + # Avoid confusing downstream consumers (spec-v1 decode doesn't use this). + logits_output.hidden_states = None + + return ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) diff --git a/python/sglang/srt/speculative/dflash_info_v2.py b/python/sglang/srt/speculative/dflash_info_v2.py new file mode 100644 index 000000000000..29e61f5971a7 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_info_v2.py @@ -0,0 +1,225 @@ +"""DFLASH spec-v2 overlap scheduling data structures (WIP). + +The spec-v2 path will mirror the scheduler integration used by Eagle v2: +- the worker returns `(next_token_ids, accept_lens, next_draft_input)` +- scheduler output processing (not the worker) mutates `req.output_ids` + +This file is intentionally introduced early to keep spec-v2-specific state +isolated from the existing spec-v1 implementation in `dflash_info.py`. +""" + +from __future__ import annotations + +import contextlib +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from sglang.srt.environ import envs +from sglang.srt.managers.overlap_utils import FutureIndices +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.mem_cache.common import ( + alloc_paged_token_slots_extend, + alloc_token_slots, + get_last_loc, +) +from sglang.srt.server_args import get_global_server_args +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func + +_OVERLAP_PLAN_STREAMS: dict[str, torch.cuda.Stream] = {} + + +def _get_overlap_plan_stream( + device: torch.device | str, +) -> tuple[Optional[torch.cuda.Stream], contextlib.AbstractContextManager]: + """Return an optional plan stream/context for overlap scheduling prep kernels.""" + if not envs.SGLANG_ENABLE_OVERLAP_PLAN_STREAM.get(): + return None, contextlib.nullcontext() + + device_str = str(device) + stream = _OVERLAP_PLAN_STREAMS.get(device_str) + if stream is None: + stream = torch.get_device_module(device_str).Stream() + _OVERLAP_PLAN_STREAMS[device_str] = stream + return stream, torch.get_device_module(device_str).stream(stream) + + +@dataclass +class DFlashDraftInputV2(SpecInput): + """Draft-side state carried across overlap iterations (spec-v2).""" + + # Required by overlap FutureMap plumbing (match Eagle v2 field names). + topk_p: torch.Tensor + topk_index: torch.Tensor + verified_id: torch.Tensor + new_seq_lens: torch.Tensor + hidden_states: torch.Tensor + verify_done: Optional[torch.cuda.Event] = None + max_top_k: int = 1 + uniform_top_k_value: Optional[int] = None + + # Filled by scheduler after dispatch. + future_indices: Optional[FutureIndices] = None + + def __post_init__(self): + super().__init__(spec_input_type=SpecInputType.DFLASH_DRAFT) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + # Spec v2 draft state itself does not change token accounting. + return (1, 1) + + @classmethod + def create_idle_input(cls, device: torch.device) -> "DFlashDraftInputV2": + return cls( + topk_p=torch.empty((0, 1), device=device, dtype=torch.float32), + topk_index=torch.empty((0, 1), device=device, dtype=torch.int64), + verified_id=torch.empty((0,), device=device, dtype=torch.int32), + new_seq_lens=torch.empty((0,), device=device, dtype=torch.int32), + hidden_states=torch.empty((0, 1), device=device, dtype=torch.float16), + verify_done=None, + ) + + def prepare_for_decode(self, batch: ScheduleBatch): + """Allocate headroom in the shared req_to_token pool for the next DFLASH step. + + DFLASH spec-v2 uses overlap scheduling's "over-allocation" approach: we reserve + future KV slots ahead of time so the worker can gather `out_cache_loc` directly + from `req_to_token` without allocator backup/restore. + """ + plan_stream, plan_stream_ctx = _get_overlap_plan_stream(batch.device) + if plan_stream is None: + # Ensure previous forward is completed before mutating shared buffers. + batch.maybe_wait_verify_done() + + bs = batch.batch_size() + if bs == 0: + return + + # For DFLASH, each decode step needs a fixed-size verify block. + block_size = int(get_global_server_args().speculative_num_draft_tokens) + if block_size <= 0: + raise ValueError( + f"DFLASH invalid speculative_num_draft_tokens={block_size}." + ) + + top_ks = [int(req.sampling_params.top_k) for req in batch.reqs] + self.max_top_k = max(max(top_ks), 1) if top_ks else 1 + self.uniform_top_k_value = ( + top_ks[0] + if top_ks and all(top_k == top_ks[0] for top_k in top_ks) + else None + ) + + page_size = batch.token_to_kv_pool_allocator.page_size + + cur_kv_lens_cpu_t = torch.tensor( + [req.kv_allocated_len for req in batch.reqs], + dtype=torch.int32, + device="cpu", + ) + + caller_stream = None + if plan_stream is not None: + caller_stream = torch.get_device_module(batch.device).current_stream() + + with plan_stream_ctx: + if plan_stream is not None and caller_stream is not None: + # `batch.seq_lens`, `batch.req_pool_indices`, and related tensors may + # have just been rebuilt on the scheduler stream by filter/merge ops. + # The plan stream must wait for those writes before reading them. + plan_stream.wait_stream(caller_stream) + + if plan_stream is not None and self.verify_done is not None: + plan_stream.wait_event(self.verify_done) + + cur_kv_lens = cur_kv_lens_cpu_t.to(device=batch.device) + committed_kv_lens = batch.seq_lens.to(dtype=torch.int32) + nxt_kv_lens = torch.maximum(cur_kv_lens, committed_kv_lens + 2 * block_size) + nxt_kv_lens_cpu_t = nxt_kv_lens.to(device="cpu") + num_needed_tokens = int( + (nxt_kv_lens_cpu_t - cur_kv_lens_cpu_t).sum().item() + ) + + if num_needed_tokens > 0: + if page_size == 1: + out_cache_loc = alloc_token_slots( + batch.tree_cache, num_needed_tokens + ) + else: + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + cur_kv_lens, + ) + out_cache_loc = alloc_paged_token_slots_extend( + batch.tree_cache, + cur_kv_lens, + cur_kv_lens_cpu_t, + nxt_kv_lens, + nxt_kv_lens_cpu_t, + last_loc, + num_needed_tokens, + ) + + # Updating req_to_token is a write to a shared tensor: it must not overlap + # with the previous batch's forward, which also reads req_to_token. + assign_req_to_token_pool_func( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + cur_kv_lens, + nxt_kv_lens, + out_cache_loc, + bs, + ) + if caller_stream is not None: + # Enqueue the dependency on the caller's stream, not inside the + # plan-stream context, so forward work cannot observe partially + # prepared req_to_token / KV allocation state. + caller_stream.wait_stream(plan_stream) + + nxt_kv_lens_cpu = nxt_kv_lens_cpu_t.tolist() + for req, new_alloc_len in zip(batch.reqs, nxt_kv_lens_cpu, strict=True): + req.kv_allocated_len = int(new_alloc_len) + + # NOTE: In overlap scheduling, per-request CPU state (e.g., `req.kv_committed_len`) + # can lag behind `batch.seq_lens` by one iteration because result processing is + # overlapped with the next forward. Avoid using lagging CPU state for buffer sizing, + # and never force a GPU->CPU sync here. + # + # `seq_lens_sum` is used for allocation sizing in attention backends (e.g., FlashInfer + # kv_indices buffers). Use allocated KV lengths as a safe upper bound. + batch.seq_lens_cpu = nxt_kv_lens_cpu_t.to(dtype=torch.int64) + batch.seq_lens_sum = int(nxt_kv_lens_cpu_t.sum().item()) + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + if self.future_indices is not None: + self.future_indices.indices = self.future_indices.indices[new_indices] + return + + self.topk_p = self.topk_p[new_indices] + self.topk_index = self.topk_index[new_indices] + self.verified_id = self.verified_id[new_indices] + self.new_seq_lens = self.new_seq_lens[new_indices] + self.hidden_states = self.hidden_states[new_indices] + + def merge_batch(self, spec_info: "DFlashDraftInputV2"): + if self.future_indices is not None: + assert spec_info.future_indices is not None + self.future_indices = FutureIndices( + indices=torch.cat( + [self.future_indices.indices, spec_info.future_indices.indices] + ) + ) + return + + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p], dim=0) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index], dim=0) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], dim=0) + self.new_seq_lens = torch.cat( + [self.new_seq_lens, spec_info.new_seq_lens], dim=0 + ) + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], dim=0 + ) diff --git a/python/sglang/srt/speculative/dflash_utils.py b/python/sglang/srt/speculative/dflash_utils.py new file mode 100644 index 000000000000..4494991c82d6 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_utils.py @@ -0,0 +1,877 @@ +from __future__ import annotations + +from dataclasses import dataclass +from numbers import Integral +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from sglang.srt.environ import envs +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.sampler import apply_custom_logit_processor +from sglang.srt.utils import is_cuda + +DEFAULT_DFLASH_MASK_TOKEN = "<|MASK|>" + +_DFLASH_SAMPLING_VERIFY_AVAILABLE = False +_DFLASH_CHAIN_VERIFY_BUFFERS: dict[tuple[Optional[int], int], dict[str, Any]] = {} +_DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS = frozenset( + { + "FlashInferAttnBackend", + "FlashInferMLAAttnBackend", + "FlashAttentionBackend", + "TRTLLMHAAttnBackend", + "TRTLLMMLABackend", + } +) + + +if is_cuda(): + try: + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + ) + + _DFLASH_SAMPLING_VERIFY_AVAILABLE = True + except Exception: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None +else: + top_k_renorm_prob = None + top_p_renorm_prob = None + tree_speculative_sampling_target_only = None + + +def is_dflash_sampling_verify_available() -> bool: + return _DFLASH_SAMPLING_VERIFY_AVAILABLE + + +def scale_kv_cell_size_per_token_for_dflash( + *, + target_cell_size_per_token: int, + target_num_layers: int, + draft_num_layers: int, + draft_cell_size_per_token: Optional[int] = None, +) -> int: + """Compute bytes/token budget for combined target+draft KV pools (DFLASH). + + DFLASH runs a separate draft runner with its own KV pool. The target runner's + token capacity must fit both pools in aggregate. + + Returns: + Approximate per-token bytes for (target KV + draft KV), expressed as a + scaled version of `target_cell_size_per_token`, unless an explicit + `draft_cell_size_per_token` is provided (in which case we sum them). + """ + if target_cell_size_per_token <= 0: + raise ValueError( + "target_cell_size_per_token must be positive, " + f"got {target_cell_size_per_token}." + ) + + if draft_cell_size_per_token is not None: + draft_cell_size_per_token = int(draft_cell_size_per_token) + if draft_cell_size_per_token <= 0: + raise ValueError( + "draft_cell_size_per_token must be positive when provided, " + f"got {draft_cell_size_per_token}." + ) + return int(target_cell_size_per_token) + int(draft_cell_size_per_token) + + if target_num_layers <= 0 or draft_num_layers <= 0: + return int(target_cell_size_per_token) + + total_layers = int(target_num_layers) + int(draft_num_layers) + return ( + int(target_cell_size_per_token) * int(total_layers) + int(target_num_layers) - 1 + ) // int(target_num_layers) + + +@dataclass(frozen=True) +class DFlashAutoMemoryPlan: + max_mamba_cache_size: int + min_required_tokens: int + required_rest_memory_gb: float + + +def resolve_dflash_concurrency_required_tokens( + *, + max_running_requests: int, + page_size: int, + speculative_num_draft_tokens: int, +) -> int: + if max_running_requests <= 0: + raise ValueError( + f"max_running_requests must be positive, got {max_running_requests}." + ) + if page_size <= 0: + raise ValueError(f"page_size must be positive, got {page_size}.") + if speculative_num_draft_tokens < 0: + raise ValueError( + "speculative_num_draft_tokens must be non-negative, " + f"got {speculative_num_draft_tokens}." + ) + + estimated_max_decode_tokens_per_req = int( + envs.SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION.get() + ) + per_request_tokens = ( + int(page_size) + + int(page_size) + + max( + int(estimated_max_decode_tokens_per_req), + 2 * int(speculative_num_draft_tokens), + ) + ) + return int(max_running_requests) * per_request_tokens + + +def resolve_dflash_max_mamba_cache_size( + *, + max_running_requests: int, + mamba_ratio: int, + explicit_max_mamba_cache_size: Optional[int] = None, +) -> int: + if max_running_requests <= 0: + raise ValueError( + f"max_running_requests must be positive, got {max_running_requests}." + ) + if mamba_ratio <= 0: + raise ValueError(f"mamba_ratio must be positive, got {mamba_ratio}.") + if explicit_max_mamba_cache_size is not None: + explicit_max_mamba_cache_size = int(explicit_max_mamba_cache_size) + if explicit_max_mamba_cache_size <= 0: + raise ValueError( + "explicit_max_mamba_cache_size must be positive when provided, " + f"got {explicit_max_mamba_cache_size}." + ) + return explicit_max_mamba_cache_size + return int(max_running_requests) * int(mamba_ratio) + + +def resolve_dflash_auto_memory_plan( + *, + rest_memory_gb: float, + post_model_load_memory_gb: float, + cell_size: int, + max_running_requests: int, + mamba_cache_per_req: int, + speculative_num_draft_tokens: int, + chunked_prefill_size: Optional[int], + max_prefill_tokens: int, + page_size: int, + mamba_ratio: int, + explicit_max_mamba_cache_size: Optional[int] = None, +) -> DFlashAutoMemoryPlan: + if cell_size <= 0: + raise ValueError(f"cell_size must be positive, got {cell_size}.") + if mamba_cache_per_req <= 0: + raise ValueError( + f"mamba_cache_per_req must be positive, got {mamba_cache_per_req}." + ) + if speculative_num_draft_tokens < 0: + raise ValueError( + "speculative_num_draft_tokens must be non-negative, " + f"got {speculative_num_draft_tokens}." + ) + if max_prefill_tokens <= 0: + raise ValueError( + f"max_prefill_tokens must be positive, got {max_prefill_tokens}." + ) + if page_size <= 0: + raise ValueError(f"page_size must be positive, got {page_size}.") + + max_mamba_cache_size = resolve_dflash_max_mamba_cache_size( + max_running_requests=max_running_requests, + mamba_ratio=mamba_ratio, + explicit_max_mamba_cache_size=explicit_max_mamba_cache_size, + ) + + if chunked_prefill_size is not None and int(chunked_prefill_size) > 0: + min_required_tokens = int(chunked_prefill_size) + else: + min_required_tokens = int(max_prefill_tokens) + min_required_tokens = max(min_required_tokens, int(page_size)) + min_required_tokens = max( + min_required_tokens, + resolve_dflash_concurrency_required_tokens( + max_running_requests=max_running_requests, + page_size=page_size, + speculative_num_draft_tokens=speculative_num_draft_tokens, + ), + ) + + linear_state_bytes = int(mamba_cache_per_req) * ( + int(max_mamba_cache_size) + + int(max_running_requests) * int(speculative_num_draft_tokens) + ) + required_rest_memory_gb = ( + linear_state_bytes + int(min_required_tokens) * int(cell_size) + ) / float(1 << 30) + if required_rest_memory_gb > float(post_model_load_memory_gb): + raise RuntimeError( + "Not enough GPU memory for DFLASH auto sizing. " + f"Required at least {required_rest_memory_gb:.2f} GB after weight load, " + f"but only {float(post_model_load_memory_gb):.2f} GB is available. " + f"max_running_requests={max_running_requests}, " + f"max_mamba_cache_size={max_mamba_cache_size}, " + f"min_required_tokens={min_required_tokens}." + ) + + return DFlashAutoMemoryPlan( + max_mamba_cache_size=int(max_mamba_cache_size), + min_required_tokens=int(min_required_tokens), + required_rest_memory_gb=max( + float(rest_memory_gb), float(required_rest_memory_gb) + ), + ) + + +def resolve_dflash_verify_mask_policy(attn_backend: Any) -> tuple[str, bool]: + backend = attn_backend + for _ in range(4): + full_backend = getattr(backend, "full_attn_backend", None) + if full_backend is None: + break + backend = full_backend + backend_name = type(backend).__name__ + return backend_name, (backend_name not in _DFLASH_VERIFY_SKIP_CUSTOM_MASK_BACKENDS) + + +def apply_dflash_verify_logits_adjustments( + *, + next_token_logits: torch.Tensor, + sampling_info: Any, + draft_token_num: int, +) -> None: + """Apply sampling-time logit adjustments for DFlash verify in place. + + This keeps v1 and v2 verify semantics aligned while letting overlap scheduling + use the cheaper precomputed `acc_linear_penalties` path instead of allocating a + repeated `[bs * draft_token_num, vocab]` penalty tensor every step. + """ + if sampling_info is None: + return + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + + bs = len(sampling_info) + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch for DFlash verify adjustments. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + next_token_logits, + sampling_info, + num_tokens_in_batch=draft_token_num, + ) + + acc_linear_penalties = getattr(sampling_info, "acc_linear_penalties", None) + penalizer = getattr(sampling_info, "penalizer_orchestrator", None) + vocab_mask = getattr(sampling_info, "vocab_mask", None) + logit_bias = getattr(sampling_info, "logit_bias", None) + + logits_3d: Optional[torch.Tensor] = None + + def get_logits_3d() -> torch.Tensor: + nonlocal logits_3d + if logits_3d is None: + logits_3d = next_token_logits.reshape(bs, draft_token_num, -1) + return logits_3d + + # Dense fallback only when we need live penalizer application or a vocab mask. + # In overlap scheduling the common path is `acc_linear_penalties`, which can be + # broadcast over the verify block without materializing a repeated buffer. + if ( + penalizer is not None and penalizer.is_required and acc_linear_penalties is None + ) or vocab_mask is not None: + linear_penalty = torch.zeros( + (bs, next_token_logits.shape[1]), + dtype=torch.float32, + device=next_token_logits.device, + ) + sampling_info.apply_logits_bias(linear_penalty) + get_logits_3d().add_( + linear_penalty[:, None, :].to(dtype=next_token_logits.dtype) + ) + return + + if acc_linear_penalties is not None: + if ( + acc_linear_penalties.device != next_token_logits.device + or acc_linear_penalties.dtype != next_token_logits.dtype + ): + acc_linear_penalties = acc_linear_penalties.to( + device=next_token_logits.device, + dtype=next_token_logits.dtype, + ) + get_logits_3d().add_(acc_linear_penalties[:, None, :]) + + if logit_bias is not None: + if ( + logit_bias.device != next_token_logits.device + or logit_bias.dtype != next_token_logits.dtype + ): + logit_bias = logit_bias.to( + device=next_token_logits.device, + dtype=next_token_logits.dtype, + ) + get_logits_3d().add_(logit_bias[:, None, :]) + + +def _get_or_create_chain_verify_buffers( + *, + bs: int, + draft_token_num: int, + device: torch.device, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: + key = (device.index, int(draft_token_num)) + cached = _DFLASH_CHAIN_VERIFY_BUFFERS.get(key) + cap_bs = 0 if cached is None else int(cached["cap_bs"]) + if cap_bs < bs: + new_cap = max(int(bs), cap_bs * 2 if cap_bs > 0 else int(bs)) + retrieve_index = torch.arange( + new_cap * draft_token_num, dtype=torch.int64, device=device + ).view(new_cap, draft_token_num) + row_next = torch.arange( + 1, draft_token_num + 1, dtype=torch.int64, device=device + ) + row_next[-1] = -1 + retrieve_next_token = row_next.unsqueeze(0).expand(new_cap, -1).clone() + retrieve_next_sibling = torch.full( + (new_cap, draft_token_num), -1, dtype=torch.int64, device=device + ) + predicts = torch.empty( + (new_cap * draft_token_num,), dtype=torch.int32, device=device + ) + accept_index = torch.empty( + (new_cap, draft_token_num), dtype=torch.int32, device=device + ) + accept_token_num = torch.empty((new_cap,), dtype=torch.int32, device=device) + cached = { + "cap_bs": int(new_cap), + "retrieve_index": retrieve_index, + "retrieve_next_token": retrieve_next_token, + "retrieve_next_sibling": retrieve_next_sibling, + "predicts": predicts, + "accept_index": accept_index, + "accept_token_num": accept_token_num, + } + _DFLASH_CHAIN_VERIFY_BUFFERS[key] = cached + + assert cached is not None + retrieve_index = cached["retrieve_index"][:bs] + retrieve_next_token = cached["retrieve_next_token"][:bs] + retrieve_next_sibling = cached["retrieve_next_sibling"][:bs] + predicts = cached["predicts"][: bs * draft_token_num] + accept_index = cached["accept_index"][:bs] + accept_token_num = cached["accept_token_num"][:bs] + return ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) + + +def build_target_layer_ids(num_target_layers: int, num_draft_layers: int) -> List[int]: + """Select target layer indices used to build DFlash context features. + + Args: + num_target_layers: Number of transformer layers in the runtime target model. + num_draft_layers: Number of layers in the DFlash draft model. + + Returns: + A list of 0-based target layer indices of length `num_draft_layers`. + + Notes: + - DFlash uses hidden states after each selected target layer (HF-style). + - SGLang captures "before layer i", so the model hook will typically add +1 + when mapping to capture points. + """ + if num_target_layers <= 0: + raise ValueError( + f"num_target_layers must be positive, got {num_target_layers}." + ) + if num_draft_layers <= 0: + raise ValueError(f"num_draft_layers must be positive, got {num_draft_layers}.") + + if num_draft_layers == 1: + return [num_target_layers // 2] + + start = 1 + end = num_target_layers - 3 + if end < start: + raise ValueError( + "DFlash layer selection requires num_target_layers >= 4. " + f"Got num_target_layers={num_target_layers}." + ) + + span = end - start + return [ + int(round(start + (i * span) / (num_draft_layers - 1))) + for i in range(num_draft_layers) + ] + + +def _cfg_get(config: Any, key: str, default: Any = None) -> Any: + if isinstance(config, dict): + return config.get(key, default) + return getattr(config, key, default) + + +def _get_text_config(config: Any) -> Any: + if config is None: + return None + if isinstance(config, dict): + return config.get("text_config", config) + text_config = getattr(config, "text_config", None) + if text_config is not None: + return text_config + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + try: + resolved = get_text_config() + if resolved is not None: + return resolved + except TypeError: + pass + return config + + +def _get_dflash_config(config: Any) -> dict: + if isinstance(config, dict): + cfg = config.get("dflash_config", None) + else: + cfg = getattr(config, "dflash_config", None) + if cfg is None: + return {} + if isinstance(cfg, dict): + return cfg + + try: + return dict(cfg) + except Exception: + return {} + + +def _parse_optional_int( + value: Any, + *, + field_name: str, + min_value: Optional[int] = None, +) -> Optional[int]: + if value is None: + return None + try: + parsed = int(value) + except Exception as e: + raise ValueError(f"Invalid {field_name}={value!r}.") from e + if min_value is not None and parsed < int(min_value): + comparator = "positive" if int(min_value) == 1 else f">= {int(min_value)}" + raise ValueError(f"{field_name} must be {comparator}, got {parsed}.") + return parsed + + +@dataclass(frozen=True) +class DFlashDraftConfig: + num_hidden_layers: Optional[int] + num_target_layers: Optional[int] + block_size: Optional[int] + target_layer_ids: Optional[List[int]] + mask_token: str + mask_token_id: Optional[int] + + def require_num_layers(self) -> int: + if self.num_hidden_layers is None: + raise ValueError( + "DFLASH requires draft num_hidden_layers in config. " + "Got config without num_hidden_layers." + ) + return int(self.num_hidden_layers) + + def resolve_block_size(self, *, default: Optional[int] = None) -> Optional[int]: + return self.block_size if self.block_size is not None else default + + def resolve_target_layer_ids( + self, + *, + target_num_layers: int, + draft_num_layers: Optional[int] = None, + ) -> List[int]: + target_num_layers = int(target_num_layers) + if target_num_layers <= 0: + raise ValueError( + f"target_num_layers must be positive, got {target_num_layers}." + ) + + if self.target_layer_ids is None: + if draft_num_layers is None: + draft_num_layers = self.require_num_layers() + return build_target_layer_ids(target_num_layers, int(draft_num_layers)) + + resolved = list(self.target_layer_ids) + if len(resolved) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(resolved)}." + ) + for idx, val in enumerate(resolved): + if val < 0 or val >= target_num_layers: + raise ValueError( + "DFLASH target_layer_ids contains an out-of-range layer id. " + f"target_layer_ids[{idx}]={val}, target_num_layers={target_num_layers}." + ) + return resolved + + +def parse_dflash_draft_config(*, draft_hf_config: Any) -> DFlashDraftConfig: + """Parse and validate DFLASH draft config fields from HF config/dict.""" + dflash_cfg = _get_dflash_config(draft_hf_config) + draft_text_config = _get_text_config(draft_hf_config) + + num_hidden_layers = _parse_optional_int( + _cfg_get(draft_text_config, "num_hidden_layers", None), + field_name="DFLASH draft num_hidden_layers", + min_value=1, + ) + raw_num_target_layers = dflash_cfg.get( + "num_target_layers", + _cfg_get(draft_hf_config, "num_target_layers", None), + ) + num_target_layers = _parse_optional_int( + raw_num_target_layers, + field_name="DFLASH draft num_target_layers", + min_value=1, + ) + + # Keep support for current checkpoints where block_size is top-level. + raw_block_size = dflash_cfg.get( + "block_size", + _cfg_get(draft_hf_config, "block_size", None), + ) + block_size = _parse_optional_int( + raw_block_size, + field_name="DFLASH block_size", + min_value=1, + ) + + layer_ids = dflash_cfg.get( + "target_layer_ids", + _cfg_get(draft_hf_config, "target_layer_ids", None), + ) + parsed_target_layer_ids: Optional[List[int]] + if layer_ids is None: + parsed_target_layer_ids = None + else: + if not isinstance(layer_ids, (list, tuple)): + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be a list of ints, " + f"got type={type(layer_ids).__name__}." + ) + parsed_target_layer_ids = [int(x) for x in layer_ids] + if len(parsed_target_layer_ids) <= 0: + raise ValueError( + "DFLASH dflash_config.target_layer_ids must be non-empty. " + f"Got len(target_layer_ids)={len(parsed_target_layer_ids)}." + ) + + mask_token = dflash_cfg.get("mask_token", None) + if mask_token is None: + mask_token = DEFAULT_DFLASH_MASK_TOKEN + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + "DFLASH dflash_config.mask_token must be a non-empty string, " + f"got {mask_token!r}." + ) + + mask_token_id = dflash_cfg.get("mask_token_id", None) + if mask_token_id is not None: + if not isinstance(mask_token_id, Integral) or isinstance(mask_token_id, bool): + raise ValueError( + "DFLASH dflash_config.mask_token_id must be an integer, " + f"got {mask_token_id!r} (type={type(mask_token_id).__name__})." + ) + mask_token_id = int(mask_token_id) + if mask_token_id < 0: + raise ValueError( + "DFLASH dflash_config.mask_token_id must be non-negative, " + f"got {mask_token_id}." + ) + + return DFlashDraftConfig( + num_hidden_layers=num_hidden_layers, + num_target_layers=num_target_layers, + block_size=block_size, + target_layer_ids=parsed_target_layer_ids, + mask_token=mask_token, + mask_token_id=mask_token_id, + ) + + +def can_dflash_slice_qkv_weight(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether DFlash can slice KV weights from a fused QKV linear layer.""" + quant_method = getattr(qkv_proj, "quant_method", None) + if not isinstance(quant_method, UnquantizedLinearMethod): + return ( + False, + "quantized qkv_proj is not supported for this path " + f"(quant_method={type(quant_method).__name__})", + ) + if not hasattr(qkv_proj, "weight"): + return False, "qkv weight tensor is missing" + return True, "" + + +def can_dflash_use_fused_qkv_proj(qkv_proj: Any) -> Tuple[bool, str]: + """Validate whether a QKV layer is eligible for DFlash fused KV materialization.""" + eligible, reason = can_dflash_slice_qkv_weight(qkv_proj) + if not eligible: + return False, reason + if getattr(qkv_proj, "bias", None) is not None: + return False, "qkv bias is not supported for fused KV path" + return True, "" + + +def compute_dflash_accept_len_and_bonus( + *, + candidates: torch.Tensor, + target_predict: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens (greedy verify rule). + + Args: + candidates: Token ids proposed by the DFlash draft, including the current token. + Shape: [bs, block_size]. candidates[:, 0] is the current token. + target_predict: Token ids predicted by the target model for each position in the block. + Shape: [bs, block_size]. target_predict[:, t] corresponds to argmax at position t. + + Returns: + accept_len: int32 tensor [bs], number of accepted *draft* tokens (excluding current token and bonus token). + bonus: int64 tensor [bs], the target-predicted token at index accept_len (the "bonus" token to append). + + Notes: + Matches the reference implementation rule: + accept while candidates[:, 1:] == target_predict[:, :-1] consecutively. + """ + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if target_predict.shape != candidates.shape: + raise ValueError( + "target_predict must have the same shape as candidates. " + f"candidates.shape={tuple(candidates.shape)}, target_predict.shape={tuple(target_predict.shape)}" + ) + + bs, block_size = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if block_size <= 0: + raise ValueError(f"block_size must be positive, got {block_size}.") + + matches = candidates[:, 1:] == target_predict[:, :-1] + accept_len = matches.to(torch.int32).cumprod(dim=1).sum(dim=1) + bonus = target_predict[torch.arange(bs, device=target_predict.device), accept_len] + return accept_len, bonus.to(torch.int64) + + +def compute_dflash_sampling_accept_len_and_bonus( + *, + candidates: torch.Tensor, + next_token_logits: torch.Tensor, + sampling_info: Any, + max_top_k: Optional[int] = None, + uniform_top_k_value: Optional[int] = None, + threshold_single: Optional[float] = None, + threshold_acc: Optional[float] = None, + uniform_samples: Optional[torch.Tensor] = None, + uniform_samples_for_final_sampling: Optional[torch.Tensor] = None, + use_sparse_topk: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute DFlash accept lengths and bonus tokens for non-greedy sampling. + + This is a chain-specialized variant of speculative target-only verification: + - DFlash proposals are linear (topk == 1), so each verify level has at most one candidate. + - When a candidate is rejected at a level, the final token is sampled from + `relu(q - p)` where `p` has only the rejected candidate mass. + """ + if not _DFLASH_SAMPLING_VERIFY_AVAILABLE: + raise RuntimeError( + "DFLASH non-greedy verification is unavailable on this build/device." + ) + if candidates.ndim != 2: + raise ValueError(f"candidates must be 2D, got shape={tuple(candidates.shape)}") + if next_token_logits.ndim != 2: + raise ValueError( + "next_token_logits must be 2D, " + f"got shape={tuple(next_token_logits.shape)}." + ) + + bs, draft_token_num = candidates.shape + if bs <= 0: + raise ValueError(f"batch size must be positive, got {bs}.") + if draft_token_num <= 0: + raise ValueError(f"draft_token_num must be positive, got {draft_token_num}.") + if next_token_logits.shape[0] != bs * draft_token_num: + raise ValueError( + "next_token_logits row count mismatch. " + f"Expected {bs * draft_token_num}, got {next_token_logits.shape[0]}." + ) + if candidates.device != next_token_logits.device: + raise ValueError( + "candidates and next_token_logits must be on the same device, " + f"got {candidates.device} and {next_token_logits.device}." + ) + + if threshold_single is None: + from sglang.srt.server_args import get_global_server_args + + threshold_single = get_global_server_args().speculative_accept_threshold_single + if threshold_acc is None: + from sglang.srt.server_args import get_global_server_args + + threshold_acc = get_global_server_args().speculative_accept_threshold_acc + threshold_single = float(threshold_single) + threshold_acc = max(float(threshold_acc), 1e-9) + + device = next_token_logits.device + + if uniform_samples is None: + uniform_samples = torch.rand( + (bs, draft_token_num), dtype=torch.float32, device=device + ) + else: + if uniform_samples.shape != (bs, draft_token_num): + raise ValueError( + "uniform_samples shape mismatch. " + f"Expected {(bs, draft_token_num)}, got {tuple(uniform_samples.shape)}." + ) + uniform_samples = uniform_samples.to(device=device, dtype=torch.float32) + + if uniform_samples_for_final_sampling is None: + uniform_samples_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=device + ) + else: + if uniform_samples_for_final_sampling.shape != (bs,): + raise ValueError( + "uniform_samples_for_final_sampling shape mismatch. " + f"Expected {(bs,)}, got {tuple(uniform_samples_for_final_sampling.shape)}." + ) + uniform_samples_for_final_sampling = uniform_samples_for_final_sampling.to( + device=device, + dtype=torch.float32, + ) + + need_top_k = bool(getattr(sampling_info, "need_top_k_sampling", True)) + need_top_p = bool(getattr(sampling_info, "need_top_p_sampling", False)) + # Build target distribution once over all verify rows. + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, draft_token_num, dim=0 + ) + scaled_logits = next_token_logits / expanded_temperature + sparse_topk_applied = False + + if use_sparse_topk and need_top_k: + repeated_top_ks = torch.repeat_interleave( + sampling_info.top_ks, draft_token_num, dim=0 + ).to(dtype=torch.int64) + vocab_size = int(scaled_logits.shape[-1]) + repeated_top_ks.clamp_(min=1, max=vocab_size) + if max_top_k is None: + max_top_k = int(repeated_top_ks.max().item()) + else: + max_top_k = int(max_top_k) + if max_top_k < 1: + max_top_k = 1 + elif max_top_k > vocab_size: + max_top_k = vocab_size + + # Sparse exact path for top-k/top-p (top-k-first semantics), then scatter to dense. + if 0 < max_top_k < vocab_size: + topk_logits, topk_indices = torch.topk(scaled_logits, k=max_top_k, dim=-1) + if uniform_top_k_value is None or int(uniform_top_k_value) != max_top_k: + ranks = torch.arange(max_top_k, device=device, dtype=torch.int64)[ + None, : + ] + valid = ranks < repeated_top_ks.unsqueeze(1) + topk_logits = topk_logits.masked_fill(~valid, float("-inf")) + + topk_probs = F.softmax(topk_logits, dim=-1) + if need_top_p: + repeated_top_ps = torch.repeat_interleave( + sampling_info.top_ps, draft_token_num, dim=0 + ) + topk_probs = top_p_renorm_prob(topk_probs, repeated_top_ps) + + target_probs = torch.zeros_like(scaled_logits, dtype=topk_probs.dtype) + target_probs.scatter_(1, topk_indices, topk_probs) + sparse_topk_applied = True + + if not sparse_topk_applied: + target_probs = F.softmax(scaled_logits, dim=-1) + if need_top_k: + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ks, draft_token_num, dim=0), + ) + if need_top_p: + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ps, draft_token_num, dim=0), + ) + target_probs = target_probs.view(bs, draft_token_num, -1).contiguous() + draft_probs = torch.zeros_like(target_probs) + + ( + retrieve_index, + retrieve_next_token, + retrieve_next_sibling, + predicts, + accept_index, + accept_token_num, + ) = _get_or_create_chain_verify_buffers( + bs=bs, + draft_token_num=draft_token_num, + device=device, + ) + candidates_i64 = ( + candidates if candidates.dtype == torch.int64 else candidates.to(torch.int64) + ) + tree_speculative_sampling_target_only( + predicts=predicts, + accept_index=accept_index, + accept_token_num=accept_token_num, + candidates=candidates_i64, + retrive_index=retrieve_index, + retrive_next_token=retrieve_next_token, + retrive_next_sibling=retrieve_next_sibling, + uniform_samples=uniform_samples, + uniform_samples_for_final_sampling=uniform_samples_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=threshold_single, + threshold_acc=threshold_acc, + deterministic=True, + ) + + accept_len = accept_token_num + row_ids = torch.arange(bs, dtype=torch.long, device=device) + accept_pos = accept_index[row_ids, accept_len.to(torch.long)].to(torch.long) + bonus = predicts[accept_pos].to(torch.int64) + return accept_len, bonus diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py new file mode 100644 index 000000000000..d6e4fcd159ec --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -0,0 +1,1369 @@ +import logging +import math +from copy import deepcopy +from typing import Optional, Union + +import torch + +from sglang.srt.distributed import get_tp_group +from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.mem_cache.common import get_last_loc +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) +from sglang.srt.server_args import ( + ServerArgs, + get_global_server_args, + set_global_server_args_for_scheduler, +) +from sglang.srt.speculative.dflash_info import DFlashDraftInput, DFlashVerifyInput +from sglang.srt.speculative.dflash_utils import ( + can_dflash_use_fused_qkv_proj, + is_dflash_sampling_verify_available, + parse_dflash_draft_config, + resolve_dflash_verify_mask_policy, +) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func +from sglang.srt.utils import is_cuda + +logger = logging.getLogger(__name__) + +_FusedKVMaterializeHelper = None + + +def _get_fused_kv_materialize_helper(): + global _FusedKVMaterializeHelper + if _FusedKVMaterializeHelper is None: + from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, + ) + + _FusedKVMaterializeHelper = FusedKVMaterializeHelper + return _FusedKVMaterializeHelper + + +class DFlashWorker: + """DFlash speculative decoding worker (spec-v1, tp>=1/pp=1).""" + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.server_args = server_args + self.gpu_id = gpu_id + self.tp_rank = tp_rank + self.dp_rank = dp_rank + self.moe_ep_rank = moe_ep_rank + self.attn_cp_rank = attn_cp_rank + self.moe_dp_rank = moe_dp_rank + self.nccl_port = nccl_port + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.page_size = server_args.page_size + self.draft_window_size: Optional[int] = ( + int(server_args.speculative_dflash_draft_window_size) + if server_args.speculative_dflash_draft_window_size is not None + else None + ) + self.use_compact_draft_cache = self.draft_window_size is not None + self.device = target_worker.device + + self._warned_sampling_fallback = False + self._logged_first_verify = False + + # Draft runner (separate KV cache + attention backend). + # Without draft windowing, the draft worker aliases the target request->token + # mapping and allocation state. With draft windowing enabled, the draft worker + # keeps a private compact req->token table over the same global KV index space, + # so radix-cache/prefix-hit KV remains reusable while draft attention sees only + # the recent window. + target_req_to_token_pool, target_token_to_kv_pool_allocator = ( + target_worker.get_memory_pool() + ) + shared_req_to_token_pool = ( + None if self.use_compact_draft_cache else target_req_to_token_pool + ) + draft_server_args = deepcopy(server_args) + draft_server_args.skip_tokenizer_init = True + draft_backend = draft_server_args.speculative_draft_attention_backend + supported_draft_backends = ("flashinfer", "fa3", "fa4") + if draft_backend is None: + draft_backend, _ = draft_server_args.get_attention_backends() + if draft_backend is None: + draft_backend = "flashinfer" + elif draft_backend == "trtllm_mha": + logger.warning( + "DFLASH draft worker does not support 'trtllm_mha' because the " + "draft path requires non-causal attention. Falling back to " + "'flashinfer'." + ) + draft_backend = "flashinfer" + elif draft_backend not in supported_draft_backends: + logger.warning( + "DFLASH draft worker only supports attention_backend in %s for now, " + "but got %r. Falling back to 'flashinfer'.", + supported_draft_backends, + draft_backend, + ) + draft_backend = "flashinfer" + # Make the draft worker backend explicit and self-contained (no further overrides). + draft_server_args.speculative_draft_attention_backend = None + draft_server_args.prefill_attention_backend = None + draft_server_args.decode_attention_backend = None + draft_server_args.attention_backend = draft_backend + # Keep draft context length aligned with the target. + draft_server_args.context_length = ( + target_worker.model_runner.model_config.context_len + ) + saved_server_args = get_global_server_args() + self.draft_worker = TpModelWorker( + server_args=draft_server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + pp_rank=0, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + dp_rank=dp_rank, + nccl_port=nccl_port, + is_draft_worker=True, + req_to_token_pool=shared_req_to_token_pool, + token_to_kv_pool_allocator=target_token_to_kv_pool_allocator, + ) + set_global_server_args_for_scheduler(saved_server_args) + self.draft_model_runner = self.draft_worker.model_runner + # Keep the same alias that other spec-v2 workers expose. + self.draft_worker.draft_runner = self.draft_model_runner + self.draft_model = self.draft_model_runner.model + draft_config = parse_dflash_draft_config( + draft_hf_config=self.draft_model_runner.model_config.hf_config + ) + if server_args.speculative_num_draft_tokens is None: + # Should not happen (ServerArgs should have inferred it), but keep a fallback. + self.block_size = int(draft_config.resolve_block_size(default=16)) + else: + self.block_size = int(server_args.speculative_num_draft_tokens) + model_block_size = draft_config.block_size + if model_block_size is None: + model_block_size = getattr(self.draft_model, "block_size", None) + if model_block_size is not None and int(model_block_size) != int( + self.block_size + ): + logger.warning( + "DFLASH block size mismatch: using speculative_num_draft_tokens=%s but draft config block_size=%s.", + self.block_size, + model_block_size, + ) + self.speculative_num_draft_tokens = int(self.block_size) + + self._mask_token = draft_config.mask_token + self._mask_token_id_override = draft_config.mask_token_id + self._mask_token_id = self._resolve_mask_token_id( + mask_token=self._mask_token, + mask_token_id=self._mask_token_id_override, + ) + if self.tp_rank == 0: + logger.info( + "Initialized DFLASH draft runner. attention_backend=%s, model=%s, block_size=%s, draft_window_size=%s, compact_cache=%s", + getattr(draft_server_args, "attention_backend", None), + self.draft_model.__class__.__name__, + self.block_size, + self.draft_window_size, + self.use_compact_draft_cache, + ) + logger.info( + "DFLASH draft runner ready. mask_token=%s, mask_token_id=%s, mask_token_id_override=%s", + self._mask_token, + self._mask_token_id, + self._mask_token_id_override, + ) + + self._block_pos_offsets = torch.arange( + self.block_size, device=self.device, dtype=torch.int64 + ) + self._draft_block_ids_buf: Optional[torch.Tensor] = None # [cap_bs, block_size] + self._draft_block_positions_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_tokens_buf: Optional[torch.Tensor] = ( + None # [cap_bs, block_size] + ) + self._draft_block_end_buf: Optional[torch.Tensor] = None # [cap_bs] + self._draft_seq_lens_cpu_buf: Optional[torch.Tensor] = None # [cap_bs] on CPU + self._draft_block_spec_info = DFlashVerifyInput( + draft_token=torch.empty((0,), dtype=torch.long, device=self.device), + positions=torch.empty((0,), dtype=torch.int64, device=self.device), + draft_token_num=int(self.block_size), + custom_mask=None, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + self._draft_greedy_gathered_max_buf: Optional[torch.Tensor] = None + self._draft_greedy_gathered_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_gather_cap: int = 0 + self._draft_greedy_best_rank_buf: Optional[torch.Tensor] = None + self._draft_greedy_rank_index_buf: Optional[torch.Tensor] = None + self._draft_greedy_selected_ids_buf: Optional[torch.Tensor] = None + self._draft_greedy_index_cap: int = 0 + + self._use_fused_kv_materialize = is_cuda() + self._fused_kv_helper: Optional[object] = None + if self._use_fused_kv_materialize: + self._init_fused_kv_helper() + + def _init_fused_kv_helper(self) -> None: + """Initialize the fused KV materialization helper with pre-stacked weights.""" + try: + layers = self.draft_model.layers + fused_disable_reason: Optional[str] = None + + if len(layers) == 0: + fused_disable_reason = "no layers found" + + for layer_idx, layer in enumerate(layers): + attn = layer.self_attn + eligible, reason = can_dflash_use_fused_qkv_proj(attn.qkv_proj) + if not eligible: + fused_disable_reason = f"{reason}: layer={layer_idx}" + break + + # Keep semantics aligned with set_kv_buffer scaling behavior. + k_scale = getattr(attn.attn, "k_scale", None) + v_scale = getattr(attn.attn, "v_scale", None) + if k_scale is not None and not math.isclose(float(k_scale), 1.0): + fused_disable_reason = ( + "non-unit k_scale is not supported for fused KV path: " + f"layer={layer_idx}, k_scale={k_scale}" + ) + break + if v_scale is not None and not math.isclose(float(v_scale), 1.0): + fused_disable_reason = ( + "non-unit v_scale is not supported for fused KV path: " + f"layer={layer_idx}, v_scale={v_scale}" + ) + break + + rope_is_neox_style = bool( + getattr(attn.rotary_emb, "is_neox_style", True) + ) + if not rope_is_neox_style: + fused_disable_reason = ( + "non-neox RoPE is not supported for fused KV path: " + f"layer={layer_idx}, rope_is_neox_style={rope_is_neox_style}" + ) + break + + if fused_disable_reason is not None: + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization disabled: %s", + fused_disable_reason, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + return + + FusedKVMaterializeHelper = _get_fused_kv_materialize_helper() + first_attn = layers[0].self_attn + rotary_emb = first_attn.rotary_emb + + self._fused_kv_helper = FusedKVMaterializeHelper( + layers=layers, + rotary_emb=rotary_emb, + num_kv_heads=first_attn.num_kv_heads, + head_dim=first_attn.head_dim, + device=self.device, + ) + if self.tp_rank == 0: + logger.info( + "DFLASH fused KV materialization enabled. " + "n_layers=%d, num_kv_heads=%d, head_dim=%d", + len(layers), + first_attn.num_kv_heads, + first_attn.head_dim, + ) + except Exception as e: + logger.warning( + "DFLASH fused KV initialization failed, falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + def _ensure_draft_block_buffers(self, bs: int) -> None: + cap = ( + 0 + if self._draft_block_ids_buf is None + else int(self._draft_block_ids_buf.shape[0]) + ) + if cap >= int(bs): + return + + new_cap = max(int(bs), cap * 2 if cap > 0 else int(bs)) + device = self.device + block_size = int(self.block_size) + self._draft_block_ids_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_positions_buf = torch.empty( + (new_cap, block_size), dtype=torch.int64, device=device + ) + self._draft_block_tokens_buf = torch.empty( + (new_cap, block_size), dtype=torch.long, device=device + ) + self._draft_block_end_buf = torch.empty( + (new_cap,), dtype=torch.int32, device=device + ) + self._draft_seq_lens_cpu_buf = torch.empty( + (new_cap,), dtype=torch.int32, device="cpu" + ) + + def __getattr__(self, name): + # Delegate anything not implemented yet to the target worker. + return getattr(self.target_worker, name) + + def clear_cache_pool(self): + # The target worker owns the shared KV allocator/cache. For the compact + # sliding-window path, the draft req->token view is rebuilt from committed + # target state before each draft forward, so there is nothing persistent + # to flush here. + pass + + def _gather_req_to_token_masked( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + pos2d: torch.Tensor, + mask: torch.Tensor, + context: str, + ) -> torch.Tensor: + if pos2d.ndim != 2: + raise RuntimeError( + f"{context} expected 2D positions, got shape={tuple(pos2d.shape)}." + ) + if mask.shape != pos2d.shape: + raise RuntimeError( + f"{context} mask/position shape mismatch: {tuple(mask.shape)} vs {tuple(pos2d.shape)}." + ) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + if mask.dtype != torch.bool: + mask = mask.to(torch.bool) + + table_width = int(req_to_token.shape[1]) + if table_width <= 0: + if bool(mask.any().item()): + raise RuntimeError( + f"{context} req_to_token table is empty but gather mask is non-empty." + ) + return torch.empty((0,), dtype=torch.int64, device=self.device) + + # Only the masked-off rectangular padding can be out of range in the normal + # ragged-batch case. Replace those don't-care columns with a valid in-range + # position before the gather so the kernel only sees real positions. + safe_pos2d = pos2d.masked_fill(~mask, 0) + return req_to_token[req_pool_indices[:, None], safe_pos2d][mask].to(torch.int64) + + def _gather_req_to_token_segments( + self, + *, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + start: torch.Tensor | None, + lengths: torch.Tensor, + ) -> torch.Tensor: + lengths = lengths.to(torch.int64) + if lengths.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + max_len = int(lengths.max().item()) + if max_len <= 0: + return torch.empty((0,), dtype=torch.int64, device=self.device) + + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + offsets = torch.arange( + max_len, device=self.device, dtype=torch.int64 + ).unsqueeze(0) + if start is None: + pos2d = offsets.expand(req_pool_indices.shape[0], -1) + else: + pos2d = start.to(torch.int64).unsqueeze(1) + offsets + mask = offsets < lengths.unsqueeze(1) + return self._gather_req_to_token_masked( + req_to_token=req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH req_to_token segment gather", + ) + + def _compute_compact_draft_seq_lens(self, seq_lens: torch.Tensor) -> torch.Tensor: + assert self.draft_window_size is not None + visible_lens = torch.clamp( + seq_lens.to(dtype=torch.int32, device=self.device), + max=int(self.draft_window_size), + ) + if self.page_size <= 1: + return visible_lens + + # Paged FA backends derive the page table from local token positions, so the + # compact suffix must start on a page boundary. Keep up to page_size - 1 extra + # tokens on the left to preserve valid local page structure. + seq_lens_i64 = seq_lens.to(torch.int64) + visible_lens_i64 = visible_lens.to(torch.int64) + visible_start = seq_lens_i64 - visible_lens_i64 + aligned_start = visible_start - torch.remainder(visible_start, self.page_size) + return (seq_lens_i64 - aligned_start).to(torch.int32) + + def _resolve_mask_token_id( + self, *, mask_token: str, mask_token_id: Optional[int] = None + ) -> int: + if not isinstance(mask_token, str) or not mask_token: + raise ValueError( + f"DFLASH mask_token must be a non-empty string, got {mask_token!r}." + ) + + vocab_size = int(self.target_worker.model_runner.model_config.vocab_size) + if mask_token_id is not None: + resolved_id = int(mask_token_id) + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is not None: + token_id_from_vocab = tokenizer.get_vocab().get(mask_token, None) + if ( + token_id_from_vocab is not None + and int(token_id_from_vocab) != resolved_id + ): + raise ValueError( + "DFLASH config mismatch: dflash_config.mask_token_id conflicts with tokenizer vocab id " + f"for dflash_config.mask_token. mask_token={mask_token!r}, " + f"mask_token_id={resolved_id}, tokenizer_vocab_id={int(token_id_from_vocab)}." + ) + return resolved_id + + tokenizer = getattr(self.target_worker, "tokenizer", None) + if tokenizer is None: + raise RuntimeError( + "DFLASH requires tokenizer initialization when dflash_config.mask_token_id is not set " + "(skip_tokenizer_init is not supported in this mode)." + ) + + resolved_id = None + if getattr(tokenizer, "mask_token", None) == mask_token: + resolved_id = getattr(tokenizer, "mask_token_id", None) + + if resolved_id is None: + # Prefer checking the explicit vocab mapping first. + vocab = tokenizer.get_vocab() + resolved_id = vocab.get(mask_token, None) + + if resolved_id is None: + # Mirror the reference DFlash HF demo by adding the mask token to the tokenizer. + # This is safe only when the resulting id stays within the target model vocab size. + added = tokenizer.add_special_tokens({"mask_token": mask_token}) + resolved_id = getattr(tokenizer, "mask_token_id", None) + if resolved_id is None: + resolved_id = tokenizer.convert_tokens_to_ids(mask_token) + + if added and self.tp_rank == 0: + logger.info( + "Added DFLASH mask token to tokenizer. token=%s, mask_token_id=%s, tokenizer_len=%s, model_vocab_size=%s", + mask_token, + resolved_id, + len(tokenizer), + vocab_size, + ) + + if resolved_id is None or int(resolved_id) < 0: + raise ValueError( + "DFLASH requires resolving a mask token id, but it could not be resolved. " + f"mask_token={mask_token!r}." + ) + + if resolved_id >= vocab_size: + raise ValueError( + "DFLASH mask_token_id is outside the target vocab size. " + f"mask_token_id={resolved_id}, vocab_size={vocab_size}. " + f"This likely means mask_token={mask_token!r} requires vocab expansion beyond the model's embedding size. " + "SGLang does not support resizing target embeddings for DFLASH yet." + ) + + return int(resolved_id) + + def _prepare_for_speculative_decoding( + self, batch: ScheduleBatch, draft_input: DFlashDraftInput + ): + if batch.forward_mode.is_extend() or batch.forward_mode.is_idle(): + return + + if batch.has_grammar: + raise RuntimeError( + "Invariant broken: DFLASH batch has grammar constraints, but scheduler should have rejected this request." + ) + if batch.sampling_info is not None and not batch.sampling_info.is_all_greedy: + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): + logger.warning( + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." + ) + self._warned_sampling_fallback = True + + bs = batch.batch_size() + + # --- 1) Append any newly committed tokens into the draft KV cache. + self._append_target_hidden_to_draft_kv(batch, draft_input) + + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if ( + lm_head is None + or not hasattr(lm_head, "weight") + or not hasattr(lm_head, "shard_indices") + ): + raise RuntimeError( + "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " + "`shard_indices` attributes." + ) + + # --- 2) Draft a non-causal block with the draft model. + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id.to(torch.long)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + # For spec-v1, the draft KV cache is always materialized before drafting the + # next block. `target_prefix_lens` stay absolute for RoPE; `draft_prefix_lens` + # are the logical resident lengths in the draft-local cache. + target_prefix_lens = batch.seq_lens # int32, device + draft_prefix_lens = draft_input.draft_seq_lens + if draft_prefix_lens.dtype != torch.int32: + draft_prefix_lens = draft_prefix_lens.to(torch.int32) + if draft_prefix_lens.device != self.device: + draft_prefix_lens = draft_prefix_lens.to(self.device, non_blocking=True) + + positions_2d = self._draft_block_positions_buf[:bs] + torch.add( + target_prefix_lens.unsqueeze(1), self._block_pos_offsets, out=positions_2d + ) + positions = positions_2d.reshape(-1) + + block_start = draft_prefix_lens + block_end = self._draft_block_end_buf[:bs] + torch.add(block_start, int(self.block_size), out=block_end) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) + allocator = self.draft_model_runner.token_to_kv_pool_allocator + token_to_kv_pool_state_backup = allocator.backup_state() + try: + if self.page_size == 1: + block_cache_loc = allocator.alloc(bs * self.block_size) + else: + block_end_cpu = seq_lens_cpu + int(self.block_size) + last_loc = get_last_loc( + self.draft_model_runner.req_to_token_pool.req_to_token, + batch.req_pool_indices, + block_start, + ) + block_cache_loc = allocator.alloc_extend( + block_start, + seq_lens_cpu, + block_end, + block_end_cpu, + last_loc, + bs * self.block_size, + ) + if block_cache_loc is None: + raise RuntimeError( + f"DFLASH draft OOM when allocating {bs * self.block_size} block tokens." + ) + + assign_req_to_token_pool_func( + batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + block_start, + block_end, + block_cache_loc, + bs, + ) + + # Use TARGET_VERIFY mode (cuda-graphable) to run a fixed-size draft block. + # In this mode, `seq_lens` stores the prefix lengths; attention backends + # derive kv_len by adding `draft_token_num`. + draft_spec_info = self._draft_block_spec_info + seq_lens = draft_prefix_lens + seq_lens_sum = int(draft_prefix_lens.sum().item()) + forward_batch = ForwardBatch( + forward_mode=ForwardMode.TARGET_VERIFY, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=batch.req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=block_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens_cpu, + positions=positions, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, + input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=draft_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + with torch.inference_mode(): + draft_logits_output = self.draft_model_runner.forward( + forward_batch + ).logits_output + finally: + # Drop the speculative block from the shared allocator (EAGLE3-style). + allocator.restore_state(token_to_kv_pool_state_backup) + + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") + draft_hidden = draft_hidden.view(bs, self.block_size, -1) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, self.block_size - 1) + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + positions = positions_2d.reshape(-1) + + verify_input = DFlashVerifyInput( + draft_token=draft_tokens.reshape(-1), + positions=positions, + draft_token_num=self.block_size, + ) + _, build_custom_mask = resolve_dflash_verify_mask_policy( + self.model_runner.attn_backend + ) + verify_input.prepare_for_verify( + batch, + self.page_size, + build_custom_mask=build_custom_mask, + ) + + batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + batch.spec_info = verify_input + batch.return_hidden_states = False + + def _greedy_sample_from_vocab_parallel_head( + self, + *, + hidden_states: torch.Tensor, + lm_head, + chunk_size: int = 256, + ) -> torch.Tensor: + """Greedy argmax over the target LM head in a TP-safe way. + + We cannot materialize full logits for large vocabularies efficiently, and with + TP>1 each rank only owns a shard of the LM head weight. This computes the + per-rank max, gathers candidates across TP ranks, and selects the global max. + """ + + if hidden_states.numel() == 0: + return torch.empty((0,), dtype=torch.long, device=hidden_states.device) + + tp_group = get_tp_group() + tp_size = int(tp_group.world_size) + + if not hasattr(lm_head, "weight") or not hasattr(lm_head, "shard_indices"): + raise RuntimeError( + "DFLASH greedy sampling requires a vocab-parallel head with `weight` and `shard_indices`." + ) + + shard = lm_head.shard_indices + weight = lm_head.weight # [local_vocab_padded, hidden] + weight_dtype = weight.dtype + + # Valid ranges in the local shard (excluding padding): + # base vocab: [0, num_org) + # added vocab: [num_org_padded, num_org_padded + num_added) + num_org = int(shard.num_org_elements) + num_org_padded = int(shard.num_org_elements_padded) + num_added = int(shard.num_added_elements) + org_vocab_start = int(shard.org_vocab_start_index) + added_vocab_start = int(shard.added_vocab_start_index) + + num_tokens = int(hidden_states.shape[0]) + out_token_ids = torch.empty( + (num_tokens,), dtype=torch.long, device=hidden_states.device + ) + + def _cast_hs(x: torch.Tensor) -> torch.Tensor: + return x if x.dtype == weight_dtype else x.to(weight_dtype) + + # Fast path (common): single-rank greedy sampling over the base vocab shard. + # Avoids extra max/id bookkeeping that is only needed for TP sync or added vocab. + if tp_size == 1 and num_added == 0: + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + out_token_ids[start:end] = ( + torch.argmax(base_logits, dim=-1).to(torch.long) + + org_vocab_start + ) + else: + out_token_ids[start:end] = 0 + return out_token_ids + + for start in range(0, num_tokens, int(chunk_size)): + end = min(num_tokens, start + int(chunk_size)) + hs = _cast_hs(hidden_states[start:end]) + chunk_len = int(hs.shape[0]) + + # Base vocab logits. + if num_org > 0: + base_logits = torch.matmul(hs, weight[:num_org].T) + local_max, local_arg = torch.max(base_logits, dim=-1) + else: + local_max = torch.full( + (chunk_len,), + torch.finfo(weight_dtype).min, + dtype=weight_dtype, + device=hs.device, + ) + local_arg = torch.zeros( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + + # Added vocab logits (e.g., LoRA-added embeddings), if present. + if num_added > 0: + added_slice_start = num_org_padded + added_slice_end = num_org_padded + num_added + added_logits = torch.matmul( + hs, weight[added_slice_start:added_slice_end].T + ) + added_max, added_arg = torch.max(added_logits, dim=-1) + use_added = added_max > local_max + local_max = torch.where(use_added, added_max, local_max) + # For base/added conversion below, keep local_arg expressed in the full local + # weight index space (base + padding + added), matching `lm_head.weight`. + local_arg = torch.where( + use_added, added_arg.to(local_arg.dtype) + num_org_padded, local_arg + ) + + # Convert local argmax indices to global token ids. + if num_added == 0: + local_arg.add_(org_vocab_start) + global_ids = local_arg + else: + global_ids = torch.empty( + (chunk_len,), dtype=torch.int64, device=hs.device + ) + is_base = local_arg < num_org + global_ids[is_base] = org_vocab_start + local_arg[is_base] + global_ids[~is_base] = added_vocab_start + ( + local_arg[~is_base] - num_org_padded + ) + + if tp_size == 1: + out_token_ids[start:end] = global_ids.to(torch.long) + continue + + # Gather per-rank maxima and associated global ids, then select the global max. + needed = tp_size * chunk_len + chunk_cap = int(chunk_size) + if ( + self._draft_greedy_gather_cap < needed + or self._draft_greedy_gathered_max_buf is None + or self._draft_greedy_gathered_ids_buf is None + or self._draft_greedy_gathered_max_buf.dtype != local_max.dtype + or self._draft_greedy_gathered_max_buf.device != hs.device + ): + # Allocate enough space for the max chunk size to avoid reallocations. + cap = tp_size * chunk_cap + self._draft_greedy_gathered_max_buf = torch.empty( + (cap,), dtype=local_max.dtype, device=hs.device + ) + self._draft_greedy_gathered_ids_buf = torch.empty( + (cap,), dtype=global_ids.dtype, device=hs.device + ) + self._draft_greedy_gather_cap = cap + + if ( + self._draft_greedy_index_cap < chunk_len + or self._draft_greedy_best_rank_buf is None + or self._draft_greedy_rank_index_buf is None + or self._draft_greedy_selected_ids_buf is None + or self._draft_greedy_best_rank_buf.device != hs.device + or self._draft_greedy_selected_ids_buf.device != hs.device + ): + self._draft_greedy_best_rank_buf = torch.empty( + (chunk_cap,), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_rank_index_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_selected_ids_buf = torch.empty( + (1, chunk_cap), dtype=torch.int64, device=hs.device + ) + self._draft_greedy_index_cap = chunk_cap + + gathered_max = self._draft_greedy_gathered_max_buf[:needed] + gathered_ids = self._draft_greedy_gathered_ids_buf[:needed] + + tp_group.all_gather_into_tensor(gathered_max, local_max.contiguous()) + tp_group.all_gather_into_tensor(gathered_ids, global_ids.contiguous()) + gathered_max = gathered_max.view(tp_size, chunk_len) + gathered_ids = gathered_ids.view(tp_size, chunk_len) + + best_rank = self._draft_greedy_best_rank_buf[:chunk_len] + torch.argmax(gathered_max, dim=0, out=best_rank) + + rank_index = self._draft_greedy_rank_index_buf[:, :chunk_len] + rank_index[0].copy_(best_rank) + selected_ids = self._draft_greedy_selected_ids_buf[:, :chunk_len] + torch.gather(gathered_ids, 0, rank_index, out=selected_ids) + out_token_ids[start:end].copy_(selected_ids.view(-1)) + + return out_token_ids + + def _append_target_hidden_to_draft_kv( + self, + batch: ScheduleBatch, + draft_input: DFlashDraftInput, + ) -> None: + """Materialize the target hidden-state features into the draft KV cache. + + This must be run before exposing new tokens to radix cache (prefix hits), otherwise + another request could reuse target KV indices without having draft KV values. + """ + + bs = batch.batch_size() + device = self.model_runner.device + + if draft_input.target_hidden is None: + raise RuntimeError( + "DFLASH draft state missing target_hidden context features." + ) + if draft_input.ctx_lens.numel() != bs: + raise RuntimeError( + f"DFLASH ctx_lens length mismatch: got {draft_input.ctx_lens.numel()} for bs={bs}." + ) + if draft_input.draft_seq_lens.numel() != bs: + raise RuntimeError( + f"DFLASH draft_seq_lens length mismatch: got {draft_input.draft_seq_lens.numel()} for bs={bs}." + ) + + total_ctx = int(draft_input.target_hidden.shape[0]) + if total_ctx <= 0: + draft_input.ctx_lens = torch.zeros_like(draft_input.ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + return + + target_req_to_token = batch.req_to_token_pool.req_to_token + draft_req_to_token = self.draft_model_runner.req_to_token_pool.req_to_token + + req_pool_indices = batch.req_pool_indices + if req_pool_indices.dtype != torch.int64: + req_pool_indices = req_pool_indices.to(torch.int64) + + ctx_lens = draft_input.ctx_lens + if ctx_lens.dtype != torch.int32: + ctx_lens = ctx_lens.to(torch.int32) + if ctx_lens.device != device: + ctx_lens = ctx_lens.to(device, non_blocking=True) + ctx_start = batch.seq_lens.to(torch.int64) - ctx_lens.to(torch.int64) + + if bs == 1: + # Fast path for single request. + max_ctx = int(total_ctx) + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + pos2d = ctx_start[:, None] + r[None, :] # [1, ctx] + cache2d = target_req_to_token[req_pool_indices[:, None], pos2d] # [1, ctx] + ctx_cache_loc = cache2d.reshape(-1).to(torch.int64) # [ctx] + ctx_positions = pos2d.reshape(-1) # [ctx] + else: + # In decode mode, ctx_lens <= block_size so we can skip the .item() sync. + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + max_ctx = int(ctx_lens.max().item()) + else: + max_ctx = int(self.block_size) + if max_ctx <= 0: + raise RuntimeError(f"DFLASH invalid max_ctx={max_ctx} for KV append.") + + if max_ctx <= self._block_pos_offsets.numel(): + r = self._block_pos_offsets[:max_ctx] + else: + r = torch.arange(max_ctx, device=device, dtype=torch.int64) + r = r[None, :] # [1, max_ctx] + pos2d = ctx_start[:, None] + r # [bs, max_ctx] + mask = r < ctx_lens[:, None] + + # Batched gather of cache locations and positions. + ctx_cache_loc = self._gather_req_to_token_masked( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + pos2d=pos2d, + mask=mask, + context="DFLASH target hidden KV append", + ) # [sum(ctx_lens)] + ctx_positions = pos2d[mask] # [sum(ctx_lens)] + + with torch.inference_mode(): + ctx_hidden = self.draft_model.project_target_hidden( + draft_input.target_hidden + ) # [sum(ctx), hidden] + if ctx_hidden.shape[0] != ctx_cache_loc.numel(): + raise RuntimeError( + f"DFLASH ctx_hidden/cache_loc mismatch: {ctx_hidden.shape[0]} vs {ctx_cache_loc.numel()}." + ) + + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + except Exception as e: + logger.warning( + "DFLASH fused KV append failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + else: + self._append_target_hidden_sequential( + ctx_hidden, ctx_positions, ctx_cache_loc + ) + + if self.use_compact_draft_cache: + new_draft_seq_lens = self._compute_compact_draft_seq_lens(batch.seq_lens) + suffix_start = batch.seq_lens.to(torch.int64) - new_draft_seq_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=target_req_to_token, + req_pool_indices=req_pool_indices, + start=suffix_start, + lengths=new_draft_seq_lens, + ) + assign_req_to_token_pool_func( + batch.req_pool_indices, + draft_req_to_token, + torch.zeros_like(new_draft_seq_lens), + new_draft_seq_lens, + suffix_cache_loc, + bs, + ) + draft_input.draft_seq_lens = new_draft_seq_lens + else: + draft_input.draft_seq_lens = batch.seq_lens.to(dtype=torch.int32) + draft_input.ctx_lens = torch.zeros_like(ctx_lens) + draft_input.target_hidden = draft_input.target_hidden[:0] + + def _append_target_hidden_to_draft_kv_by_loc( + self, + *, + target_hidden: torch.Tensor, + cache_loc: torch.Tensor, + positions: torch.Tensor, + mask_valid: Optional[torch.Tensor] = None, + ) -> None: + """Materialize target context features into the draft KV cache at explicit slots. + + This helper avoids boolean-index packing for the spec-v2 overlap path, which + already computes explicit cache locations from over-allocated req_to_token state. + """ + if target_hidden is None: + raise RuntimeError("DFLASH missing target hidden context features.") + if target_hidden.numel() == 0: + return + if target_hidden.ndim != 2: + raise ValueError( + "DFLASH target_hidden must be 2D, " + f"got shape={tuple(target_hidden.shape)}." + ) + + if cache_loc.ndim != 1: + raise ValueError( + f"DFLASH cache_loc must be 1D, got shape={tuple(cache_loc.shape)}." + ) + if positions.ndim != 1: + raise ValueError( + f"DFLASH positions must be 1D, got shape={tuple(positions.shape)}." + ) + num_tokens = int(target_hidden.shape[0]) + if int(cache_loc.numel()) != num_tokens: + raise ValueError( + "DFLASH cache_loc length mismatch: " + f"cache_loc={int(cache_loc.numel())}, target_hidden={num_tokens}." + ) + if int(positions.numel()) != num_tokens: + raise ValueError( + "DFLASH positions length mismatch: " + f"positions={int(positions.numel())}, target_hidden={num_tokens}." + ) + + device = self.model_runner.device + if cache_loc.device != device: + cache_loc = cache_loc.to(device, non_blocking=True) + if positions.device != device: + positions = positions.to(device, non_blocking=True) + if target_hidden.device != device: + target_hidden = target_hidden.to(device, non_blocking=True) + + if cache_loc.dtype != torch.int64: + cache_loc = cache_loc.to(torch.int64) + if positions.dtype != torch.int64: + positions = positions.to(torch.int64) + + mask_3d: Optional[torch.Tensor] = None + if mask_valid is not None: + if mask_valid.ndim != 1: + raise ValueError( + "DFLASH mask_valid must be 1D, " + f"got shape={tuple(mask_valid.shape)}." + ) + if int(mask_valid.numel()) != num_tokens: + raise ValueError( + "DFLASH mask_valid length mismatch: " + f"mask_valid={int(mask_valid.numel())}, target_hidden={num_tokens}." + ) + if mask_valid.device != device: + mask_valid = mask_valid.to(device, non_blocking=True) + if mask_valid.dtype != torch.bool: + mask_valid = mask_valid.to(torch.bool) + mask_3d = mask_valid.view(-1, 1, 1) + + with torch.inference_mode(): + ctx_hidden = self.draft_model.project_target_hidden(target_hidden) + + if mask_3d is None: + if self._use_fused_kv_materialize and self._fused_kv_helper is not None: + try: + self._append_target_hidden_fused( + ctx_hidden=ctx_hidden, + ctx_positions=positions, + ctx_cache_loc=cache_loc, + ) + return + except Exception as e: + logger.warning( + "DFLASH fused KV append-by-loc failed; falling back to sequential path: %s", + e, + ) + self._use_fused_kv_materialize = False + self._fused_kv_helper = None + + self._append_target_hidden_sequential( + ctx_hidden=ctx_hidden, + ctx_positions=positions, + ctx_cache_loc=cache_loc, + ) + return + + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + + if mask_3d is not None: + k = k.masked_fill(~mask_3d, 0) + v = v.masked_fill(~mask_3d, 0) + + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + def _append_target_hidden_sequential( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + for layer in self.draft_model.layers: + attn = layer.self_attn + k, v = attn.kv_proj_only(ctx_hidden) + k = attn.apply_k_norm(k) + k = attn.apply_k_rope(ctx_positions, k) + k = k.view(-1, attn.num_kv_heads, attn.head_dim) + v = v.view(-1, attn.num_kv_heads, attn.head_dim) + self.draft_model_runner.token_to_kv_pool.set_kv_buffer( + attn.attn, + ctx_cache_loc, + k, + v, + attn.attn.k_scale, + attn.attn.v_scale, + ) + + def _append_target_hidden_fused( + self, + ctx_hidden: torch.Tensor, + ctx_positions: torch.Tensor, + ctx_cache_loc: torch.Tensor, + ) -> None: + """Fused KV materialization using batched projection + Triton kernel.""" + token_to_kv_pool = self.draft_model_runner.token_to_kv_pool + layers = self.draft_model.layers + + def _write_layer_kv( + layer_idx: int, cache_k: torch.Tensor, cache_v: torch.Tensor + ) -> None: + attn = layers[layer_idx].self_attn.attn + token_to_kv_pool.set_kv_buffer( + attn, + ctx_cache_loc, + cache_k, + cache_v, + attn.k_scale, + attn.v_scale, + ) + + self._fused_kv_helper.materialize( + ctx_hidden=ctx_hidden, + positions=ctx_positions, + write_layer_kv=_write_layer_kv, + ) + + def _update_target_mamba_state_after_verify( + self, + *, + batch: ScheduleBatch, + seq_lens_pre_verify: torch.Tensor, + commit_lens: torch.Tensor, + ) -> None: + """Commit Mamba intermediate states for accepted verify steps. + + During TARGET_VERIFY, Mamba kernels run with `disable_state_update=True` and + cache per-step intermediate states. After acceptance, we need to commit the + state corresponding to each request's last accepted step. + """ + attn_backend = self.target_worker.model_runner.attn_backend + if not hasattr(attn_backend, "update_mamba_state_after_mtp_verify"): + return + + accepted_steps = commit_lens.to(torch.int64) - 1 + mamba_steps_to_track = None + + if batch.mamba_track_indices is not None: + mamba_track_interval = self.server_args.mamba_track_interval + to_track_mask = ( + seq_lens_pre_verify // mamba_track_interval + != batch.seq_lens // mamba_track_interval + ) + tracking_point = ( + batch.seq_lens // mamba_track_interval * mamba_track_interval + ) + to_track_ith = torch.clamp(tracking_point - seq_lens_pre_verify - 1, min=0) + can_track_mask = to_track_mask & ( + to_track_ith < commit_lens.to(to_track_ith.dtype) + ) + mamba_steps_to_track = torch.where( + can_track_mask, + to_track_ith.to(torch.int64), + torch.full_like(to_track_ith, -1, dtype=torch.int64), + ) + + attn_backend.update_mamba_state_after_mtp_verify( + accepted_steps=accepted_steps, + mamba_track_indices=batch.mamba_track_indices, + mamba_steps_to_track=mamba_steps_to_track, + model=self.target_worker.model_runner.model, + ) + + def forward_batch_generation( + self, + batch: Union[ScheduleBatch, ModelWorkerBatch], + **kwargs, + ) -> GenerationBatchResult: + if getattr(batch, "return_logprob", False): + raise RuntimeError( + "Invariant broken: DFLASH batch requested return_logprob, but scheduler should have rejected this request." + ) + + if isinstance(batch, ModelWorkerBatch): + # Should not happen for spec-v1 (non-overlap) scheduling, but keep a sane fallback. + return self.target_worker.forward_batch_generation(batch, **kwargs) + + if batch.forward_mode.is_extend() or batch.is_extend_in_batch: + model_worker_batch = batch.get_model_worker_batch() + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + logits_output, next_token_ids = ( + batch_result.logits_output, + batch_result.next_token_ids, + ) + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, but got None." + ) + + # Materialize the prompt tokens into the draft KV cache immediately. This is required + # for radix cache support, since the scheduler may update radix after prefill returns. + device = next_token_ids.device + + def _to_int32_device_tensor(x, *, device=device): + if isinstance(x, torch.Tensor): + if x.device != device: + x = x.to(device, non_blocking=True) + return x if x.dtype == torch.int32 else x.to(torch.int32) + return torch.tensor(x, dtype=torch.int32, device=device) + + extend_seq_lens = _to_int32_device_tensor( + model_worker_batch.extend_seq_lens + ) + draft_input = DFlashDraftInput( + verified_id=next_token_ids.to(torch.int64), + target_hidden=logits_output.hidden_states, + ctx_lens=extend_seq_lens, + draft_seq_lens=( + torch.zeros_like(extend_seq_lens) + if self.use_compact_draft_cache + else _to_int32_device_tensor(model_worker_batch.extend_prefix_lens) + ), + ) + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + num_accepted_tokens=0, + can_run_cuda_graph=batch_result.can_run_cuda_graph, + ) + + # Decode / target-verify stage. + draft_input = batch.spec_info + if not isinstance(draft_input, DFlashDraftInput): + raise RuntimeError( + "DFLASH decode requires DFlashDraftInput state on the running batch. " + "This usually means the request did not complete the prefill stage." + ) + + self._prepare_for_speculative_decoding(batch, draft_input) + + model_worker_batch = batch.get_model_worker_batch() + assert model_worker_batch.forward_mode.is_target_verify() + verify_input = model_worker_batch.spec_info + assert isinstance(verify_input, DFlashVerifyInput) + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + batch.seq_lens.clone() if need_mamba_verify_commit else None + ) + + batch_result = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True, **kwargs + ) + logits_output, can_run_cuda_graph = ( + batch_result.logits_output, + batch_result.can_run_cuda_graph, + ) + + ( + new_verified_id, + commit_lens, + next_target_hidden, + accept_length_per_req_cpu, + ) = verify_input.verify( + batch=batch, + logits_output=logits_output, + page_size=self.page_size, + ) + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) + + # Update draft state for the next iteration. Also materialize the committed verify tokens + # into the draft KV cache immediately so radix cache entries are safe to reuse. + draft_input.verified_id = new_verified_id + draft_input.target_hidden = next_target_hidden + draft_input.ctx_lens = commit_lens + self._append_target_hidden_to_draft_kv(batch, draft_input) + batch.spec_info = draft_input + batch.forward_mode = ForwardMode.DECODE + + num_accepted_tokens = sum(accept_length_per_req_cpu) + if not self._logged_first_verify and self.tp_rank == 0: + logger.info( + "DFLASH verify completed. accept_length_per_req=%s", + accept_length_per_req_cpu, + ) + self._logged_first_verify = True + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=new_verified_id, + num_accepted_tokens=num_accepted_tokens, + accept_length_per_req_cpu=accept_length_per_req_cpu, + can_run_cuda_graph=can_run_cuda_graph, + ) diff --git a/python/sglang/srt/speculative/dflash_worker_v2.py b/python/sglang/srt/speculative/dflash_worker_v2.py new file mode 100644 index 000000000000..ff19ca51bcf2 --- /dev/null +++ b/python/sglang/srt/speculative/dflash_worker_v2.py @@ -0,0 +1,493 @@ +import logging +from typing import Optional + +import torch + +from sglang.srt.managers.schedule_batch import ModelWorkerBatch +from sglang.srt.managers.scheduler import GenerationBatchResult +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, + compute_position, +) +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.dflash_info import DFlashVerifyInput +from sglang.srt.speculative.dflash_info_v2 import DFlashDraftInputV2 +from sglang.srt.speculative.dflash_utils import ( + apply_dflash_verify_logits_adjustments, + compute_dflash_accept_len_and_bonus, + compute_dflash_sampling_accept_len_and_bonus, + is_dflash_sampling_verify_available, +) +from sglang.srt.speculative.dflash_worker import DFlashWorker +from sglang.srt.speculative.eagle_info_v2 import assign_extend_cache_locs_func +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func + +logger = logging.getLogger(__name__) + + +class DFlashWorkerV2(DFlashWorker): + """DFLASH speculative decoding worker (spec-v2 overlap scheduling). + + This is intentionally implemented as a *separate* worker from the existing + spec-v1 `DFlashWorker` (non-overlap), to keep the v1 path stable and to + minimize risk while bringing up overlap scheduling. + """ + + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + attn_cp_rank: int, + moe_dp_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + super().__init__( + server_args=server_args, + gpu_id=gpu_id, + tp_rank=tp_rank, + dp_rank=dp_rank, + moe_ep_rank=moe_ep_rank, + attn_cp_rank=attn_cp_rank, + moe_dp_rank=moe_dp_rank, + nccl_port=nccl_port, + target_worker=target_worker, + ) + + def _validate_phase1_sampling_support( + self, model_worker_batch: ModelWorkerBatch + ) -> None: + sampling_info = model_worker_batch.sampling_info + if sampling_info is None or sampling_info.is_all_greedy: + return + + if ( + not is_dflash_sampling_verify_available() + and not self._warned_sampling_fallback + and self.tp_rank == 0 + ): + logger.warning( + "DFLASH non-greedy verification is unavailable on this build/device; " + "falling back to greedy argmax verification." + ) + self._warned_sampling_fallback = True + + def _make_next_draft_input_prefill( + self, + *, + verified_id: torch.Tensor, + seq_lens: torch.Tensor, + verify_done: Optional[torch.cuda.Event] = None, + ) -> DFlashDraftInputV2: + bs = int(seq_lens.numel()) + device = verified_id.device + return DFlashDraftInputV2( + topk_p=torch.ones((bs, 1), device=device, dtype=torch.float32), + topk_index=torch.zeros((bs, 1), device=device, dtype=torch.int64), + verified_id=verified_id.to(dtype=torch.int32), + new_seq_lens=seq_lens.to(dtype=torch.int32), + hidden_states=torch.empty((bs, 1), device=device, dtype=torch.float16), + verify_done=verify_done, + ) + + def _make_next_draft_input_decode( + self, + *, + verified_id: torch.Tensor, + new_seq_lens: torch.Tensor, + verify_done: Optional[torch.cuda.Event] = None, + ) -> DFlashDraftInputV2: + bs = int(new_seq_lens.numel()) + device = verified_id.device + return DFlashDraftInputV2( + topk_p=torch.ones((bs, 1), device=device, dtype=torch.float32), + topk_index=torch.zeros((bs, 1), device=device, dtype=torch.int64), + verified_id=verified_id.to(dtype=torch.int32), + new_seq_lens=new_seq_lens.to(dtype=torch.int32), + hidden_states=torch.empty((bs, 1), device=device, dtype=torch.float16), + verify_done=verify_done, + ) + + def forward_batch_generation( + self, + model_worker_batch: ModelWorkerBatch, + **kwargs, + ) -> GenerationBatchResult: + if getattr(model_worker_batch, "return_logprob", False): + raise ValueError( + "DFLASH speculative decoding does not support return_logprob yet." + ) + self._validate_phase1_sampling_support(model_worker_batch) + + if ( + model_worker_batch.forward_mode.is_extend() + or model_worker_batch.is_extend_in_batch + ): + # Target prefill: capture DFlash aux hidden states for prompt tokens. + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + batch_output = self.target_worker.forward_batch_generation( + model_worker_batch, **kwargs + ) + + logits_output, next_token_ids = ( + batch_output.logits_output, + batch_output.next_token_ids, + ) + + if logits_output.hidden_states is None: + raise RuntimeError( + "DFLASH requires target aux hidden capture for prefill, but got None. " + "Make sure the target model has DFlash layers-to-capture configured." + ) + + if ( + model_worker_batch.extend_seq_lens is None + or model_worker_batch.extend_prefix_lens is None + ): + raise RuntimeError( + "DFLASH expected extend_seq_lens / extend_prefix_lens to be populated in extend mode, " + "but got None." + ) + + # Materialize prompt tokens into the draft KV cache immediately. This is required + # for radix cache safety (the scheduler may update radix after prefill returns). + device = next_token_ids.device + ctx_lens = torch.tensor( + model_worker_batch.extend_seq_lens, dtype=torch.int32, device=device + ) + draft_seq_lens = torch.tensor( + model_worker_batch.extend_prefix_lens, dtype=torch.int32, device=device + ) + + if model_worker_batch.out_cache_loc is None: + raise RuntimeError( + "DFLASH prefill expected out_cache_loc, but got None." + ) + positions, _ = compute_position( + self.model_runner.server_args.attention_backend, + draft_seq_lens, + ctx_lens, + int(sum(model_worker_batch.extend_seq_lens)), + ) + self._append_target_hidden_to_draft_kv_by_loc( + target_hidden=logits_output.hidden_states, + cache_loc=model_worker_batch.out_cache_loc, + positions=positions, + ) + + # Avoid copying large hidden-state buffers to CPU in overlap scheduling. + logits_output.hidden_states = None + + batch_output.next_draft_input = self._make_next_draft_input_prefill( + verified_id=next_token_ids, + seq_lens=model_worker_batch.seq_lens, + ) + verify_done = torch.get_device_module(device).Event() + verify_done.record() + batch_output.next_draft_input.verify_done = verify_done + return batch_output + + # Decode / target-verify stage. + if model_worker_batch.spec_info is None: + model_worker_batch.spec_info = DFlashDraftInputV2.create_idle_input( + device=self.device + ) + + draft_input = model_worker_batch.spec_info + if not isinstance(draft_input, DFlashDraftInputV2): + raise RuntimeError( + "DFLASH spec-v2 expected DFlashDraftInputV2 state on the running batch." + ) + + if model_worker_batch.forward_mode.is_idle(): + empty_ids = torch.empty((0,), dtype=torch.int64, device=self.device) + empty_lens = torch.empty((0,), dtype=torch.int32, device=self.device) + next_draft_input = self._make_next_draft_input_decode( + verified_id=torch.empty((0,), device=self.device, dtype=torch.int32), + new_seq_lens=torch.empty((0,), device=self.device, dtype=torch.int32), + ) + verify_done = torch.get_device_module(self.device).Event() + verify_done.record() + next_draft_input.verify_done = verify_done + return GenerationBatchResult( + logits_output=None, + next_token_ids=empty_ids, + accept_lens=empty_lens, + next_draft_input=next_draft_input, + can_run_cuda_graph=False, + ) + + # `seq_lens` is carried over from the previous overlap iteration and may have been + # produced on another stream. + model_worker_batch.seq_lens.record_stream( + torch.get_device_module(self.device).current_stream() + ) + + bs = len(model_worker_batch.seq_lens) + device = self.device + + # --- 1) Draft a non-causal block with the draft model. + target_model = self.target_worker.model_runner.model + embed_module = target_model.get_input_embeddings() + lm_head = getattr(target_model, "lm_head", None) + if ( + lm_head is None + or not hasattr(lm_head, "weight") + or not hasattr(lm_head, "shard_indices") + ): + raise RuntimeError( + "DFLASH requires the target model to expose a vocab-parallel `lm_head` with `weight` and " + "`shard_indices` attributes." + ) + + self._ensure_draft_block_buffers(bs) + assert self._draft_block_ids_buf is not None + assert self._draft_block_positions_buf is not None + assert self._draft_block_tokens_buf is not None + assert self._draft_block_end_buf is not None + assert self._draft_seq_lens_cpu_buf is not None + + block_ids = self._draft_block_ids_buf[:bs] + block_ids.fill_(int(self._mask_token_id)) + block_ids[:, 0].copy_(draft_input.verified_id.to(torch.long)) + + noise_embedding = embed_module(block_ids) + input_embeds = noise_embedding.view(-1, noise_embedding.shape[-1]) + + prefix_lens = model_worker_batch.seq_lens + positions_2d = self._draft_block_positions_buf[:bs] + torch.add( + prefix_lens.to(torch.int64).unsqueeze(1), + self._block_pos_offsets, + out=positions_2d, + ) + positions = positions_2d.reshape(-1) + + end_offset = prefix_lens + int(self.block_size) + verify_out_cache_loc = assign_extend_cache_locs_func( + req_pool_indices=model_worker_batch.req_pool_indices, + req_to_token=self.model_runner.req_to_token_pool.req_to_token, + start_offset=prefix_lens, + end_offset=end_offset, + batch_size=bs, + draft_token_num=int(self.block_size), + device=device, + ) + + seq_lens_cpu = self._draft_seq_lens_cpu_buf[:bs] + if self.use_compact_draft_cache: + # Rebuild the draft-local sliding-window view from committed target state. + draft_prefix_lens = self._compute_compact_draft_seq_lens(prefix_lens) + seq_lens_cpu.copy_(draft_prefix_lens.to(device="cpu", dtype=torch.int32)) + + suffix_start = prefix_lens.to(torch.int64) - draft_prefix_lens.to( + torch.int64 + ) + suffix_cache_loc = self._gather_req_to_token_segments( + req_to_token=self.model_runner.req_to_token_pool.req_to_token, + req_pool_indices=model_worker_batch.req_pool_indices, + start=suffix_start, + lengths=draft_prefix_lens, + ) + assign_req_to_token_pool_func( + model_worker_batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + torch.zeros_like(draft_prefix_lens), + draft_prefix_lens, + suffix_cache_loc, + bs, + ) + + block_end = self._draft_block_end_buf[:bs] + torch.add(draft_prefix_lens, int(self.block_size), out=block_end) + assign_req_to_token_pool_func( + model_worker_batch.req_pool_indices, + self.draft_model_runner.req_to_token_pool.req_to_token, + draft_prefix_lens, + block_end, + verify_out_cache_loc, + bs, + ) + draft_seq_lens = draft_prefix_lens + else: + # Non-windowed path uses the shared overallocated mapping directly. + draft_seq_lens = prefix_lens + if model_worker_batch.seq_lens_cpu is not None: + if model_worker_batch.seq_lens_cpu.dtype == torch.int32: + seq_lens_cpu.copy_(model_worker_batch.seq_lens_cpu) + else: + seq_lens_cpu.copy_(model_worker_batch.seq_lens_cpu.to(torch.int32)) + else: + seq_lens_cpu.copy_(prefix_lens.to("cpu", dtype=torch.int32)) + + forward_batch = ForwardBatch( + forward_mode=ForwardMode.TARGET_VERIFY, + batch_size=bs, + input_ids=block_ids.flatten(), + req_pool_indices=model_worker_batch.req_pool_indices, + seq_lens=draft_seq_lens, + out_cache_loc=verify_out_cache_loc, + seq_lens_sum=int(draft_seq_lens.sum().item()), + seq_lens_cpu=seq_lens_cpu, + positions=positions, + req_to_token_pool=self.draft_model_runner.req_to_token_pool, + token_to_kv_pool=self.draft_model_runner.token_to_kv_pool, + attn_backend=self.draft_model_runner.attn_backend, + input_embeds=input_embeds, + spec_algorithm=SpeculativeAlgorithm.DFLASH, + spec_info=self._draft_block_spec_info, + capture_hidden_mode=CaptureHiddenMode.NULL, + ) + + with torch.inference_mode(): + draft_logits_output = self.draft_model_runner.forward( + forward_batch + ).logits_output + + draft_hidden = draft_logits_output.hidden_states + if draft_hidden is None: + raise RuntimeError("DFLASH draft model returned no hidden states.") + draft_hidden = draft_hidden.view(bs, int(self.block_size), -1) + draft_next = self._greedy_sample_from_vocab_parallel_head( + hidden_states=draft_hidden[:, 1:, :].reshape(-1, draft_hidden.shape[-1]), + lm_head=lm_head, + ).view(bs, int(self.block_size) - 1) + + draft_tokens = self._draft_block_tokens_buf[:bs] + draft_tokens[:, 0].copy_(block_ids[:, 0]) + draft_tokens[:, 1:].copy_(draft_next) + + # --- 2) Target verify. + # TARGET_VERIFY uses standard causal masking; custom masks are unnecessary here. + custom_mask = None + + verify_input_ids = draft_tokens.reshape(-1) + verify_input = DFlashVerifyInput( + draft_token=verify_input_ids, + positions=positions, + draft_token_num=int(self.block_size), + custom_mask=custom_mask, + capture_hidden_mode=CaptureHiddenMode.FULL, + ) + + model_worker_batch.forward_mode = ( + ForwardMode.TARGET_VERIFY + if not model_worker_batch.forward_mode.is_idle() + else ForwardMode.IDLE + ) + model_worker_batch.input_ids = verify_input_ids + model_worker_batch.out_cache_loc = verify_out_cache_loc + model_worker_batch.spec_info = verify_input + model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL + + need_mamba_verify_commit = hasattr( + self.target_worker.model_runner.attn_backend, + "update_mamba_state_after_mtp_verify", + ) + seq_lens_pre_verify = ( + model_worker_batch.seq_lens.clone() if need_mamba_verify_commit else None + ) + + target_out = self.target_worker.forward_batch_generation( + model_worker_batch, is_verify=True, **kwargs + ) + logits_output = target_out.logits_output + can_run_cuda_graph = target_out.can_run_cuda_graph + + sampling_info = model_worker_batch.sampling_info + if sampling_info is not None: + apply_dflash_verify_logits_adjustments( + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + draft_token_num=int(self.block_size), + ) + + candidates = draft_tokens + if ( + sampling_info is not None + and not sampling_info.is_all_greedy + and is_dflash_sampling_verify_available() + ): + accept_len, bonus = compute_dflash_sampling_accept_len_and_bonus( + candidates=candidates, + next_token_logits=logits_output.next_token_logits, + sampling_info=sampling_info, + max_top_k=draft_input.max_top_k, + uniform_top_k_value=draft_input.uniform_top_k_value, + ) + else: + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1).view( + bs, int(self.block_size) + ) + accept_len, bonus = compute_dflash_accept_len_and_bonus( + candidates=candidates, + target_predict=target_predict, + ) + commit_lens = accept_len.to(torch.int32) + 1 # [bs] + + if need_mamba_verify_commit: + assert seq_lens_pre_verify is not None + self._update_target_mamba_state_after_verify( + batch=model_worker_batch, + seq_lens_pre_verify=seq_lens_pre_verify, + commit_lens=commit_lens, + ) + + out_tokens = torch.empty( + (bs, int(self.block_size)), dtype=torch.int64, device=device + ) + if int(self.block_size) > 1: + out_tokens[:, : int(self.block_size) - 1].copy_(candidates[:, 1:]) + out_tokens[:, int(self.block_size) - 1].fill_(0) + out_tokens.scatter_(1, accept_len.to(torch.int64)[:, None], bonus[:, None]) + + # --- 3) Materialize committed verify-input tokens into draft KV cache. + hidden = logits_output.hidden_states + if hidden is None: + raise RuntimeError( + "DFLASH verify requires target hidden states, but got None." + ) + hidden = hidden.view(bs, int(self.block_size), -1) + + # Keep KV append dense to avoid boolean-index packing (which can introduce sync). + offsets = self._block_pos_offsets # [block_size] + mask2d = offsets[None, :] < commit_lens.to(torch.int64)[:, None] # [bs, block] + mask_flat = mask2d.reshape(-1) + + loc2d = verify_out_cache_loc.view(bs, int(self.block_size)) + loc2d = torch.where(mask2d, loc2d, loc2d.new_zeros(())) + loc_flat = loc2d.reshape(-1) + + self._append_target_hidden_to_draft_kv_by_loc( + target_hidden=hidden.reshape(-1, hidden.shape[-1]), + cache_loc=loc_flat, + positions=positions, + mask_valid=mask_flat, + ) + + # Avoid copying large hidden-state buffers to CPU in overlap scheduling. + logits_output.hidden_states = None + + new_seq_lens = prefix_lens + commit_lens.to(prefix_lens.dtype) + next_draft_input = self._make_next_draft_input_decode( + verified_id=bonus, + new_seq_lens=new_seq_lens, + ) + verify_done = torch.get_device_module(device).Event() + verify_done.record() + next_draft_input.verify_done = verify_done + + return GenerationBatchResult( + logits_output=logits_output, + next_token_ids=out_tokens.reshape(-1), + accept_lens=commit_lens, + can_run_cuda_graph=can_run_cuda_graph, + next_draft_input=next_draft_input, + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index a40a8aa0dc33..6f2f54dbe5e5 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -4,6 +4,8 @@ from enum import Enum, IntEnum, auto from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union +from sglang.srt.environ import envs + if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.tp_worker import TpModelWorker @@ -15,6 +17,7 @@ class SpeculativeAlgorithm(Enum): """Enumeration of speculative decoding algorithms.""" + DFLASH = auto() EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() @@ -33,6 +36,9 @@ def from_string(cls, name: Optional[str]) -> SpeculativeAlgorithm: def is_none(self) -> bool: return self == SpeculativeAlgorithm.NONE + def is_speculative(self) -> bool: + return self != SpeculativeAlgorithm.NONE + def is_eagle(self) -> bool: # NOTE: EAGLE3 is a variant of EAGLE return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.EAGLE3 @@ -40,6 +46,9 @@ def is_eagle(self) -> bool: def is_eagle3(self) -> bool: return self == SpeculativeAlgorithm.EAGLE3 + def is_dflash(self) -> bool: + return self == SpeculativeAlgorithm.DFLASH + def is_standalone(self) -> bool: return self == SpeculativeAlgorithm.STANDALONE @@ -47,7 +56,11 @@ def is_ngram(self) -> bool: return self == SpeculativeAlgorithm.NGRAM def supports_spec_v2(self) -> bool: - return self.is_eagle() or self.is_standalone() + return ( + self.is_eagle() + or self.is_standalone() + or (self.is_dflash() and envs.SGLANG_ENABLE_DFLASH_SPEC_V2.get()) + ) def create_worker( self, server_args: ServerArgs @@ -57,6 +70,21 @@ def create_worker( ), "Cannot create worker for NONE speculative algorithm." enable_overlap = not server_args.disable_overlap_schedule + + if self.is_dflash(): + if enable_overlap: + if not envs.SGLANG_ENABLE_DFLASH_SPEC_V2.get(): + raise ValueError( + "DFLASH does not support overlap scheduling (spec v2) by default. " + "Set env SGLANG_ENABLE_DFLASH_SPEC_V2=True to opt in." + ) + from sglang.srt.speculative.dflash_worker_v2 import DFlashWorkerV2 + + return DFlashWorkerV2 + from sglang.srt.speculative.dflash_worker import DFlashWorker + + return DFlashWorker + if self.is_eagle() and server_args.enable_multi_layer_eagle: # FIXME: migrate to EagleWorker if enable_overlap: @@ -110,6 +138,8 @@ class SpecInputType(IntEnum): # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it EAGLE_DRAFT = auto() EAGLE_VERIFY = auto() + DFLASH_DRAFT = auto() + DFLASH_VERIFY = auto() NGRAM_VERIFY = auto() @@ -120,11 +150,15 @@ def __init__(self, spec_input_type: SpecInputType): def is_draft_input(self) -> bool: # FIXME: remove this function which is only used for assertion # or use another variable name like `draft_input` to substitute `spec_info` - return self.spec_input_type == SpecInputType.EAGLE_DRAFT + return self.spec_input_type in { + SpecInputType.EAGLE_DRAFT, + SpecInputType.DFLASH_DRAFT, + } def is_verify_input(self) -> bool: return self.spec_input_type in { SpecInputType.EAGLE_VERIFY, + SpecInputType.DFLASH_VERIFY, SpecInputType.NGRAM_VERIFY, } diff --git a/python/sglang/srt/speculative/triton_ops/__init__.py b/python/sglang/srt/speculative/triton_ops/__init__.py new file mode 100644 index 000000000000..a8ea8f4c704b --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Triton kernels for speculative decoding.""" + +from sglang.srt.speculative.triton_ops.fused_kv_materialize import ( + FusedKVMaterializeHelper, +) + +__all__ = ["FusedKVMaterializeHelper"] diff --git a/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py new file mode 100644 index 000000000000..e7dc4c05ddfc --- /dev/null +++ b/python/sglang/srt/speculative/triton_ops/fused_kv_materialize.py @@ -0,0 +1,303 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Fused Triton kernel for DFlash KV materialization. + +Combines: KV projection (cuBLAS) + RMSNorm + RoPE (Triton), then pool-managed KV writes. +""" + +from typing import Callable, List + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_norm_rope_kernel( + kv_ptr, # [total_ctx, kv_size * 2] + k_norm_weight_ptr, # [head_dim] + cos_sin_cache_ptr, # [max_pos, rotary_dim] + positions_ptr, # [total_ctx] + k_out_ptr, # [total_ctx, num_kv_heads, head_dim] + v_out_ptr, # [total_ctx, num_kv_heads, head_dim] + kv_stride_ctx, + cos_sin_stride_pos, + k_out_stride_ctx, + k_out_stride_head, + v_out_stride_ctx, + v_out_stride_head, + total_ctx, + num_kv_heads: tl.constexpr, + head_dim: tl.constexpr, + kv_size: tl.constexpr, + rotary_dim: tl.constexpr, + half_rotary_dim: tl.constexpr, + eps: tl.constexpr, + BLOCK_HD: tl.constexpr, +): + """Fused RMSNorm(K) + RoPE(K) materialization. Grid: (total_ctx, num_kv_heads).""" + ctx_id = tl.program_id(0) + head_id = tl.program_id(1) + if ctx_id >= total_ctx: + return + + # Load metadata + position = tl.load(positions_ptr + ctx_id) + + # Compute base pointers + kv_base = kv_ptr + ctx_id * kv_stride_ctx + k_base = kv_base + head_id * head_dim + v_base = kv_base + kv_size + head_id * head_dim + k_write = k_out_ptr + ctx_id * k_out_stride_ctx + head_id * k_out_stride_head + v_write = v_out_ptr + ctx_id * v_out_stride_ctx + head_id * v_out_stride_head + + # Load K and V + offs = tl.arange(0, BLOCK_HD) + mask_hd = offs < head_dim + mask_half = offs < half_rotary_dim + + k_raw = tl.load(k_base + offs, mask=mask_hd, other=0.0).to(tl.float32) + v_raw = tl.load(v_base + offs, mask=mask_hd, other=0.0) + + # RMSNorm on K + inv_rms = tl.rsqrt(tl.sum(k_raw * k_raw) / head_dim + eps) + norm_w = tl.load(k_norm_weight_ptr + offs, mask=mask_hd, other=1.0).to(tl.float32) + k_normed = k_raw * inv_rms * norm_w + + # RoPE (neox style): k_first, k_second -> rotated + cos_sin_base = cos_sin_cache_ptr + position * cos_sin_stride_pos + cos_v = tl.load(cos_sin_base + offs, mask=mask_half, other=1.0).to(tl.float32) + sin_v = tl.load( + cos_sin_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + + # Extract first/second halves of K for rotation + k_first = tl.where(mask_half, k_normed, 0.0) + k_second_raw = tl.load( + k_base + half_rotary_dim + offs, mask=mask_half, other=0.0 + ).to(tl.float32) + norm_w_second = tl.load( + k_norm_weight_ptr + half_rotary_dim + offs, mask=mask_half, other=1.0 + ).to(tl.float32) + k_second = k_second_raw * inv_rms * norm_w_second + + # Apply rotation + k_rot_first = k_first * cos_v - k_second * sin_v + k_rot_second = k_second * cos_v + k_first * sin_v + + # Store V (no transform) + tl.store(v_write + offs, v_raw, mask=mask_hd) + + # Store K: rotated halves + pass-through + tl.store(k_write + offs, k_rot_first.to(v_raw.dtype), mask=mask_half) + tl.store( + k_write + half_rotary_dim + offs, k_rot_second.to(v_raw.dtype), mask=mask_half + ) + mask_pass = (offs >= rotary_dim) & (offs < head_dim) + tl.store(k_write + offs, k_normed.to(v_raw.dtype), mask=mask_pass) + + +def _fused_norm_rope( + kv: torch.Tensor, # [total_ctx, kv_size*2] + k_norm_weight: torch.Tensor, # [head_dim] + cos_sin_cache: torch.Tensor, # [max_pos, rotary_dim] + positions: torch.Tensor, # [total_ctx] + num_kv_heads: int, + head_dim: int, + rotary_dim: int, + eps: float = 1e-6, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused RMSNorm + RoPE materialization for a single layer.""" + total_ctx = kv.shape[0] + if total_ctx == 0: + empty = torch.empty( + (0, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + return empty, empty + + kv_size = num_kv_heads * head_dim + if kv.shape[1] != kv_size * 2: + raise ValueError( + "Invalid fused KV projection shape: " + f"got {tuple(kv.shape)}, expected second dim {kv_size * 2}." + ) + if rotary_dim <= 0 or rotary_dim > head_dim or rotary_dim % 2 != 0: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={rotary_dim}, head_dim={head_dim}." + ) + + half_rotary_dim = rotary_dim // 2 + BLOCK_HD = triton.next_power_of_2(head_dim) + + # Ensure int64 for indexing + if positions.device != kv.device: + positions = positions.to(device=kv.device, dtype=torch.int64) + elif positions.dtype != torch.int64: + positions = positions.to(torch.int64) + + k_out = torch.empty( + (total_ctx, num_kv_heads, head_dim), dtype=kv.dtype, device=kv.device + ) + v_out = torch.empty_like(k_out) + + _fused_norm_rope_kernel[(total_ctx, num_kv_heads)]( + kv, + k_norm_weight, + cos_sin_cache, + positions, + k_out, + v_out, + kv.stride(0), + cos_sin_cache.stride(0), + k_out.stride(0), + k_out.stride(1), + v_out.stride(0), + v_out.stride(1), + total_ctx, + num_kv_heads, + head_dim, + kv_size, + rotary_dim, + half_rotary_dim, + eps, + BLOCK_HD, + ) + return k_out, v_out + + +class FusedKVMaterializeHelper: + """Fused KV materialization helper using batched projection. + + Uses torch.einsum for batched KV projection across all layers, + then a Triton kernel for fused RMSNorm + RoPE materialization per layer. + """ + + def __init__( + self, + layers: List, + rotary_emb, + num_kv_heads: int, + head_dim: int, + device: torch.device, + ): + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.rotary_emb = rotary_emb + self.n_layers = len(layers) + self.device = device + + self.rotary_dim = int(getattr(rotary_emb, "rotary_dim", head_dim)) + self.is_neox_style = bool(getattr(rotary_emb, "is_neox_style", True)) + + if not self.is_neox_style: + raise NotImplementedError("Only neox-style RoPE is supported.") + if self.rotary_dim <= 0 or self.rotary_dim > self.head_dim: + raise ValueError( + "Invalid fused KV rotary/head dim pair: " + f"rotary_dim={self.rotary_dim}, head_dim={self.head_dim}." + ) + + # Pre-extract and stack weights for batched projection. + kv_weights = [] + self.k_norm_weights = [] + self.eps_values = [] + + for layer_id, layer in enumerate(layers): + attn = layer.self_attn + if int(attn.num_kv_heads) != self.num_kv_heads: + raise ValueError( + "num_kv_heads mismatch across layers for fused KV path: " + f"expected {self.num_kv_heads}, got {int(attn.num_kv_heads)} at layer {layer_id}." + ) + if int(attn.head_dim) != self.head_dim: + raise ValueError( + "head_dim mismatch across layers for fused KV path: " + f"expected {self.head_dim}, got {int(attn.head_dim)} at layer {layer_id}." + ) + layer_rotary_dim = int( + getattr(attn.rotary_emb, "rotary_dim", self.head_dim) + ) + layer_is_neox = bool(getattr(attn.rotary_emb, "is_neox_style", True)) + if ( + layer_rotary_dim != self.rotary_dim + or layer_is_neox != self.is_neox_style + ): + raise ValueError( + "RoPE config mismatch across layers for fused KV path: " + f"expected (rotary_dim={self.rotary_dim}, neox={self.is_neox_style}), " + f"got (rotary_dim={layer_rotary_dim}, neox={layer_is_neox}) at layer {layer_id}." + ) + + # Extract KV portion of QKV weight + qkv_w = attn.qkv_proj.weight + kv_weight = qkv_w[attn.q_size : attn.q_size + 2 * attn.kv_size] + kv_weights.append(kv_weight) + self.k_norm_weights.append(attn.k_norm.weight) + self.eps_values.append(attn.k_norm.variance_epsilon) + + # Stack for batched einsum: [n_layers, kv_size*2, hidden_size] + self.batched_kv_weight = torch.stack(kv_weights) + + def materialize( + self, + ctx_hidden: torch.Tensor, + positions: torch.Tensor, + write_layer_kv: Callable[[int, torch.Tensor, torch.Tensor], None], + ) -> None: + """Materialize KV cache for all layers using batched projection.""" + total_ctx = ctx_hidden.shape[0] + if total_ctx == 0: + return + + if positions.ndim != 1: + positions = positions.reshape(-1) + if positions.numel() != total_ctx: + raise ValueError( + "positions must match ctx_hidden token count for fused KV materialization: " + f"positions={positions.numel()}, total_ctx={total_ctx}." + ) + + max_position = int(positions.max().item()) + ensure_cos_sin_cache_length = getattr( + self.rotary_emb, "_ensure_cos_sin_cache_length", None + ) + if callable(ensure_cos_sin_cache_length): + ensure_cos_sin_cache_length(max_position) + + cos_sin_cache = self.rotary_emb.cos_sin_cache + if max_position >= int(cos_sin_cache.shape[0]): + raise RuntimeError( + "RoPE cos/sin cache is too short for fused KV materialization: " + f"max_position={max_position}, cache_len={int(cos_sin_cache.shape[0])}." + ) + if cos_sin_cache.device != ctx_hidden.device: + cos_sin_cache = cos_sin_cache.to(ctx_hidden.device) + + # Batched KV projection: [n_layers, total_ctx, kv_size*2] + kv_all = torch.einsum("th,loh->lto", ctx_hidden, self.batched_kv_weight) + + # Per-layer fused norm/RoPE/materialize, then delegate writes to the KV pool. + for layer_id in range(self.n_layers): + cache_k, cache_v = _fused_norm_rope( + kv_all[layer_id], + self.k_norm_weights[layer_id], + cos_sin_cache, + positions, + self.num_kv_heads, + self.head_dim, + self.rotary_dim, + self.eps_values[layer_id], + ) + write_layer_kv(layer_id, cache_k, cache_v)