Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions python/sglang/srt/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaMLP
from sglang.srt.server_args import get_global_server_args


class LlamaDecoderLayer(LlamaDecoderLayer):
Expand All @@ -46,6 +47,7 @@ def __init__(
config: LlamaConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
draft_window_size: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__(config, layer_id, quant_config, prefix)
Expand All @@ -65,6 +67,9 @@ def __init__(
prefix=add_prefix("qkv_proj", prefix),
)

if draft_window_size is not None:
self.self_attn.attn.sliding_window_size = draft_window_size

if config.model_type == "llama4_text":
inter_size = config.intermediate_size_mlp
else:
Expand Down Expand Up @@ -115,6 +120,7 @@ def __init__(
self,
config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None,
draft_window_size: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
Expand Down Expand Up @@ -173,7 +179,7 @@ def __init__(

self.layers = nn.ModuleList(
[
LlamaDecoderLayer(config, i, quant_config, prefix)
LlamaDecoderLayer(config, i, quant_config, draft_window_size, prefix)
for i in range(config.num_hidden_layers)
]
)
Expand Down Expand Up @@ -254,7 +260,10 @@ def __init__(
self.pp_group = get_pp_group()

self.model = LlamaModel(
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
config,
quant_config=quant_config,
draft_window_size=self.get_attention_sliding_window_size(),
prefix=add_prefix("model", prefix),
)
# Llama 3.2 1B Instruct set tie_word_embeddings to True
# Llama 3.1 8B Instruct set tie_word_embeddings to False
Expand Down Expand Up @@ -339,5 +348,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
def get_hot_token_id(self):
return self.hot_token_id

def get_attention_sliding_window_size(self):
server_args = get_global_server_args()
draft_window_size: Optional[int] = (
int(server_args.speculative_draft_window_size)
if server_args.speculative_draft_window_size is not None
else None
)
return draft_window_size


EntryClass = [LlamaForCausalLMEagle3]
35 changes: 19 additions & 16 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,12 +568,12 @@ class ServerArgs:
speculative_eagle_topk: Optional[int] = None
speculative_num_draft_tokens: Optional[int] = None
speculative_dflash_block_size: Optional[int] = None
speculative_dflash_draft_window_size: Optional[int] = None
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
speculative_attention_mode: str = "prefill"
speculative_draft_attention_backend: Optional[str] = None
speculative_draft_window_size: Optional[int] = None
speculative_moe_runner_backend: Optional[str] = None
speculative_moe_a2a_backend: Optional[str] = None
speculative_draft_model_quantization: Optional[str] = None
Expand Down Expand Up @@ -3551,14 +3551,13 @@ def _handle_speculative_decoding(self):
)

window_size = None
if self.speculative_dflash_draft_window_size is not None:
window_size = int(self.speculative_dflash_draft_window_size)
if self.speculative_draft_window_size is not None:
window_size = int(self.speculative_draft_window_size)
if window_size <= 0:
raise ValueError(
"DFLASH requires --speculative-dflash-draft-window-size "
f"to be positive, got {window_size}."
f"--speculative-draft-window-size must be positive, got {window_size}."
)
self.speculative_dflash_draft_window_size = window_size
self.speculative_draft_window_size = window_size

if self.speculative_num_draft_tokens is None:
from sglang.srt.speculative.dflash_utils import (
Expand Down Expand Up @@ -3598,7 +3597,7 @@ def _handle_speculative_decoding(self):
draft_tokens = int(self.speculative_num_draft_tokens)
if window_size < draft_tokens:
raise ValueError(
"DFLASH --speculative-dflash-draft-window-size must be >= "
"--speculative-draft-window-size must be >= "
"--speculative-num-draft-tokens (block_size). "
f"window_size={window_size}, block_size={draft_tokens}."
)
Expand Down Expand Up @@ -5680,15 +5679,6 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="DFLASH only. Block size (verify window length). Alias of --speculative-num-draft-tokens for DFLASH.",
default=ServerArgs.speculative_dflash_block_size,
)
parser.add_argument(
"--speculative-dflash-draft-window-size",
type=int,
help="DFLASH only. Sliding window size for the draft-model KV cache. "
"When set, the draft worker keeps a recent target-token window in its "
"local cache (paged backends may retain up to one extra page on the left "
"for alignment). Default is full context.",
default=ServerArgs.speculative_dflash_draft_window_size,
)
parser.add_argument(
"--speculative-accept-threshold-single",
type=float,
Expand Down Expand Up @@ -5720,6 +5710,19 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="Attention backend for speculative decoding drafting.",
default=ServerArgs.speculative_draft_attention_backend,
)
parser.add_argument(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For backward compatibility, may be

parser.add_argument(
      "--speculative-draft-window-size",
      "--speculative-dflash-draft-window-size",  # alias
      type=int,
      dest="speculative_draft_window_size",
      ...
  )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh makes sense 👍

"--speculative-draft-window-size",
"--speculative-dflash-draft-window-size",
type=int,
dest="speculative_draft_window_size",
help="Sliding window size for the draft model (honored by EAGLE-3 and DFLASH). "
"For EAGLE-3, the drafter only attends to the most recent N keys "
"(verifier hidden states + its own outputs); the verifier is unaffected. "
"For DFLASH, the draft worker keeps a recent target-token window in its "
"local KV cache (paged backends may retain up to one extra page on the "
"left for alignment). Default is full attention/context.",
default=ServerArgs.speculative_draft_window_size,
)
parser.add_argument(
"--speculative-moe-runner-backend",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/speculative/dflash_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def __init__(
self.model_runner = target_worker.model_runner
self.page_size = server_args.page_size
self.draft_window_size: Optional[int] = (
int(server_args.speculative_dflash_draft_window_size)
if server_args.speculative_dflash_draft_window_size is not None
int(server_args.speculative_draft_window_size)
if server_args.speculative_draft_window_size is not None
else None
)
self.use_compact_draft_cache = self.draft_window_size is not None
Expand Down
Loading