From 707ddffeae2df6cf0f6509bbc3a9c56c8562af4e Mon Sep 17 00:00:00 2001 From: zhousx Date: Mon, 17 Feb 2025 14:25:37 +0000 Subject: [PATCH 01/13] opt eagle lm_head --- python/sglang/srt/models/llama_eagle.py | 11 ++++++++--- python/sglang/srt/server_args.py | 7 +++++++ python/sglang/srt/speculative/eagle_worker.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 09bfbb170c0f..9e4e57a33cd4 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -117,9 +117,14 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=quant_config - ) + if hasattr(config, 'hot_vocab_size'): + self.lm_head = ParallelLMHead( + config.hot_vocab_size, config.hidden_size, quant_config=quant_config + ) + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) self.logits_processor = LogitsProcessor(config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ddb10e390f50..296400fa6e76 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -128,6 +128,7 @@ class ServerArgs: speculative_num_steps: int = 5 speculative_eagle_topk: int = 8 speculative_num_draft_tokens: int = 64 + speculative_token_map: Optional[str] = None # Double Sparsity enable_double_sparsity: bool = False @@ -751,6 +752,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The number of token sampled from draft model in Speculative Decoding.", default=ServerArgs.speculative_num_draft_tokens, ) + parser.add_argument( + "--speculative-token-map", + type=str, + help="The path of the draft model's small vocab table.", + default=ServerArgs.speculative_token_map, + ) # Double Sparsity parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index eb8e839f950f..3aacc0c157ae 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -44,6 +44,15 @@ def __init__( # We will capture it later backup_disable_cuda_graph = server_args.disable_cuda_graph server_args.disable_cuda_graph = True + + if server_args.speculative_token_map is not None: + try: + self.hot_token_id = torch.load(server_args.speculative_token_map) + except: + raise RuntimeError(f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}") + self.hot_token_id = torch.tensor(self.hot_token_id, dtype=torch.int32, device='cuda') + server_args.json_model_override_args=f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' + super().__init__( gpu_id=gpu_id, tp_rank=tp_rank, @@ -66,6 +75,11 @@ def __init__( # Share the embedding and lm_head if not self.speculative_algorithm.is_nextn(): embed, head = self.target_worker.model_runner.model.get_embed_and_head() + if server_args.speculative_token_map is not None: + head = head.clone() + head.data = head.data[self.hot_token_id] + else: + self.hot_token_id = None self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph @@ -223,6 +237,7 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.topk_index, spec_info.hidden_states, ) + topk_index = self.hot_token_id[topk_index] if self.hot_token_id is not None else topk_index # Return values score_list: List[torch.Tensor] = [] @@ -262,6 +277,7 @@ def draft_forward(self, forward_batch: ForwardBatch): ) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) + topk_index = self.hot_token_id[topk_index] if self.hot_token_id is not None else topk_index hidden_states = logits_output.hidden_states return score_list, token_list, parents_list From 372222fecd726b0a4f9d464cbf0e10958f6d47fb Mon Sep 17 00:00:00 2001 From: zhousx Date: Mon, 24 Feb 2025 14:57:50 +0000 Subject: [PATCH 02/13] opt eagle dynamic lm_head --- python/sglang/srt/server_args.py | 7 ++ python/sglang/srt/speculative/eagle_worker.py | 31 +++++--- .../sglang/srt/speculative/hot_table_utils.py | 71 +++++++++++++++++++ 3 files changed, 100 insertions(+), 9 deletions(-) create mode 100644 python/sglang/srt/speculative/hot_table_utils.py diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 296400fa6e76..e91f6577c58c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -129,6 +129,7 @@ class ServerArgs: speculative_eagle_topk: int = 8 speculative_num_draft_tokens: int = 64 speculative_token_map: Optional[str] = None + speculative_token_map_num_dynamic_tokens: int = 0 # Double Sparsity enable_double_sparsity: bool = False @@ -758,6 +759,12 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The path of the draft model's small vocab table.", default=ServerArgs.speculative_token_map, ) + parser.add_argument( + "--speculative-token-map-num-dynamic-tokens", + type=int, + help="The number of hot tokens which change in runtime.", + default=ServerArgs.speculative_token_map_num_dynamic_tokens, + ) # Double Sparsity parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 3aacc0c157ae..cb3d8c23c6e5 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -25,6 +25,7 @@ select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.hot_table_utils import HotVocabTable logger = logging.getLogger(__name__) @@ -47,11 +48,14 @@ def __init__( if server_args.speculative_token_map is not None: try: - self.hot_token_id = torch.load(server_args.speculative_token_map) + init_hot_token_ids = torch.load(server_args.speculative_token_map) except: raise RuntimeError(f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}") - self.hot_token_id = torch.tensor(self.hot_token_id, dtype=torch.int32, device='cuda') - server_args.json_model_override_args=f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' + if server_args.speculative_token_map_num_dynamic_tokens > len(init_hot_token_ids): + server_args.speculative_token_map_num_dynamic_tokens = len(init_hot_token_ids) + self.hot_token_pool = HotVocabTable(init_hot_token_ids, server_args.speculative_token_map_num_dynamic_tokens) + server_args.json_model_override_args=f'{{"hot_vocab_size": {len(init_hot_token_ids)}}}' + self.hot_token_ids = self.hot_token_pool.get_hot_token_ids() super().__init__( gpu_id=gpu_id, @@ -76,11 +80,14 @@ def __init__( if not self.speculative_algorithm.is_nextn(): embed, head = self.target_worker.model_runner.model.get_embed_and_head() if server_args.speculative_token_map is not None: - head = head.clone() - head.data = head.data[self.hot_token_id] + self.target_worker_lm_haed = head + head_small = head.clone() + head_small.data = head_small.data[self.hot_token_ids] + self.model_runner.model.set_embed_and_head(embed, head_small) else: - self.hot_token_id = None - self.model_runner.model.set_embed_and_head(embed, head) + self.hot_token_pool = None + self.hot_token_ids = None + self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners @@ -127,6 +134,11 @@ def init_cuda_graphs(self): def forward_batch_speculative_generation(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): # Draft + if self.hot_token_pool is not None: + # Update lm_head + self.hot_token_pool.add_token(batch.input_ids) + self.hot_token_ids = self.hot_token_pool.get_hot_token_ids() + self.model_runner.model.lm_head.weight = self.target_worker_lm_haed.data[self.hot_token_ids] spec_info: EagleVerifyInput = self.draft(batch) # Verify @@ -237,7 +249,8 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.topk_index, spec_info.hidden_states, ) - topk_index = self.hot_token_id[topk_index] if self.hot_token_id is not None else topk_index + + topk_index = self.hot_token_ids[topk_index] if self.hot_token_ids is not None else topk_index # Return values score_list: List[torch.Tensor] = [] @@ -277,7 +290,7 @@ def draft_forward(self, forward_batch: ForwardBatch): ) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) - topk_index = self.hot_token_id[topk_index] if self.hot_token_id is not None else topk_index + topk_index = self.hot_token_ids[topk_index] if self.hot_token_ids is not None else topk_index hidden_states = logits_output.hidden_states return score_list, token_list, parents_list diff --git a/python/sglang/srt/speculative/hot_table_utils.py b/python/sglang/srt/speculative/hot_table_utils.py new file mode 100644 index 000000000000..0ed48d39e301 --- /dev/null +++ b/python/sglang/srt/speculative/hot_table_utils.py @@ -0,0 +1,71 @@ +import heapq +from typing import List, Optional, Union +import torch + +class HotVocabTable: + def __init__(self, initial_tokens, num_dynamic_tokens=256): + self.topk = len(initial_tokens) + self.heap = [(self.topk - i, token_id) for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens])] + self.heap.extend([(0, token_id) for i, token_id in enumerate(initial_tokens[-num_dynamic_tokens:])]) + heapq.heapify(self.heap) + self.counters = {token_id: self.topk - i for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens])} + self.counters.update({token_id: 0 for token_id in initial_tokens[-num_dynamic_tokens:]}) + self.pos = {token_id: idx for idx, (_, token_id) in enumerate(self.heap)} + self.token_ids = torch.tensor(initial_tokens, dtype=torch.int32, device='cuda') + + def _sift_up(self, i): + while i > 0: + parent = (i - 1) // 2 + if self.heap[i][0] < self.heap[parent][0]: + self.heap[i], self.heap[parent] = self.heap[parent], self.heap[i] + self.pos[self.heap[i][1]], self.pos[self.heap[parent][1]] = i, parent + i = parent + else: + break + + def _sift_down(self, i): + n = len(self.heap) + while True: + left, right, smallest = 2*i+1, 2*i+2, i + if left < n and self.heap[left][0] < self.heap[smallest][0]: + smallest = left + if right < n and self.heap[right][0] < self.heap[smallest][0]: + smallest = right + if smallest == i: + break + self.heap[i], self.heap[smallest] = self.heap[smallest], self.heap[i] + self.pos[self.heap[i][1]], self.pos[self.heap[smallest][1]] = i, smallest + i = smallest + + def add_token(self, token_ids: Union[torch.Tensor, List[torch.Tensor]]) -> None: + if not isinstance(token_ids, list): + token_ids = [token_ids] + + for t in token_ids: + if t.dim() != 1: + t = t.flatten() + for item in t: + self._add_to_heap(item.item()) + + def _add_to_heap(self, token_id): + self.counters[token_id] = self.counters.get(token_id, 0) + 1 + current_count = self.counters[token_id] + + if token_id in self.pos: + idx = self.pos[token_id] + old_count, _ = self.heap[idx] + if current_count == old_count: + return + self.heap[idx] = (current_count, token_id) + self._sift_down(idx) if current_count > old_count else self._sift_up(idx) + else: + if current_count > self.heap[0][0]: + old_token = self.heap[0][1] + del self.pos[old_token] + self.heap[0] = (current_count, token_id) + self.token_ids[self.token_ids == old_token] = token_id + self.pos[token_id] = 0 + self._sift_down(0) + + def get_hot_token_ids(self): + return self.token_ids \ No newline at end of file From 2cc3e2bafa5c03e4b2eb137c8e28b4812b5fc31f Mon Sep 17 00:00:00 2001 From: zhousx Date: Tue, 25 Feb 2025 02:29:47 +0000 Subject: [PATCH 03/13] code formatting --- python/sglang/srt/models/llama_eagle.py | 2 +- python/sglang/srt/speculative/eagle_worker.py | 40 +++++++++++++----- .../sglang/srt/speculative/hot_table_utils.py | 41 +++++++++++++------ 3 files changed, 59 insertions(+), 24 deletions(-) diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index 9e4e57a33cd4..4f34e625e121 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -117,7 +117,7 @@ def __init__( if self.config.tie_word_embeddings: self.lm_head = self.model.embed_tokens else: - if hasattr(config, 'hot_vocab_size'): + if hasattr(config, "hot_vocab_size"): self.lm_head = ParallelLMHead( config.hot_vocab_size, config.hidden_size, quant_config=quant_config ) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index cb3d8c23c6e5..0dbbcbab32f0 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -24,8 +24,8 @@ fast_topk, select_top_k_tokens, ) -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.hot_table_utils import HotVocabTable +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm logger = logging.getLogger(__name__) @@ -50,13 +50,23 @@ def __init__( try: init_hot_token_ids = torch.load(server_args.speculative_token_map) except: - raise RuntimeError(f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}") - if server_args.speculative_token_map_num_dynamic_tokens > len(init_hot_token_ids): - server_args.speculative_token_map_num_dynamic_tokens = len(init_hot_token_ids) - self.hot_token_pool = HotVocabTable(init_hot_token_ids, server_args.speculative_token_map_num_dynamic_tokens) - server_args.json_model_override_args=f'{{"hot_vocab_size": {len(init_hot_token_ids)}}}' + raise RuntimeError( + f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}" + ) + if server_args.speculative_token_map_num_dynamic_tokens > len( + init_hot_token_ids + ): + server_args.speculative_token_map_num_dynamic_tokens = len( + init_hot_token_ids + ) + self.hot_token_pool = HotVocabTable( + init_hot_token_ids, server_args.speculative_token_map_num_dynamic_tokens + ) + server_args.json_model_override_args = ( + f'{{"hot_vocab_size": {len(init_hot_token_ids)}}}' + ) self.hot_token_ids = self.hot_token_pool.get_hot_token_ids() - + super().__init__( gpu_id=gpu_id, tp_rank=tp_rank, @@ -138,7 +148,9 @@ def forward_batch_speculative_generation(self, batch: ScheduleBatch): # Update lm_head self.hot_token_pool.add_token(batch.input_ids) self.hot_token_ids = self.hot_token_pool.get_hot_token_ids() - self.model_runner.model.lm_head.weight = self.target_worker_lm_haed.data[self.hot_token_ids] + self.model_runner.model.lm_head.weight = ( + self.target_worker_lm_haed.data[self.hot_token_ids] + ) spec_info: EagleVerifyInput = self.draft(batch) # Verify @@ -250,7 +262,11 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.hidden_states, ) - topk_index = self.hot_token_ids[topk_index] if self.hot_token_ids is not None else topk_index + topk_index = ( + self.hot_token_ids[topk_index] + if self.hot_token_ids is not None + else topk_index + ) # Return values score_list: List[torch.Tensor] = [] @@ -290,7 +306,11 @@ def draft_forward(self, forward_batch: ForwardBatch): ) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) - topk_index = self.hot_token_ids[topk_index] if self.hot_token_ids is not None else topk_index + topk_index = ( + self.hot_token_ids[topk_index] + if self.hot_token_ids is not None + else topk_index + ) hidden_states = logits_output.hidden_states return score_list, token_list, parents_list diff --git a/python/sglang/srt/speculative/hot_table_utils.py b/python/sglang/srt/speculative/hot_table_utils.py index 0ed48d39e301..4713ec71fdcb 100644 --- a/python/sglang/srt/speculative/hot_table_utils.py +++ b/python/sglang/srt/speculative/hot_table_utils.py @@ -1,17 +1,32 @@ import heapq from typing import List, Optional, Union + import torch + class HotVocabTable: def __init__(self, initial_tokens, num_dynamic_tokens=256): self.topk = len(initial_tokens) - self.heap = [(self.topk - i, token_id) for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens])] - self.heap.extend([(0, token_id) for i, token_id in enumerate(initial_tokens[-num_dynamic_tokens:])]) + self.heap = [ + (self.topk - i, token_id) + for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens]) + ] + self.heap.extend( + [ + (0, token_id) + for i, token_id in enumerate(initial_tokens[-num_dynamic_tokens:]) + ] + ) heapq.heapify(self.heap) - self.counters = {token_id: self.topk - i for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens])} - self.counters.update({token_id: 0 for token_id in initial_tokens[-num_dynamic_tokens:]}) + self.counters = { + token_id: self.topk - i + for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens]) + } + self.counters.update( + {token_id: 0 for token_id in initial_tokens[-num_dynamic_tokens:]} + ) self.pos = {token_id: idx for idx, (_, token_id) in enumerate(self.heap)} - self.token_ids = torch.tensor(initial_tokens, dtype=torch.int32, device='cuda') + self.token_ids = torch.tensor(initial_tokens, dtype=torch.int32, device="cuda") def _sift_up(self, i): while i > 0: @@ -22,11 +37,11 @@ def _sift_up(self, i): i = parent else: break - + def _sift_down(self, i): n = len(self.heap) while True: - left, right, smallest = 2*i+1, 2*i+2, i + left, right, smallest = 2 * i + 1, 2 * i + 2, i if left < n and self.heap[left][0] < self.heap[smallest][0]: smallest = left if right < n and self.heap[right][0] < self.heap[smallest][0]: @@ -36,11 +51,11 @@ def _sift_down(self, i): self.heap[i], self.heap[smallest] = self.heap[smallest], self.heap[i] self.pos[self.heap[i][1]], self.pos[self.heap[smallest][1]] = i, smallest i = smallest - + def add_token(self, token_ids: Union[torch.Tensor, List[torch.Tensor]]) -> None: if not isinstance(token_ids, list): token_ids = [token_ids] - + for t in token_ids: if t.dim() != 1: t = t.flatten() @@ -50,7 +65,7 @@ def add_token(self, token_ids: Union[torch.Tensor, List[torch.Tensor]]) -> None: def _add_to_heap(self, token_id): self.counters[token_id] = self.counters.get(token_id, 0) + 1 current_count = self.counters[token_id] - + if token_id in self.pos: idx = self.pos[token_id] old_count, _ = self.heap[idx] @@ -59,13 +74,13 @@ def _add_to_heap(self, token_id): self.heap[idx] = (current_count, token_id) self._sift_down(idx) if current_count > old_count else self._sift_up(idx) else: - if current_count > self.heap[0][0]: + if current_count > self.heap[0][0]: old_token = self.heap[0][1] del self.pos[old_token] self.heap[0] = (current_count, token_id) self.token_ids[self.token_ids == old_token] = token_id self.pos[token_id] = 0 self._sift_down(0) - + def get_hot_token_ids(self): - return self.token_ids \ No newline at end of file + return self.token_ids From e5b2bfa8f19dbe21f9d635eddf3089b07c30414e Mon Sep 17 00:00:00 2001 From: zhousx Date: Sat, 1 Mar 2025 03:19:40 +0000 Subject: [PATCH 04/13] revert dynamic lm_head --- python/sglang/srt/server_args.py | 7 -- python/sglang/srt/speculative/eagle_worker.py | 45 +++------- .../sglang/srt/speculative/hot_table_utils.py | 86 ------------------- 3 files changed, 13 insertions(+), 125 deletions(-) delete mode 100644 python/sglang/srt/speculative/hot_table_utils.py diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e91f6577c58c..296400fa6e76 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -129,7 +129,6 @@ class ServerArgs: speculative_eagle_topk: int = 8 speculative_num_draft_tokens: int = 64 speculative_token_map: Optional[str] = None - speculative_token_map_num_dynamic_tokens: int = 0 # Double Sparsity enable_double_sparsity: bool = False @@ -759,12 +758,6 @@ def add_cli_args(parser: argparse.ArgumentParser): help="The path of the draft model's small vocab table.", default=ServerArgs.speculative_token_map, ) - parser.add_argument( - "--speculative-token-map-num-dynamic-tokens", - type=int, - help="The number of hot tokens which change in runtime.", - default=ServerArgs.speculative_token_map_num_dynamic_tokens, - ) # Double Sparsity parser.add_argument( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 0dbbcbab32f0..cf343d40b179 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -24,7 +24,6 @@ fast_topk, select_top_k_tokens, ) -from sglang.srt.speculative.hot_table_utils import HotVocabTable from sglang.srt.speculative.spec_info import SpeculativeAlgorithm logger = logging.getLogger(__name__) @@ -48,24 +47,14 @@ def __init__( if server_args.speculative_token_map is not None: try: - init_hot_token_ids = torch.load(server_args.speculative_token_map) + self.hot_token_id = torch.load(server_args.speculative_token_map) except: raise RuntimeError( f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}" ) - if server_args.speculative_token_map_num_dynamic_tokens > len( - init_hot_token_ids - ): - server_args.speculative_token_map_num_dynamic_tokens = len( - init_hot_token_ids - ) - self.hot_token_pool = HotVocabTable( - init_hot_token_ids, server_args.speculative_token_map_num_dynamic_tokens - ) server_args.json_model_override_args = ( - f'{{"hot_vocab_size": {len(init_hot_token_ids)}}}' + f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' ) - self.hot_token_ids = self.hot_token_pool.get_hot_token_ids() super().__init__( gpu_id=gpu_id, @@ -90,14 +79,14 @@ def __init__( if not self.speculative_algorithm.is_nextn(): embed, head = self.target_worker.model_runner.model.get_embed_and_head() if server_args.speculative_token_map is not None: - self.target_worker_lm_haed = head - head_small = head.clone() - head_small.data = head_small.data[self.hot_token_ids] - self.model_runner.model.set_embed_and_head(embed, head_small) + head = head.clone() + self.hot_token_id = torch.tensor( + self.hot_token_id, dtype=torch.int32, device=head.device + ) + head.data = head.data[self.hot_token_id] else: - self.hot_token_pool = None - self.hot_token_ids = None - self.model_runner.model.set_embed_and_head(embed, head) + self.hot_token_id = None + self.model_runner.model.set_embed_and_head(embed, head) self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners @@ -144,13 +133,6 @@ def init_cuda_graphs(self): def forward_batch_speculative_generation(self, batch: ScheduleBatch): if batch.forward_mode.is_decode(): # Draft - if self.hot_token_pool is not None: - # Update lm_head - self.hot_token_pool.add_token(batch.input_ids) - self.hot_token_ids = self.hot_token_pool.get_hot_token_ids() - self.model_runner.model.lm_head.weight = ( - self.target_worker_lm_haed.data[self.hot_token_ids] - ) spec_info: EagleVerifyInput = self.draft(batch) # Verify @@ -261,10 +243,9 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.topk_index, spec_info.hidden_states, ) - topk_index = ( - self.hot_token_ids[topk_index] - if self.hot_token_ids is not None + self.hot_token_id[topk_index] + if self.hot_token_id is not None else topk_index ) @@ -307,8 +288,8 @@ def draft_forward(self, forward_batch: ForwardBatch): probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) topk_index = ( - self.hot_token_ids[topk_index] - if self.hot_token_ids is not None + self.hot_token_id[topk_index] + if self.hot_token_id is not None else topk_index ) hidden_states = logits_output.hidden_states diff --git a/python/sglang/srt/speculative/hot_table_utils.py b/python/sglang/srt/speculative/hot_table_utils.py deleted file mode 100644 index 4713ec71fdcb..000000000000 --- a/python/sglang/srt/speculative/hot_table_utils.py +++ /dev/null @@ -1,86 +0,0 @@ -import heapq -from typing import List, Optional, Union - -import torch - - -class HotVocabTable: - def __init__(self, initial_tokens, num_dynamic_tokens=256): - self.topk = len(initial_tokens) - self.heap = [ - (self.topk - i, token_id) - for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens]) - ] - self.heap.extend( - [ - (0, token_id) - for i, token_id in enumerate(initial_tokens[-num_dynamic_tokens:]) - ] - ) - heapq.heapify(self.heap) - self.counters = { - token_id: self.topk - i - for i, token_id in enumerate(initial_tokens[:-num_dynamic_tokens]) - } - self.counters.update( - {token_id: 0 for token_id in initial_tokens[-num_dynamic_tokens:]} - ) - self.pos = {token_id: idx for idx, (_, token_id) in enumerate(self.heap)} - self.token_ids = torch.tensor(initial_tokens, dtype=torch.int32, device="cuda") - - def _sift_up(self, i): - while i > 0: - parent = (i - 1) // 2 - if self.heap[i][0] < self.heap[parent][0]: - self.heap[i], self.heap[parent] = self.heap[parent], self.heap[i] - self.pos[self.heap[i][1]], self.pos[self.heap[parent][1]] = i, parent - i = parent - else: - break - - def _sift_down(self, i): - n = len(self.heap) - while True: - left, right, smallest = 2 * i + 1, 2 * i + 2, i - if left < n and self.heap[left][0] < self.heap[smallest][0]: - smallest = left - if right < n and self.heap[right][0] < self.heap[smallest][0]: - smallest = right - if smallest == i: - break - self.heap[i], self.heap[smallest] = self.heap[smallest], self.heap[i] - self.pos[self.heap[i][1]], self.pos[self.heap[smallest][1]] = i, smallest - i = smallest - - def add_token(self, token_ids: Union[torch.Tensor, List[torch.Tensor]]) -> None: - if not isinstance(token_ids, list): - token_ids = [token_ids] - - for t in token_ids: - if t.dim() != 1: - t = t.flatten() - for item in t: - self._add_to_heap(item.item()) - - def _add_to_heap(self, token_id): - self.counters[token_id] = self.counters.get(token_id, 0) + 1 - current_count = self.counters[token_id] - - if token_id in self.pos: - idx = self.pos[token_id] - old_count, _ = self.heap[idx] - if current_count == old_count: - return - self.heap[idx] = (current_count, token_id) - self._sift_down(idx) if current_count > old_count else self._sift_up(idx) - else: - if current_count > self.heap[0][0]: - old_token = self.heap[0][1] - del self.pos[old_token] - self.heap[0] = (current_count, token_id) - self.token_ids[self.token_ids == old_token] = token_id - self.pos[token_id] = 0 - self._sift_down(0) - - def get_hot_token_ids(self): - return self.token_ids From 16d71a755ac64e553e1569350b15dc9f3ab2904f Mon Sep 17 00:00:00 2001 From: zhousx Date: Sat, 1 Mar 2025 03:51:45 +0000 Subject: [PATCH 05/13] update docs. --- docs/backend/speculative_decoding.ipynb | 75 +++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index c7bd7fb31d02..afec9c4d51b7 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -135,6 +135,81 @@ "print_highlight(f\"Response: {response}\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n", + "\n", + "By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation.For more details checkout [this paper](https://arxiv.org/pdf/arXiv:2502.14856)\n", + "\n", + "Set `--speculative-token-map` to use this optimization. You can get the high-frequency token in FR-Spec from https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec. Or you can obtain high-frequency token by yourself.\n", + "+ Execute inference on your dataset using sglang's standard inference mode and persist the outputs.\n", + "+ Extract the top-k high-frequency tokens from the saved file. There is a reference implementation (https://gist.github.com/Zhou-sx/71a9196d2f324c93f79016579fdf57da). \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sglang.test.test_utils import is_in_ci\n", + "\n", + "if is_in_ci():\n", + " from patch import launch_server_cmd\n", + "else:\n", + " from sglang.utils import launch_server_cmd\n", + "\n", + "from sglang.utils import wait_for_server, print_highlight, terminate_process\n", + "\n", + "server_process, port = launch_server_cmd(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algo EAGLE \\\n", + " --speculative-draft lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 3 \\\n", + " --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --speculative-token-map {hot_token_ids.pt} \n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(f\"http://localhost:{port}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import openai\n", + "\n", + "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " messages=[\n", + " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", + " ],\n", + " temperature=0,\n", + " max_tokens=64,\n", + ")\n", + "\n", + "print_highlight(f\"Response: {response}\")" + ] + }, { "cell_type": "code", "execution_count": null, From 5fc7af520a854b8278e9b2efb0f43c59869cf16d Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sat, 1 Mar 2025 14:36:52 +0800 Subject: [PATCH 06/13] fix --- docs/backend/speculative_decoding.ipynb | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index afec9c4d51b7..4366ae6fb3e9 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -150,18 +150,11 @@ "source": [ "### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n", "\n", - "By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation.For more details checkout [this paper](https://arxiv.org/pdf/arXiv:2502.14856)\n", + "By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details checkout [this paper](https://arxiv.org/pdf/arXiv:2502.14856)\n", "\n", - "Set `--speculative-token-map` to use this optimization. You can get the high-frequency token in FR-Spec from https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec. Or you can obtain high-frequency token by yourself.\n", - "+ Execute inference on your dataset using sglang's standard inference mode and persist the outputs.\n", - "+ Extract the top-k high-frequency tokens from the saved file. There is a reference implementation (https://gist.github.com/Zhou-sx/71a9196d2f324c93f79016579fdf57da). \n" + "Set `--speculative-token-map` to use this optimization. You can get the high-frequency token in FR-Spec from https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec. Or you can obtain high-frequency token by yourself (https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [] - }, { "cell_type": "code", "execution_count": null, @@ -180,8 +173,8 @@ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", "python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 3 \\\n", - " --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --speculative-token-map {hot_token_ids.pt} \n", + " --speculative-draft lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map {hot_token_ids.pt} \n", "\"\"\"\n", ")\n", "\n", From 6c738ead3b9665cf84f38ab037343a25bb4874c3 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sat, 1 Mar 2025 15:03:13 +0800 Subject: [PATCH 07/13] fix model path --- docs/backend/speculative_decoding.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 4366ae6fb3e9..011d8030b6a0 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -172,7 +172,7 @@ "\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-3.1-8B-Instruct --speculative-algo EAGLE \\\n", + "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algo EAGLE \\\n", " --speculative-draft lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map {hot_token_ids.pt} \n", "\"\"\"\n", @@ -192,7 +192,7 @@ "client = openai.Client(base_url=f\"http://127.0.0.1:{port}/v1\", api_key=\"None\")\n", "\n", "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", + " model=\"meta-llama/Meta-Llama-3-8B-Instruct\",\n", " messages=[\n", " {\"role\": \"user\", \"content\": \"List 3 countries and their capitals.\"},\n", " ],\n", From d7f8f9e268fc7b1e93f798873d17dcd89be08010 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sun, 2 Mar 2025 11:29:58 +0800 Subject: [PATCH 08/13] refactor --- python/sglang/srt/speculative/eagle_worker.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index cf343d40b179..f360e7f389fc 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -243,11 +243,8 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.topk_index, spec_info.hidden_states, ) - topk_index = ( - self.hot_token_id[topk_index] - if self.hot_token_id is not None - else topk_index - ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] # Return values score_list: List[torch.Tensor] = [] @@ -287,11 +284,8 @@ def draft_forward(self, forward_batch: ForwardBatch): ) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) - topk_index = ( - self.hot_token_id[topk_index] - if self.hot_token_id is not None - else topk_index - ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] hidden_states = logits_output.hidden_states return score_list, token_list, parents_list From 111ca8644b5920ee6ae2496471437e55ebea91cc Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sun, 2 Mar 2025 16:18:51 +0800 Subject: [PATCH 09/13] support downloading speculative-token-map from hf --- python/sglang/srt/speculative/eagle_worker.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index f360e7f389fc..e3a1f9792262 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -3,6 +3,8 @@ from typing import List, Optional, Union import torch +import os +from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import Req, ScheduleBatch @@ -46,12 +48,12 @@ def __init__( server_args.disable_cuda_graph = True if server_args.speculative_token_map is not None: - try: + if os.path.exists(server_args.speculative_token_map): self.hot_token_id = torch.load(server_args.speculative_token_map) - except: - raise RuntimeError( - f"there is not hot_token_ids.pt file in {self.server_args.speculative_token_map}" - ) + else: + cache_dir = snapshot_download(os.path.dirname(server_args.speculative_token_map), ignore_patterns=["*.bin", "*.safetensors"]) + file_path = os.path.join(cache_dir, os.path.basename(server_args.speculative_token_map)) + self.hot_token_id = torch.load(file_path) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' ) From 1a5aad837cd387c5dea25002bd5ea0e33b383d27 Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sun, 2 Mar 2025 16:21:09 +0800 Subject: [PATCH 10/13] fix doc --- docs/backend/server_arguments.md | 1 + docs/backend/speculative_decoding.ipynb | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 7a614b61a954..db8bb9514d82 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -146,6 +146,7 @@ Please consult the documentation below to learn more about the parameters you ma * `speculative_num_steps`: How many draft passes we run before verifying. * `speculative_num_draft_tokens`: The number of tokens proposed in a draft. * `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). +* `speculative_token_map`: Optional, the path to the high frequency token list of [FR-Spec](https://arxiv.org/html/2502.14856v1), used for accelerating [Eagle](https://arxiv.org/html/2406.16858v1). ## Double Sparsity diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 011d8030b6a0..5f344e03afb0 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -26,7 +26,7 @@ "source": [ "## EAGLE Decoding\n", "\n", - "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft`) and the relevant EAGLE parameters:" + "To enable EAGLE-based speculative decoding, specify the draft model (`--speculative-draft-model-path`) and the relevant EAGLE parameters:" ] }, { @@ -46,8 +46,8 @@ "\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64\n", "\"\"\"\n", ")\n", @@ -103,8 +103,8 @@ "source": [ "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", + "python3 -m sglang.launch_server --model meta-llama/Llama-2-7b-chat-hf --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 5 \\\n", " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --mem-fraction 0.6 \\\n", " --enable-torch-compile --cuda-graph-max-bs 2\n", "\"\"\"\n", @@ -172,9 +172,9 @@ "\n", "server_process, port = launch_server_cmd(\n", " \"\"\"\n", - "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algo EAGLE \\\n", - " --speculative-draft lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", - " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map {hot_token_ids.pt} \n", + "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n", + " --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \n", "\"\"\"\n", ")\n", "\n", From cd9b2a753358a8abe8dfae4918120a47d78983ac Mon Sep 17 00:00:00 2001 From: Achazwl <323163497@qq.com> Date: Sun, 2 Mar 2025 17:28:52 +0800 Subject: [PATCH 11/13] fix doc --- docs/backend/speculative_decoding.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index 5f344e03afb0..d8397bd87090 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -174,7 +174,8 @@ " \"\"\"\n", "python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3-8B-Instruct --speculative-algorithm EAGLE \\\n", " --speculative-draft-model-path lmzheng/sglang-EAGLE-LLaMA3-Instruct-8B --speculative-num-steps 5 \\\n", - " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \n", + " --speculative-eagle-topk 8 --speculative-num-draft-tokens 64 --speculative-token-map thunlp/LLaMA3-Instruct-8B-FR-Spec/freq_32768.pt \\\n", + " --mem-fraction 0.7 --cuda-graph-max-bs 2 --dtype float16 \n", "\"\"\"\n", ")\n", "\n", From a811ae7a79161614f727288a54c8467a25999851 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Sun, 2 Mar 2025 19:57:30 +0000 Subject: [PATCH 12/13] fix docs link --- docs/backend/speculative_decoding.ipynb | 4 ++-- python/sglang/srt/speculative/eagle_worker.py | 11 ++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/docs/backend/speculative_decoding.ipynb b/docs/backend/speculative_decoding.ipynb index d8397bd87090..87f0328be6a8 100644 --- a/docs/backend/speculative_decoding.ipynb +++ b/docs/backend/speculative_decoding.ipynb @@ -150,9 +150,9 @@ "source": [ "### EAGLE Decoding via Frequency-Ranked Speculative Sampling\n", "\n", - "By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces lm_head computational overhead while accelerating the pipeline without quality degradation. For more details checkout [this paper](https://arxiv.org/pdf/arXiv:2502.14856)\n", + "Thanks for the contribution from THUNLP. By employing a truncated high-frequency token vocabulary in the draft model, Eagle speculative decoding reduces `lm_head` computational overhead while accelerating the pipeline without quality degradation. For more details, checkout [the paper](https://arxiv.org/pdf/arXiv:2502.14856).\n", "\n", - "Set `--speculative-token-map` to use this optimization. You can get the high-frequency token in FR-Spec from https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec. Or you can obtain high-frequency token by yourself (https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n" + "In our implementation, set `--speculative-token-map` to enable the optimization. You can get the high-frequency token in FR-Spec from [this model](https://huggingface.co/thunlp/LLaMA3-Instruct-8B-FR-Spec). Or you can obtain high-frequency token by directly downloading these token from [this repo](https://github.com/thunlp/FR-Spec/tree/main?tab=readme-ov-file#prepare-fr-spec-vocabulary-subset).\n" ] }, { diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index e3a1f9792262..c4319b15b3e6 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -1,9 +1,9 @@ import logging +import os import time from typing import List, Optional, Union import torch -import os from huggingface_hub import snapshot_download from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -51,8 +51,13 @@ def __init__( if os.path.exists(server_args.speculative_token_map): self.hot_token_id = torch.load(server_args.speculative_token_map) else: - cache_dir = snapshot_download(os.path.dirname(server_args.speculative_token_map), ignore_patterns=["*.bin", "*.safetensors"]) - file_path = os.path.join(cache_dir, os.path.basename(server_args.speculative_token_map)) + cache_dir = snapshot_download( + os.path.dirname(server_args.speculative_token_map), + ignore_patterns=["*.bin", "*.safetensors"], + ) + file_path = os.path.join( + cache_dir, os.path.basename(server_args.speculative_token_map) + ) self.hot_token_id = torch.load(file_path) server_args.json_model_override_args = ( f'{{"hot_vocab_size": {len(self.hot_token_id)}}}' From fc7ec150cdf68070888be5522247353127a958c4 Mon Sep 17 00:00:00 2001 From: zhaochenyang20 Date: Mon, 3 Mar 2025 00:30:34 +0000 Subject: [PATCH 13/13] fix lint and docs --- python/sglang/srt/speculative/eagle_worker.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index c4319b15b3e6..7639bd999870 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -94,6 +94,12 @@ def __init__( else: self.hot_token_id = None self.model_runner.model.set_embed_and_head(embed, head) + else: + if server_args.speculative_token_map is not None: + raise NotImplementedError( + "NEXTN does not support speculative-token-map now" + ) + self.hot_token_id = None self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph # Create multi-step attn backends and cuda graph runners