diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index cf343d40b179..f360e7f389fc 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -243,11 +243,8 @@ def draft_forward(self, forward_batch: ForwardBatch): spec_info.topk_index, spec_info.hidden_states, ) - topk_index = ( - self.hot_token_id[topk_index] - if self.hot_token_id is not None - else topk_index - ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] # Return values score_list: List[torch.Tensor] = [] @@ -287,11 +284,8 @@ def draft_forward(self, forward_batch: ForwardBatch): ) probs = torch.softmax(logits_output.next_token_logits, dim=-1) topk_p, topk_index = fast_topk(probs, self.topk, dim=-1) - topk_index = ( - self.hot_token_id[topk_index] - if self.hot_token_id is not None - else topk_index - ) + if self.hot_token_id is not None: + topk_index = self.hot_token_id[topk_index] hidden_states = logits_output.hidden_states return score_list, token_list, parents_list