Skip to content

Feat: Support SWA (Sliding Window Attention) for EAGLE-3 drafter#24664

Merged
hnyls2002 merged 4 commits into
sgl-project:mainfrom
Dogacel:eagle-swa
May 12, 2026
Merged

Feat: Support SWA (Sliding Window Attention) for EAGLE-3 drafter#24664
hnyls2002 merged 4 commits into
sgl-project:mainfrom
Dogacel:eagle-swa

Conversation

@Dogacel
Copy link
Copy Markdown
Contributor

@Dogacel Dogacel commented May 8, 2026

Motivation

Add sliding-window attention support to EAGLE series models. Related to #24663, in our upcoming paper we showed how SWA can help increase acceptance lengths if model is not trained to handle long context lengths. Some models are not usable without SWA as their training length is usually short (2-4K).

Modifications

  1. Add SWA support to EAGLE via CLI flag --speculative-draft-window-size, which defaults to None to disable SWA.
  2. Support SWA in llama_eagle3.py
  3. Unify D-Flash's SWA and EAGLE's SWA flags.

Accuracy Tests

MT-Bench on gpt-oss-20b and llama3.1 ensured acceptance lengths are the same when SWA is disabled.

Speed Tests and Profiling

Qwen3.5 or GPT-oss-20b drafter accuracies are less effected by long context as they have internal sliding window layers. However some models like Llama 3.1 8B's accuracy hurts pretty badly when context length exceeds trained context length.

Benchmark: MT-Bench (80 questions)
Temperature: 0.7
Max Tokens: 2048

Context Pre-fill: 32K tokens — gpt-oss-20b drafter

Config 64-5-1-6

Setting Latency (s) Output Throughput (tok/s) Accept Length
No SWA 72.38 1341.76 2.370
SWA=2048 63.74 1599.73 2.568

Config 1-5-1-6

Setting Latency (s) Output Throughput (tok/s) Accept Length
No SWA 324.49 303.81 2.401
SWA=2048 311.974 324.83 2.5576

Checklist

Review and Merge Process

  1. Ping Merge Oncalls to start the process. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • Common commands include /tag-and-rerun-ci, /tag-run-ci-label, /rerun-failed-ci
  4. After green CI and required approvals, ask Merge Oncalls or people with Write permission to merge the PR.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@Qiaolin-Yu
Copy link
Copy Markdown
Collaborator

/tag-and-rerun-ci

@Qiaolin-Yu Qiaolin-Yu self-assigned this May 10, 2026
def get_attention_sliding_window_size(self):
server_args = get_global_server_args()
draft_window_size: Optional[int] = (
int(server_args.speculative_draft_window_size) - 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: does dflash also minus 1 here? if not, the concept of this var is not consistent.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I remember D-Flash doing it, but maybe it is refactored or my memory is poor. Reverted it to be just speculative_draft_window_size anyway.

Comment thread python/sglang/srt/server_args.py Outdated
raise ValueError(
"DFLASH requires --speculative-dflash-draft-window-size "
f"to be positive, got {window_size}."
"--speculative-draft-window-size must be positive, got {window_size}."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be f"...", right?

help="Attention backend for speculative decoding drafting.",
default=ServerArgs.speculative_draft_attention_backend,
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward compatibility, may be

parser.add_argument(
      "--speculative-draft-window-size",
      "--speculative-dflash-draft-window-size",  # alias
      type=int,
      dest="speculative_draft_window_size",
      ...
  )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh makes sense 👍

@kpham-sgl kpham-sgl self-assigned this May 11, 2026
@Qiaolin-Yu Qiaolin-Yu self-requested a review May 11, 2026 06:15
Copy link
Copy Markdown
Collaborator

@Qiaolin-Yu Qiaolin-Yu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@Qiaolin-Yu Qiaolin-Yu self-requested a review May 11, 2026 06:17
# Conflicts:
#	python/sglang/srt/models/llama_eagle3.py
@hnyls2002 hnyls2002 merged commit 186eb42 into sgl-project:main May 12, 2026
38 of 79 checks passed
SpencerGarnets added a commit to ai-blaise/optimization-playground that referenced this pull request May 12, 2026
…ack)

Brings in upstream sgl-project/sglang main commits since
096ad02 (merge base, Laguna-XS.2 model support).
Total: 28 upstream commits composed.

Custom-stack files preserved intact (entirely-ours, byte-identical to
origin/main):
  - Blackwell CuTe kernel suite (warp_decode_cute, g1_attention_cute,
    gated_norm_cute, layersplit_cute, fused_store_index_cache)
  - TurboQuant 2.5-bit dense KV cache path
  - HIGGS 2-bit dense KV cache path (with split-K decode)
  - NVFP4 IndexCache dispatcher (active gate)
  - quantization_config_dispatch (HF-config-driven runtime routing)
  - All custom server-args flags and runtime methods preserved

Verification:
  - 200+ merged Python files compile cleanly
  - Dispatcher symbol presence verified
  - HIGGS pool / TurboQuant pool classes present at expected lines
  - compressed_tensors_w4a4_nvfp4_moe imports clean
  - All custom server-args flags present (enable_higgs_dense_2bit_kv_cache,
    enable_turboquant_dense_kv_cache, turboquant_dense_kv_preset,
    indexer_quantization_declared, higgs_mla_decode_num_splits, etc.)

Manual-merged shared files (auto-merge gave broken/mixed output; cleaned
up post-merge):
  - python/sglang/srt/disaggregation/mooncake/conn.py: upstream's PR#24932
    refactored maybe_send_extra into a state-types-loop. Replayed our
    LayerSplit NSA state-index-length-mismatch check inside the SWA/NSA
    branch of the new loop body.
  - sgl-kernel/python/sgl_kernel/__init__.py: upstream's PR#23449 (Apple
    Silicon Metal kernel) wrapped the entire module body in
    `if darwin/arm64: from sgl_kernel.metal import * else: ...`. The
    auto-merge duplicated the file body; rewrote cleanly with upstream's
    structure and re-injected our `g1_gate_forward`,
    `warp_decode_cute_moe_forward`, and
    `warp_decode_cute_moe_packed_forward` imports plus `g1_gate_forward`
    in _DEBUG_EXPORT_NAMES.
  - python/sglang/srt/managers/scheduler_output_processor_mixin.py: line
    628 still referenced `result.num_accepted_drafts` (renamed by PR
    sgl-project#25038 to `num_correct_drafts`). Renamed in place.
  - python/sglang/srt/observability/scheduler_metrics_mixin.py: a block
    around the spec-decode logging path had mixed old/new names from
    auto-merge (lines 553/557/560). Renamed `spec_num_accepted_tokens`
    -> `spec_num_accept_tokens` and local `num_accepted_drafts` ->
    `num_correct_drafts` to match the rest of the file.
  - test/test_smc_info.py: stub Req mock used the old field names
    `spec_accepted_drafts` and `update_spec_acceptance_histogram`.
    Renamed to `spec_num_correct_drafts` and
    `update_spec_correct_drafts_histogram` per PR sgl-project#24081.

Auto-merge cleanly integrated upstream changes to:
  - server_args.py (new fields: prefill_only_disable_kv_cache,
    weight_loader_drop_cache_after_load, prefill_delayer_queue_min_ratio,
    prefill_delayer_max_delay_ms, speculative_draft_window_size, etc.)
  - mem_cache/memory_pool.py (new NoOpMHATokenToKVPool)
  - model_executor/model_runner_kv_cache_mixin.py (NoOpMHATokenToKVPool
    pool factory + _validate_prefill_only_disable_kv_cache_pool_family)
  - layers/attention/nsa_backend.py (spec rename
    num_accepted_drafts -> num_correct_drafts;
    num_accepted_tokens -> num_accept_tokens)
  - layers/attention/nsa/nsa_indexer.py (new _apply_q_scale_and_softmax_scale
    compile method; torch.mm replaces deep_gemm wrapper)
  - 28+ disaggregation/spec/runner files with mostly clean
    upstream-side-only integration.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

----- upstream commit subjects (28) -----
fd3eb77 [Cookbook]: add Laguna-XS.2 (Poolside) (sgl-project#24730)
6be1a45 Fix swa component host hit (sgl-project#25085)
693f497 [NPU] use causal_conv1d_update_v2 for performance (sgl-project#24595)
1efe9e2 [Bug Fix] Reject incompatible combination of --disable-cuda-graph-padding and --enable-torch-compile (sgl-project#23903)
8d27ce7 Optimize uvicorn startup command (sgl-project#25041)
b35fd5f [fix] skip legacy minicpmv conv template for MiniCPM-V 4.6 (sgl-project#24998)
7582237 [Tiny Fix] Disable BCG when inner layer_model unresolved (sgl-project#25021)
ca3bc05 Deepseek-v4-Pro share expert tp1 (sgl-project#24949)
a72d3ae [Spec] Multi-layer mamba scatter cleanup; fix positional call bug (sgl-project#25030)
7128533 Revert "Migrate Intel CPU cases to the test/registered." (sgl-project#25044)
1f985c5 [Spec] Rename `accepted_indices` -> `accept_indices`; drop `_token_id` suffix per Rule 5 (sgl-project#25038)
ecf5d84 Migrate Intel CPU cases to the test/registered. (sgl-project#22670)
d7f4761 [PD] Refactor hybrid state transfer (sgl-project#24932)
91907b7 [UnifiedTree]: Fix Unified HiCache tombstone lock release replay (sgl-project#24972)
4ad63ad [Spec] Rename `accepted_drafts` -> `correct_drafts` for unambiguous naming (sgl-project#24081)
6bfb365 [PD] Rate limit prefill inflight polling warnings (sgl-project#24967)
6bb79c1 [Linear Attn] Add CUSTOM enum and plugin extensibility for kernel backends (sgl-project#24937)
cfc41d5 Fix kimi k2.5 mla eagle + dp attention (sgl-project#25033)
0f3932c [Fix] Qwen3-ASR config: set thinker_config before super().__init__ (sgl-project#24187)
f526e3f [Spec] Mamba scatter cleanup; fix multi-layer positional bug; dflash naming (sgl-project#25029)
10375a1 [NIXL][XPU] Fix uint64 overflow for mismatched P/D TP sizes (e.g. prefill_tp=1, decode_tp=2) (sgl-project#24648)
0a37d24 [diffusion] hardware: support sage attention backend on MUSA (attn backend, 21/N) (sgl-project#24752)
5495026 [HiCache] feat: default storage prefetch timeout (sgl-project#23309)
186eb42 Feat: Support SWA (Sliding Window Attention) for EAGLE-3 drafter (sgl-project#24664)
a75b79e Feat: Support newer EAGLE-3 drafters (sgl-project#24663)
f3a8189 [Spec] Internal rename per N2 v2 naming rule (sgl-project#25014)
bfc2eda [MUSA] Use MUSA-optimized operators in piecewise CUDA graph (sgl-project#23633)
74d70af [Apple Silicon] Add Metal kernel support in sgl-kernel (sgl-project#23449)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants