diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7a614b61a954..db8bb9514d82 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma * `speculative_num_steps`: How many draft passes we run before verifying. * `speculative_num_draft_tokens`: The number of tokens proposed in a draft. * `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). +* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1). ## Double Sparsity diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 011d8030b6a0..d8397bd87090 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -26,7 +26,7 @@ "source": [ "## EAGLE Decoding\n", "\n", - "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" + "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:" ] }, { @@ -46,8 +46,8 @@ "\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n", "\"\"\"\n", ")\n", @@ -103,8 +103,8 @@ "source": [ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n", " --enable-torch-compile --cuda-graph-max-bs 2\n", "\"\"\"\n", @@ -172,9 +172,10 @@ "\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", - " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map {hot_token_ids.pt} \n", + "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n", + " --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n", "\"\"\"\n", ")\n", "\n", diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f360e7f389fc..e3a1f9792262 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -3,6 +3,8 @@ from typing import List, Optional, Union import torch +import os +from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import Req, ScheduleBatch @@ -46,12 +48,12 @@ def __init__( server_args.disable_cuda_graph = True if server_args.speculative_token_map is not None: - try: + if os.path.exists(server_args.speculative_token_map): self.hot_token_id = torch.load(server_args.speculative_token_map) - except: - raise RuntimeError( - f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}" - ) + else: + cache_dir = snapshot_download(os.path.dirname(server_args.speculative_token_map), ignore_patterns=["*.bin", "*.safetensors"]) + file_path = os.path.join(cache_dir, os.path.basename(server_args.speculative_token_map)) + self.hot_token_id = torch.load(file_path) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' )