Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/backend/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Please refer to our dedicated guide on [constrained decoding](./structured_outpu
| ignore_eos | `bool = False` | Don't stop generation when EOS token is sampled. |
| skip_special_tokens | `bool = True` | Remove special tokens during decoding. |
| custom_params | `Optional[List[Optional[Dict[str, Any]]]] = None` | Used when employing `CustomLogitProcessor`. For usage, see below. |
| thinking_budget | `Optional[int] = None` | The maximum number of reasoning tokens that can be generated for a request. |

## Examples

Expand Down Expand Up @@ -296,3 +297,29 @@ response = requests.post(
)
print(response.json())
```

### Thinking Budget

Launch a server with `--reasoning-parser`.

```bash
python3 -m sglang.launch_server --model Qwen/Qwen3-8B --reasoning-parser qwen3
```

Send a request:

```python
import requests
response = requests.post(
"http://localhost:30000/generate",
json={
"text": "9.11 and 9.8, which is greater?",
"sampling_params": {
"temperature": 0.3,
"max_new_tokens": 256,
"thinking_budget": 20,
},
},
)
print(response.json())
```
6 changes: 5 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,9 @@ def sample(
[self.sample(values, forward_batch) for values in logits_output],
axis=-1,
)

sampling_info = forward_batch.sampling_info
if sampling_info.thinking_budgets is not None:
sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
self._preprocess_logits(logits_output, forward_batch.sampling_info)

# Sample the next tokens
Expand All @@ -1157,6 +1159,8 @@ def sample(
forward_batch.top_logprobs_nums,
forward_batch.token_ids_logprobs,
)
if sampling_info.thinking_budgets is not None:
sampling_info.update_thinking_budgets(next_token_ids)
return next_token_ids

@property
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def v1_generate_request(
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"thinking_budget": request.thinking_budget,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
Expand Down Expand Up @@ -1127,6 +1128,7 @@ def v1_chat_generate_request(
"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"thinking_budget": request.thinking_budget,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems only valid when enable_thinking = True. Need to do the validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 1c741a6

"stop": stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
thinking_budget: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be used in non-chat scenario? Is there any usage example/reference for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with CompletionRequest. Removed thinking_budget in it 186b1aa

json_schema: Optional[str] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
Expand Down Expand Up @@ -378,6 +379,7 @@ def set_tool_choice_default(cls, values):
top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
thinking_budget: Optional[int] = None
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def detect_and_parse(self, text: str) -> StreamingParseResult:
One-time parsing: Detects and parses reasoning sections in the provided text.
Returns both reasoning content and normal text separately.
"""
text = text.replace(self.think_start_token, "").strip()
text = text.replace(self.think_start_token, "")
if self.think_end_token not in text:
# Assume reasoning was truncated before `</think>` token
return StreamingParseResult(reasoning_text=text)
Expand Down Expand Up @@ -73,7 +73,7 @@ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
normal_text = current_text[end_idx + len(self.think_end_token) :]

return StreamingParseResult(
normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
normal_text=normal_text, reasoning_text=reasoning_text
)

# Continue with reasoning content
Expand Down
56 changes: 54 additions & 2 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ class SamplingBatchInfo:
# Whether any request needs min_p sampling
need_min_p_sampling: bool

# Use thinking_budget to truncate thinking
num_thinking_tokens: Optional[torch.Tensor] = None
think_end_ids: Optional[torch.Tensor] = None
thinking_budgets: Optional[torch.Tensor] = None

# Masking tensors for grammar-guided structured outputs
vocab_size: int
vocab_size: int = 0
grammars: Optional[List] = None
vocab_mask: Optional[torch.Tensor] = None
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
Expand Down Expand Up @@ -76,7 +81,22 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)

if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-reasoning model, we can skip this check? Do we need to identify if the model is a reasoning model from the architect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about it. Now I don't know how to decide whether a model is reasoning from a request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can think of, there seems to be no better way. The information about can be obtained here is very limited. It is just doing sampling and does not know the specific model architecture.

think_end_ids = torch.tensor(
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
dtype=torch.int64,
).to(device, non_blocking=True)
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
device, non_blocking=True
)
thinking_budgets = torch.tensor(
[r.sampling_params.thinking_budget or -1 for r in reqs],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When thinking_budget=0, this will be assigned a value of -1. It is recommended to modify this, otherwise it will not be fully supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated 200c598
thanks for pointing out

dtype=torch.int64,
).to(device, non_blocking=True)
else:
think_end_ids = None
num_thinking_tokens = None
thinking_budgets = None
# Check if any request has custom logit processor
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
Expand Down Expand Up @@ -132,6 +152,9 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
think_end_ids=think_end_ids,
num_thinking_tokens=num_thinking_tokens,
thinking_budgets=thinking_budgets,
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
Expand All @@ -146,6 +169,35 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
def __len__(self):
return len(self.temperatures)

def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
has_budget = self.thinking_budgets > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't support =0 scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, now we support thinking_budget=0 in 30aa15f

if not has_budget.any():
return
torch.where(
has_budget,
self.num_thinking_tokens + 1,
self.num_thinking_tokens,
out=self.num_thinking_tokens,
)
should_stop = has_budget & (
self.num_thinking_tokens - 1 > self.thinking_budgets
)
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
if len(batch_indices) > 0:
end_token_indices = self.think_end_ids[batch_indices]
next_token_logits[batch_indices, end_token_indices] = 0.0

def update_thinking_budgets(self, next_token_ids: torch.Tensor):
if not torch.any(self.thinking_budgets > 0):
return
torch.where(
next_token_ids == self.think_end_ids,
torch.tensor(-1, device=self.thinking_budgets.device),
self.thinking_budgets,
out=self.thinking_budgets,
)

def update_regex_vocab_mask(self):
if not self.grammars:
self.vocab_mask = None
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class SamplingParams:
def __init__(
self,
max_new_tokens: int = 128,
thinking_budget: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
Expand Down Expand Up @@ -58,6 +59,7 @@ def __init__(
self.stop_token_ids = set(stop_token_ids)
else:
self.stop_token_ids = None
self.thinking_budget = thinking_budget
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down
1 change: 1 addition & 0 deletions test/srt/run_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class TestFile:
TestFile("test_radix_attention.py", 167),
TestFile("test_reasoning_content.py", 89),
TestFile("test_enable_thinking.py", 70),
TestFile("test_thinking_budget.py", 60),
TestFile("test_regex_constrained.py", 64),
TestFile("test_release_memory_occupation.py", 44),
TestFile("test_request_length_validation.py", 31),
Expand Down
95 changes: 95 additions & 0 deletions test/srt/test_thinking_budget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Usage:
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_20
python3 -m unittest test_thinking_budget.TestThinkingBudget.test_chat_completion_with_thinking_budget_200
"""

import unittest

import requests
from transformers import AutoTokenizer

from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)


class TestThinkingBudget(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen3-8B"
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.base_url = DEFAULT_URL_FOR_TEST
cls.api_key = "sk-1234"
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key,
other_args=[
"--reasoning-parser",
"qwen3",
],
)

@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)

def test_chat_completion_with_thinking_budget_20(self):
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
"thinking_budget": 20,
},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
tokens = self.tokenizer.encode(reasoning_content)
self.assertEqual(
len(tokens),
20,
f"Reasoning content length: {len(tokens)} not equal to 20, tokens: {tokens}, reasoning_content: {reasoning_content}",
)

def test_chat_completion_with_thinking_budget_200(self):
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.api_key}"},
json={
"model": self.model,
"messages": [
{"role": "user", "content": "9.11 and 9.8, which is greater?"}
],
"temperature": 0,
"separate_reasoning": True,
"chat_template_kwargs": {"enable_thinking": True},
"thinking_budget": 200,
},
)
self.assertEqual(response.status_code, 200, f"Failed with: {response.text}")
data = response.json()
reasoning_content = data["choices"][0]["message"]["reasoning_content"]
tokens = self.tokenizer.encode(reasoning_content)
self.assertEqual(
len(tokens),
200,
f"Reasoning content length {len(tokens)} not equal to 200, tokens: {tokens}, reasoning_content: {reasoning_content}",
)


if __name__ == "__main__":
unittest.main()