Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/backend/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma
* `speculative_num_steps`: How many draft passes we run before verifying.
* `speculative_num_draft_tokens`: The number of tokens proposed in a draft.
* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1).
* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1).


## Double Sparsity
Expand Down
81 changes: 76 additions & 5 deletions docs/backend/speculative_decoding.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"source": [
"## EAGLE Decoding\n",
"\n",
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:"
"To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:"
]
},
{
Expand All @@ -46,8 +46,8 @@
"\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n",
"\"\"\"\n",
")\n",
Expand Down Expand Up @@ -103,8 +103,8 @@
"source": [
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n",
" --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
"python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n",
" --enable-torch-compile --cuda-graph-max-bs 2\n",
"\"\"\"\n",
Expand Down Expand Up @@ -135,6 +135,77 @@
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"terminate_process(server_process)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n",
"\n",
"By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n",
"\n",
"In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n",
"\n",
"Thanks for the contribution from [Weilin Zhao](https://github.com/https://github.com/Achazwl) and [Zhousx](https://github.com/Zhou-sx). "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sglang.test.test_utils import is_in_ci\n",
"\n",
"if is_in_ci():\n",
" from patch import launch_server_cmd\n",
"else:\n",
" from sglang.utils import launch_server_cmd\n",
"\n",
"from sglang.utils import wait_for_server, print_highlight, terminate_process\n",
"\n",
"server_process, port = launch_server_cmd(\n",
" \"\"\"\n",
"python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n",
" --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n",
" --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n",
" --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n",
"\"\"\"\n",
")\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n",
"\n",
"response = client.chat.completions.create(\n",
" model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n",
" ],\n",
" temperature=0,\n",
" max_tokens=64,\n",
")\n",
"\n",
"print_highlight(f\"Response: {response}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
11 changes: 8 additions & 3 deletions python/sglang/srt/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,14 @@ def __init__(
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
if hasattr(config, "hot_vocab_size"):
self.lm_head = ParallelLMHead(
config.hot_vocab_size, config.hidden_size, quant_config=quant_config
)
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ServerArgs:
speculative_num_steps: int = 5
speculative_eagle_topk: int = 8
speculative_num_draft_tokens: int = 64
speculative_token_map: Optional[str] = None

# Double Sparsity
enable_double_sparsity: bool = False
Expand Down Expand Up @@ -751,6 +752,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
help="The number of token sampled from draft model in Speculative Decoding.",
default=ServerArgs.speculative_num_draft_tokens,
)
parser.add_argument(
"--speculative-token-map",
type=str,
help="The path of the draft model's small vocab table.",
default=ServerArgs.speculative_token_map,
)

# Double Sparsity
parser.add_argument(
Expand Down
37 changes: 37 additions & 0 deletions python/sglang/srt/speculative/eagle_worker.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import os
import time
from typing import List, Optional, Union

import torch
from huggingface_hub import snapshot_download

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
Expand Down Expand Up @@ -44,6 +46,23 @@ def __init__(
# We will capture it later
backup_disable_cuda_graph = server_args.disable_cuda_graph
server_args.disable_cuda_graph = True

if server_args.speculative_token_map is not None:
if os.path.exists(server_args.speculative_token_map):
self.hot_token_id = torch.load(server_args.speculative_token_map)
else:
cache_dir = snapshot_download(
os.path.dirname(server_args.speculative_token_map),
ignore_patterns=["*.bin", "*.safetensors"],
)
file_path = os.path.join(
cache_dir, os.path.basename(server_args.speculative_token_map)
)
self.hot_token_id = torch.load(file_path)
server_args.json_model_override_args = (
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
)

super().__init__(
gpu_id=gpu_id,
tp_rank=tp_rank,
Expand All @@ -66,7 +85,21 @@ def __init__(
# Share the embedding and lm_head
if not self.speculative_algorithm.is_nextn():
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
if server_args.speculative_token_map is not None:
head = head.clone()
self.hot_token_id = torch.tensor(
self.hot_token_id, dtype=torch.int32, device=head.device
)
head.data = head.data[self.hot_token_id]
else:
self.hot_token_id = None
self.model_runner.model.set_embed_and_head(embed, head)
else:
if server_args.speculative_token_map is not None:
raise NotImplementedError(
"NEXTN does not support speculative-token-map now"
)
self.hot_token_id = None
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph

# Create multi-step attn backends and cuda graph runners
Expand Down Expand Up @@ -223,6 +256,8 @@ def draft_forward(self, forward_batch: ForwardBatch):
spec_info.topk_index,
spec_info.hidden_states,
)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]

# Return values
score_list: List[torch.Tensor] = []
Expand Down Expand Up @@ -262,6 +297,8 @@ def draft_forward(self, forward_batch: ForwardBatch):
)
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
if self.hot_token_id is not None:
topk_index = self.hot_token_id[topk_index]
hidden_states = logits_output.hidden_states

return score_list, token_list, parents_list
Expand Down
Loading