From 621b85c28be3f4494d0c274947e3ea0027fc4dcf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Fri, 8 May 2026 04:49:09 +0000 Subject: [PATCH 1/3] support SWA in EAGLE --- python/sglang/srt/models/llama_eagle3.py | 22 +++++++++++-- python/sglang/srt/server_args.py | 31 ++++++++++--------- .../sglang/srt/speculative/dflash_worker.py | 4 +-- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index 0bff63788eae..d04dfe3c14e3 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -38,6 +38,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP +from sglang.srt.server_args import get_global_server_args class LlamaDecoderLayer(LlamaDecoderLayer): @@ -46,6 +47,7 @@ def __init__( config: LlamaConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + draft_window_size: Optional[int] = None, prefix: str = "", ) -> None: super().__init__(config, layer_id, quant_config, prefix) @@ -61,6 +63,9 @@ def __init__( prefix=add_prefix("qkv_proj", prefix), ) + if draft_window_size is not None: + self.self_attn.attn.sliding_window_size = draft_window_size + if config.model_type == "llama4_text": inter_size = config.intermediate_size_mlp else: @@ -106,6 +111,7 @@ def __init__( self, config: LlamaConfig, quant_config: Optional[QuantizationConfig] = None, + draft_window_size: Optional[int] = None, prefix: str = "", ) -> None: super().__init__() @@ -150,7 +156,7 @@ def __init__( bias=getattr(config, "bias", False), ) - self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix) + self.midlayer = LlamaDecoderLayer(config, 0, quant_config, draft_window_size, prefix) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -229,7 +235,10 @@ def __init__( raise ValueError("EAGLE3 currently only supports 1 layer") self.model = LlamaModel( - config, quant_config=quant_config, prefix=add_prefix("model", prefix) + config, + quant_config=quant_config, + draft_window_size=self.get_attention_sliding_window_size(), + prefix=add_prefix("model", prefix), ) # Llama 3.2 1B Instruct set tie_word_embeddings to True # Llama 3.1 8B Instruct set tie_word_embeddings to False @@ -302,5 +311,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: def get_hot_token_id(self): return self.hot_token_id + def get_attention_sliding_window_size(self): + server_args = get_global_server_args() + draft_window_size: Optional[int] = ( + int(server_args.speculative_draft_window_size) - 1 + if server_args.speculative_draft_window_size is not None + else None + ) + return draft_window_size + EntryClass = [LlamaForCausalLMEagle3] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1660a24179be..161f7e9d4888 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -567,12 +567,12 @@ class ServerArgs: speculative_eagle_topk: Optional[int] = None speculative_num_draft_tokens: Optional[int] = None speculative_dflash_block_size: Optional[int] = None - speculative_dflash_draft_window_size: Optional[int] = None speculative_accept_threshold_single: float = 1.0 speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None speculative_attention_mode: str = "prefill" speculative_draft_attention_backend: Optional[str] = None + speculative_draft_window_size: Optional[int] = None speculative_moe_runner_backend: Optional[str] = None speculative_moe_a2a_backend: Optional[str] = None speculative_draft_model_quantization: Optional[str] = None @@ -3444,14 +3444,13 @@ def _handle_speculative_decoding(self): ) window_size = None - if self.speculative_dflash_draft_window_size is not None: - window_size = int(self.speculative_dflash_draft_window_size) + if self.speculative_draft_window_size is not None: + window_size = int(self.speculative_draft_window_size) if window_size <= 0: raise ValueError( - "DFLASH requires --speculative-dflash-draft-window-size " - f"to be positive, got {window_size}." + "--speculative-draft-window-size must be positive, got {window_size}." ) - self.speculative_dflash_draft_window_size = window_size + self.speculative_draft_window_size = window_size if self.speculative_num_draft_tokens is None: from sglang.srt.speculative.dflash_utils import ( @@ -5568,15 +5567,6 @@ def add_cli_args(parser: argparse.ArgumentParser): help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.", default=ServerArgs.speculative_dflash_block_size, ) - parser.add_argument( - "--speculative-dflash-draft-window-size", - type=int, - help="DFLASH only. Sliding window size for the draft-model KV cache. " - "When set, the draft worker keeps a recent target-token window in its " - "local cache (paged backends may retain up to one extra page on the left " - "for alignment). Default is full context.", - default=ServerArgs.speculative_dflash_draft_window_size, - ) parser.add_argument( "--speculative-accept-threshold-single", type=float, @@ -5608,6 +5598,17 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Attention backend for speculative decoding drafting.", default=ServerArgs.speculative_draft_attention_backend, ) + parser.add_argument( + "--speculative-draft-window-size", + type=int, + help="Sliding window size for the draft model (honored by EAGLE-3 and DFLASH). " + "For EAGLE-3, the drafter only attends to the most recent N keys " + "(verifier hidden states + its own outputs); the verifier is unaffected. " + "For DFLASH, the draft worker keeps a recent target-token window in its " + "local KV cache (paged backends may retain up to one extra page on the " + "left for alignment). Default is full attention/context.", + default=ServerArgs.speculative_draft_window_size, + ) parser.add_argument( "--speculative-moe-runner-backend", type=str, diff --git a/python/sglang/srt/speculative/dflash_worker.py b/python/sglang/srt/speculative/dflash_worker.py index 8d34db1748a5..7b1d357bb6d1 100644 --- a/python/sglang/srt/speculative/dflash_worker.py +++ b/python/sglang/srt/speculative/dflash_worker.py @@ -74,8 +74,8 @@ def __init__( self.model_runner = target_worker.model_runner self.page_size = server_args.page_size self.draft_window_size: Optional[int] = ( - int(server_args.speculative_dflash_draft_window_size) - if server_args.speculative_dflash_draft_window_size is not None + int(server_args.speculative_draft_window_size) + if server_args.speculative_draft_window_size is not None else None ) self.use_compact_draft_cache = self.draft_window_size is not None From 650e4eb59500c7f8d2a3f6b258c0a4171748de3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Sun, 10 May 2026 17:05:21 -0500 Subject: [PATCH 2/3] lint --- python/sglang/srt/models/llama_eagle3.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index d04dfe3c14e3..404d15c741ab 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -156,7 +156,9 @@ def __init__( bias=getattr(config, "bias", False), ) - self.midlayer = LlamaDecoderLayer(config, 0, quant_config, draft_window_size, prefix) + self.midlayer = LlamaDecoderLayer( + config, 0, quant_config, draft_window_size, prefix + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) From 541d8572c1a962a167e5a26e266a68695d1b43fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Do=C4=9Fa=C3=A7=20Eldenk?= Date: Mon, 11 May 2026 05:07:47 +0000 Subject: [PATCH 3/3] fix args --- python/sglang/srt/models/llama_eagle3.py | 2 +- python/sglang/srt/server_args.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index 404d15c741ab..316865d4d05a 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -316,7 +316,7 @@ def get_hot_token_id(self): def get_attention_sliding_window_size(self): server_args = get_global_server_args() draft_window_size: Optional[int] = ( - int(server_args.speculative_draft_window_size) - 1 + int(server_args.speculative_draft_window_size) if server_args.speculative_draft_window_size is not None else None ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 161f7e9d4888..4e5a3694c98f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -3448,7 +3448,7 @@ def _handle_speculative_decoding(self): window_size = int(self.speculative_draft_window_size) if window_size <= 0: raise ValueError( - "--speculative-draft-window-size must be positive, got {window_size}." + f"--speculative-draft-window-size must be positive, got {window_size}." ) self.speculative_draft_window_size = window_size @@ -3490,7 +3490,7 @@ def _handle_speculative_decoding(self): draft_tokens = int(self.speculative_num_draft_tokens) if window_size < draft_tokens: raise ValueError( - "DFLASH --speculative-dflash-draft-window-size must be >= " + "--speculative-draft-window-size must be >= " "--speculative-num-draft-tokens (block_size). " f"window_size={window_size}, block_size={draft_tokens}." ) @@ -5600,7 +5600,9 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument( "--speculative-draft-window-size", + "--speculative-dflash-draft-window-size", type=int, + dest="speculative_draft_window_size", help="Sliding window size for the draft model (honored by EAGLE-3 and DFLASH). " "For EAGLE-3, the drafter only attends to the most recent N keys " "(verifier hidden states + its own outputs); the verifier is unaffected. "