-
Notifications
You must be signed in to change notification settings - Fork 3.4k
feat: add thinking_budget (version 2) #6208
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 3 commits
a16ac99
6c2ecdb
14fe85c
6aca64c
186b1aa
1c741a6
30aa15f
200c598
1c99b26
8830b06
6a56bfe
40defe9
ebb710b
3819825
7a56516
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel): | |
| top_k: int = -1 | ||
| min_p: float = 0.0 | ||
| min_tokens: int = 0 | ||
| thinking_budget: Optional[int] = None | ||
|
||
| json_schema: Optional[str] = None | ||
| regex: Optional[str] = None | ||
| ebnf: Optional[str] = None | ||
|
|
@@ -378,6 +379,7 @@ def set_tool_choice_default(cls, values): | |
| top_k: int = -1 | ||
| min_p: float = 0.0 | ||
| min_tokens: int = 0 | ||
| thinking_budget: Optional[int] = None | ||
| regex: Optional[str] = None | ||
| ebnf: Optional[str] = None | ||
| repetition_penalty: float = 1.0 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,8 +30,13 @@ class SamplingBatchInfo: | |
| # Whether any request needs min_p sampling | ||
| need_min_p_sampling: bool | ||
|
|
||
| # Use thinking_budget to truncate thinking | ||
| num_thinking_tokens: Optional[torch.Tensor] = None | ||
| think_end_ids: Optional[torch.Tensor] = None | ||
| thinking_budgets: Optional[torch.Tensor] = None | ||
|
|
||
| # Masking tensors for grammar-guided structured outputs | ||
| vocab_size: int | ||
| vocab_size: int = 0 | ||
| grammars: Optional[List] = None | ||
| vocab_mask: Optional[torch.Tensor] = None | ||
| apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None | ||
|
|
@@ -76,7 +81,22 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): | |
| min_ps = torch.tensor( | ||
| [r.sampling_params.min_p for r in reqs], dtype=torch.float | ||
| ).to(device, non_blocking=True) | ||
|
|
||
| if any(hasattr(r.tokenizer, "think_end_id") for r in reqs): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For non-reasoning model, we can skip this check? Do we need to identify if the model is a reasoning model from the architect?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am thinking about it. Now I don't know how to decide whether a model is reasoning from a request.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As far as I can think of, there seems to be no better way. The information about can be obtained here is very limited. It is just doing sampling and does not know the specific model architecture. |
||
| think_end_ids = torch.tensor( | ||
| [getattr(r.tokenizer, "think_end_id", -1) for r in reqs], | ||
| dtype=torch.int64, | ||
| ).to(device, non_blocking=True) | ||
| num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to( | ||
| device, non_blocking=True | ||
| ) | ||
| thinking_budgets = torch.tensor( | ||
| [r.sampling_params.thinking_budget or -1 for r in reqs], | ||
|
||
| dtype=torch.int64, | ||
| ).to(device, non_blocking=True) | ||
| else: | ||
| think_end_ids = None | ||
| num_thinking_tokens = None | ||
| thinking_budgets = None | ||
| # Check if any request has custom logit processor | ||
| has_custom_logit_processor = ( | ||
| batch.enable_custom_logit_processor # check the flag first. | ||
|
|
@@ -132,6 +152,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): | |
| top_ps=top_ps, | ||
| top_ks=top_ks, | ||
| min_ps=min_ps, | ||
| think_end_ids=think_end_ids, | ||
| num_thinking_tokens=num_thinking_tokens, | ||
| thinking_budgets=thinking_budgets, | ||
| is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), | ||
| need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), | ||
| vocab_size=vocab_size, | ||
|
|
@@ -146,6 +169,35 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): | |
| def __len__(self): | ||
| return len(self.temperatures) | ||
|
|
||
| def apply_thinking_budgets(self, next_token_logits: torch.Tensor): | ||
| has_budget = self.thinking_budgets > 0 | ||
ispobock marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if not has_budget.any(): | ||
| return | ||
| torch.where( | ||
| has_budget, | ||
| self.num_thinking_tokens + 1, | ||
| self.num_thinking_tokens, | ||
| out=self.num_thinking_tokens, | ||
| ) | ||
| should_stop = has_budget & ( | ||
| self.num_thinking_tokens - 1 > self.thinking_budgets | ||
| ) | ||
| next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf")) | ||
| batch_indices = torch.nonzero(should_stop, as_tuple=True)[0] | ||
| if len(batch_indices) > 0: | ||
| end_token_indices = self.think_end_ids[batch_indices] | ||
| next_token_logits[batch_indices, end_token_indices] = 0.0 | ||
|
|
||
| def update_thinking_budgets(self, next_token_ids: torch.Tensor): | ||
| if not torch.any(self.thinking_budgets > 0): | ||
ispobock marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return | ||
| torch.where( | ||
| next_token_ids == self.think_end_ids, | ||
| torch.tensor(-1, device=self.thinking_budgets.device), | ||
| self.thinking_budgets, | ||
| out=self.thinking_budgets, | ||
| ) | ||
|
|
||
| def update_regex_vocab_mask(self): | ||
| if not self.grammars: | ||
| self.vocab_mask = None | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| """ | ||
| Usage: | ||
| python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20 | ||
| python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200 | ||
| """ | ||
|
|
||
| import unittest | ||
|
|
||
| import requests | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from sglang.srt.utils import kill_process_tree | ||
| from sglang.test.test_utils import ( | ||
| DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
| DEFAULT_URL_FOR_TEST, | ||
| CustomTestCase, | ||
| popen_launch_server, | ||
| ) | ||
|
|
||
|
|
||
| class TestThinkingBudget(CustomTestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| cls.model = "Qwen/Qwen3-8B" | ||
| cls.tokenizer = AutoTokenizer.from_pretrained(cls.model) | ||
| cls.base_url = DEFAULT_URL_FOR_TEST | ||
| cls.api_key = "sk-1234" | ||
| cls.process = popen_launch_server( | ||
| cls.model, | ||
| cls.base_url, | ||
| timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, | ||
| api_key=cls.api_key, | ||
| other_args=[ | ||
| "--reasoning-parser", | ||
| "qwen3", | ||
| ], | ||
| ) | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| kill_process_tree(cls.process.pid) | ||
|
|
||
| def test_chat_completion_with_thinking_budget_20(self): | ||
| response = requests.post( | ||
| f"{self.base_url}/v1/chat/completions", | ||
| headers={"Authorization": f"Bearer {self.api_key}"}, | ||
| json={ | ||
| "model": self.model, | ||
| "messages": [ | ||
| {"role": "user", "content": "9.11 and 9.8, which is greater?"} | ||
| ], | ||
| "temperature": 0, | ||
| "separate_reasoning": True, | ||
| "chat_template_kwargs": {"enable_thinking": True}, | ||
| "thinking_budget": 20, | ||
| }, | ||
| ) | ||
| self.assertEqual(response.status_code, 200, f"Failed with: {response.text}") | ||
| data = response.json() | ||
| reasoning_content = data["choices"][0]["message"]["reasoning_content"] | ||
| tokens = self.tokenizer.encode(reasoning_content) | ||
| self.assertEqual( | ||
| len(tokens), | ||
| 20, | ||
| f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}", | ||
| ) | ||
|
|
||
| def test_chat_completion_with_thinking_budget_200(self): | ||
| response = requests.post( | ||
| f"{self.base_url}/v1/chat/completions", | ||
| headers={"Authorization": f"Bearer {self.api_key}"}, | ||
| json={ | ||
| "model": self.model, | ||
| "messages": [ | ||
| {"role": "user", "content": "9.11 and 9.8, which is greater?"} | ||
| ], | ||
| "temperature": 0, | ||
| "separate_reasoning": True, | ||
| "chat_template_kwargs": {"enable_thinking": True}, | ||
| "thinking_budget": 200, | ||
| }, | ||
| ) | ||
| self.assertEqual(response.status_code, 200, f"Failed with: {response.text}") | ||
| data = response.json() | ||
| reasoning_content = data["choices"][0]["message"]["reasoning_content"] | ||
| tokens = self.tokenizer.encode(reasoning_content) | ||
| self.assertEqual( | ||
| len(tokens), | ||
| 200, | ||
| f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}", | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems only valid when
enable_thinking = True. Need to do the validation.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in 1c741a6