Skip to content

Conversation

@thyecust
Copy link
Contributor

@thyecust thyecust commented May 12, 2025

Motivation

See #6089

Modifications

Removed incorrect diff in protocol.py. See commit version 2.

Checklist

@thyecust thyecust changed the title Thinking budget2 feat: add thinking_budget (version 2) May 12, 2025
@thyecust
Copy link
Contributor Author

@zhyncs May you provide more information about #6181 (comment) ? I cannot reproduce that.

@thyecust
Copy link
Contributor Author

12 results - 4 files

sglang • python/sglang/srt/model_executor/model_runner.py:
   ...

sglang • python/sglang/srt/openai_api/adapter.py:
   557                  "min_new_tokens": request.min_tokens,
   558:                 "thinking_budget": request.thinking_budget,
   559                  "stop": request.stop,

  1130              "min_new_tokens": request.min_tokens,
  1131:             "thinking_budget": request.thinking_budget,
  1132              "stop": stop,

sglang • python/sglang/srt/sampling/sampling_batch_info.py:
   ...

sglang • python/sglang/srt/sampling/sampling_params.py:
   ...

There are only two references of request.thinking_budget, one for ChatCompletionRequest and one for CompletionRequest. The definitions is in protocol.py. So I think they are safe.

@thyecust
Copy link
Contributor Author

@minleminzui Could you help review this? thx

@minleminzui minleminzui self-assigned this May 12, 2025
@minleminzui
Copy link
Collaborator

12 results - 4 files

sglang • python/sglang/srt/model_executor/model_runner.py:
   ...

sglang • python/sglang/srt/openai_api/adapter.py:
   557                  "min_new_tokens": request.min_tokens,
   558:                 "thinking_budget": request.thinking_budget,
   559                  "stop": request.stop,

  1130              "min_new_tokens": request.min_tokens,
  1131:             "thinking_budget": request.thinking_budget,
  1132              "stop": stop,

sglang • python/sglang/srt/sampling/sampling_batch_info.py:
   ...

sglang • python/sglang/srt/sampling/sampling_params.py:
   ...

There are only two references of request.thinking_budget, one for ChatCompletionRequest and one for CompletionRequest. The definitions is in protocol.py. So I think they are safe.

@CatherineSue, @ispobock, @sleepcoo please help to review this pr, please ensure it doesn't affect your internal services.

Copy link
Collaborator

@CatherineSue CatherineSue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we allow user to pass this parameter with a non-thinking model? I think it is better to isolate this from non-thinking model to prevent potential effects in the future

@thyecust
Copy link
Contributor Author

thyecust commented May 12, 2025

It seems we allow user to pass this parameter with a non-thinking model? I think it is better to isolate this from non-thinking model to prevent potential effects in the future

in sampling_batch_info.py

        if any(hasattr(r.tokenizer, "think_end_id") for r in reqs): 
             think_end_ids = torch.tensor( 
                 [getattr(r.tokenizer, "think_end_id", -1) for r in reqs], 
                 dtype=torch.int64, 
             ).to(device, non_blocking=True) 
             num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to( 
                 device, non_blocking=True

For a non-thinking model, r.tokenizer has no think_end_id attribute, so nothing will happen.

Better isolation introduces more code. Is this silent ignorance acceptable? @CatherineSue

[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)

if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-reasoning model, we can skip this check? Do we need to identify if the model is a reasoning model from the architect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about it. Now I don't know how to decide whether a model is reasoning from a request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can think of, there seems to be no better way. The information about can be obtained here is very limited. It is just doing sampling and does not know the specific model architecture.

"temperature": request.temperature,
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
"min_new_tokens": request.min_tokens,
"thinking_budget": request.thinking_budget,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems only valid when enable_thinking = True. Need to do the validation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in 1c741a6

top_k: int = -1
min_p: float = 0.0
min_tokens: int = 0
thinking_budget: Optional[int] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will it be used in non-chat scenario? Is there any usage example/reference for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not familiar with CompletionRequest. Removed thinking_budget in it 186b1aa

def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
if self.thinking_budgets is None:
return
has_budget = self.thinking_budgets > 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't support =0 scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, now we support thinking_budget=0 in 30aa15f

next_token_logits[batch_indices, end_token_indices] = 0.0

def update_thinking_budgets(self, next_token_ids: torch.Tensor):
if self.thinking_budgets is None or not torch.any(self.thinking_budgets > 0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't support =0 scenario?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, now we support thinking_budget=0 in 30aa15f

device, non_blocking=True
)
thinking_budgets = torch.tensor(
[r.sampling_params.thinking_budget or -1 for r in reqs],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When thinking_budget=0, this will be assigned a value of -1. It is recommended to modify this, otherwise it will not be fully supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated 200c598
thanks for pointing out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants