Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 67 additions & 21 deletions vllm/v1/sample/logits_processor/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,16 @@ def apply_with_spec_decode(


class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor):
"""Limits the number of tokens allowed inside a 'thinking' section."""
"""Limits the number of tokens allowed inside a 'thinking' section.

Includes a soft budget zone over the last 30% of the token budget.
Instead of a hard cut, the logit for the end-of-thinking token is
progressively boosted relative to the model's own logit distribution,
encouraging a natural stopping point before the hard force at 100%.
"""

# Fraction of budget where the soft zone begins (0.7 = last 30%).
_SOFT_ZONE_START_FRAC = 0.7

def __init__(
self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool
Expand All @@ -316,6 +325,7 @@ def __init__(
# Key: request_index, Value: state dict containing:
# "in_think": bool - currently in thinking mode
# "in_end": bool - currently forcing end tokens output
# "soft_progress": float - 0..1 ramp through soft budget zone
# "check_count_down": int - steps remaining until next think
# start/end token parsing
# "think_count": int - number of thinking tokens generated
Expand All @@ -331,6 +341,10 @@ def __init__(
self.force_token_ids = torch.full(
(max_num_reqs,), -1, dtype=torch.long, device=device
)
# Pre-built tensor for end token IDs (used in apply).
self._end_ids = torch.tensor(
self.reasoning_end_token_ids, device=device, dtype=torch.long
) if self.reasoning_end_token_ids else None

@staticmethod
def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int:
Expand Down Expand Up @@ -373,8 +387,9 @@ def _init_state_entry(
think_count = 0

return {
"in_think": in_think, # Currently in thinking mode
"in_think": in_think,
"in_end": in_think and thinking_token_budget == 0,
"soft_progress": 0.0, # 0..1 ramp through the soft zone
"check_count_down": thinking_token_budget,
"think_count": think_count, # Number of tokens in thinking section
"end_count": 0, # Number of end tokens forced so far
Expand All @@ -387,24 +402,33 @@ def _init_state_entry(

def _update_think_state(self, state: dict[str, Any]):
"""Updates the state based on newly generated output tokens."""
if not state.get("in_end", False) and state.get("check_count_down", 0) > 0:
# Skip the countdown optimisation for budgeted requests so that
# think_count is updated every token (needed for the soft zone).
has_budget = state.get("thinking_token_budget", 0) > 0
if not state.get("in_end") and state.get("check_count_down", 0) > 0 and not has_budget:
state["check_count_down"] -= 1
return

output = state.get("output_tok_ids", [])
if not output:
return

# Track previous output length for incremental processing
prev_length = state.get("prev_output_length", 0)
current_length = len(output)

if current_length <= prev_length:
return

# Process only newly added tokens
new_tokens = output[prev_length:]
state["prev_output_length"] = current_length
# Strip trailing -1 sentinels that vLLM v1 appends as
# placeholders before the sampler fills in the real token.
effective_length = current_length
while effective_length > prev_length and output[effective_length - 1] == -1:
effective_length -= 1
if effective_length <= prev_length:
return

new_tokens = output[prev_length:effective_length]
state["prev_output_length"] = effective_length

# Check if new tokens contain think start or end sequences
start_len = len(self.reasoning_start_token_ids)
Expand Down Expand Up @@ -435,6 +459,7 @@ def _update_think_state(self, state: dict[str, Any]):
else:
# Case: ...<start>...<end>... - exiting think mode
state["in_think"] = False
state["soft_progress"] = 0.0
state["think_count"] = 0
elif recent_start_pos >= 0:
# Found think start - entering think mode
Expand All @@ -445,29 +470,37 @@ def _update_think_state(self, state: dict[str, Any]):
elif recent_end_pos >= 0:
# Found think end - exiting think mode
state["in_think"] = False
state["soft_progress"] = 0.0
state["think_count"] = 0
elif state["in_think"]:
# Continue thinking mode, increment count by new tokens
state["think_count"] += len(new_tokens)

# Set countdown based on current state
if state["in_think"]:
remaining_budget = max(
0, state["thinking_token_budget"] - state["think_count"]
)
state["check_count_down"] = max(0, remaining_budget - 1)
else:
state["check_count_down"] = state["thinking_token_budget"]
# Compute soft zone boundaries and transition state.
budget = state["thinking_token_budget"]
count = state["think_count"]
soft_start = int(budget * self._SOFT_ZONE_START_FRAC)

# Check if need to transition to end mode
if (
state["in_think"]
and state["think_count"] >= state["thinking_token_budget"]
):
if state["in_think"] and count >= budget:
# Hard force -- safety net if soft zone didn't work.
state["in_think"] = False
state["in_end"] = True
state["soft_progress"] = 0.0
state["end_count"] = 0
state["check_count_down"] = state["thinking_token_budget"]
state["check_count_down"] = budget
elif state["in_think"] and budget > 0 and count >= soft_start:
# Soft zone: ramp progress from 0.0 to 1.0.
state["soft_progress"] = (
(count - soft_start) / max(1, budget - soft_start))
state["check_count_down"] = 0
elif state["in_think"]:
# Before soft zone.
state["soft_progress"] = 0.0
state["check_count_down"] = max(
0, soft_start - count - 1)
else:
state["soft_progress"] = 0.0
state["check_count_down"] = budget
else:
# In end mode
state["end_count"] += 1
Expand Down Expand Up @@ -535,6 +568,19 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
self.force_token_ids[i] = self.reasoning_end_token_ids[
state["end_count"]
]
elif state and state.get("soft_progress", 0) > 0 \
and self._end_ids is not None:
# Adaptive soft bias: boost </think> relative to the
# current gap between the top logit and the end token.
# progress=0% -> target = end_logit (no change)
# progress=50% -> target = top_logit (equal)
# progress=100%-> target = top + gap (dominates)
top_logit = logits[i].max().item()
end_logit = logits[i, self._end_ids[0]].item()
gap = max(top_logit - end_logit, 1.0)
target = end_logit + 2.0 * gap * state["soft_progress"]
logits[i, self._end_ids] = torch.clamp(
logits[i, self._end_ids], min=target)

# Check in CPU first not to sync with GPU
has_active_thinking = any(
Expand Down
Loading