From 8ce4561e619815d4a976ec3b0f73ec17c137d44d Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:14:00 +0000 Subject: [PATCH 01/61] feat: limit thinking tokens Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/__init__.py | 3 + vllm/config/reasoning.py | 31 ++++++ vllm/config/vllm.py | 3 + vllm/entrypoints/openai/protocol.py | 4 +- vllm/sampling_params.py | 6 ++ vllm/v1/engine/core.py | 1 + vllm/v1/sample/logits_processor/builtin.py | 105 +++++++++++++++++++++ vllm/v1/worker/gpu_input_batch.py | 2 + vllm/v1/worker/gpu_model_runner.py | 20 ++++ 9 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 vllm/config/reasoning.py diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index c909265c071d..c6b259c3772e 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -25,6 +25,7 @@ from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig +from vllm.config.reasoning import ReasoningConfig from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, get_attr_docs, is_init_field, update_config) from vllm.config.vllm import (VllmConfig, get_cached_compilation_config, @@ -80,6 +81,8 @@ "ParallelConfig", # From vllm.config.pooler "PoolerConfig", + # From vllm.config.reasoning + "ReasoningConfig", # From vllm.config.scheduler "RunnerType", "SchedulerConfig", diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py new file mode 100644 index 000000000000..220c739b3fed --- /dev/null +++ b/vllm/config/reasoning.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Optional + +from pydantic.dataclasses import dataclass + +from vllm.config.utils import config + + +@config +@dataclass +class ReasoningConfig: + """Configuration for reasoning models.""" + + think_start_str: Optional[str] = None + """String that indicates the start of reasoning.""" + think_end_str: Optional[str] = None + """String that indicates the end of reasoning.""" + think_start_token_ids: Optional[list[int]] = None + """Token ID that indicates the start of reasoning.""" + think_end_token_ids: Optional[list[int]] = None + """Token ID that indicates the end of reasoning.""" + + def is_thinking_enabled(self) -> bool: + """Check if both start and end thinking token IDs + are set to enable thinking token budget logic.""" + return (self.think_start_token_ids is not None + and self.think_end_token_ids is not None + and len(self.think_start_token_ids) > 0 + and len(self.think_end_token_ids) > 0) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ac40b0fd4783..ae5425b17769 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -30,6 +30,7 @@ from .model import ModelConfig from .observability import ObservabilityConfig from .parallel import ParallelConfig +from .reasoning import ReasoningConfig from .scheduler import SchedulerConfig from .speculative import SpeculativeConfig from .structured_outputs import StructuredOutputsConfig @@ -100,6 +101,8 @@ class VllmConfig: """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None """The configurations for event publishing.""" + reasoning_config: Optional[ReasoningConfig] = None + """The configurations for reasoning model.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 9d51372887c2..682d7468a4a5 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -488,6 +488,7 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) + max_think_tokens: Optional[int] = None # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -794,7 +795,8 @@ def to_sampling_params( else RequestOutputKind.FINAL_ONLY, structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, - bad_words=self.bad_words, + bad_words= self.bad_words, + max_think_tokens=self.max_think_tokens, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index f424682f9dfa..8a07b7843efb 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -213,6 +213,9 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: Optional[list[list[int]]] = None + max_think_tokens: Optional[int] = None + """Maximum number of tokens allowed for thinking operations.""" + @staticmethod def from_optional( n: Optional[int] = 1, @@ -228,6 +231,7 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, + max_think_tokens: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, @@ -282,6 +286,7 @@ def from_optional( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, + max_think_tokens=max_think_tokens, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, @@ -565,6 +570,7 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " + f"max_think_tokens={self.max_think_tokens}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 3ee804f10c17..9379db401ad8 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -97,6 +97,7 @@ def __init__(self, self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) + # EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg. self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index fc655d993cb4..09ce19d5aeba 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -273,3 +273,108 @@ def process_dict_updates( req_entries[a_index] = b_entry return updated + + +class MaxThinkTokensLogitsProcessor(LogitsProcessor): + """A logits processor that limits the maximum number of thinking tokens.""" + + def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device): + """ + Args: + think_start_token_id (int): Token ID for the start of thinking section. + think_end_token_id (int): Token ID for the end of thinking section. + pin_memory (bool): Whether to use pinned memory for tensors. + device (torch.device): Device to use for tensor operations. + """ + super().__init__() + self.think_start_token_id = reasoning_config.think_start_token_id + self.think_end_token_id = reasoning_config.think_end_token_id + self.pin_memory = pin_memory + self.device = device + self._state = {} + + def _find_last_token_index(self, tokens, token_id): + try: + return len(tokens) - tokens[::-1].index(token_id) - 1 + except ValueError: + return -1 + + def is_argmax_invariant(self) -> bool: + """This logits processor can change the outcome of greedy sampling + by forcing that the thinking section ends after a certain number of tokens.""" + return False + + def update_state(self, batch_update: Optional[BatchUpdate]): + if batch_update is None: + return + + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None + + if max_think_tokens is None: + continue + + last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) + last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) + + in_think = False + count = 0 + + if last_think_start_idx > last_think_end_idx: + in_think = True + count = len(prompt_tok_ids) - (last_think_start_idx + 1) + + self._state[index] = { + "in_think": in_think, + "count": count, + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": output_tok_ids, + "max_think_tokens": max_think_tokens, + } + + for index in batch_update.removed: + self._state.pop(index, None) + + for i1, i2, direction in batch_update.moved: + if direction == MoveDirectionality.SWAP: + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + else: + self._state[i2] = self._state.pop(i1, None) + + def apply(self, logits: torch.Tensor) -> torch.Tensor: + batch_size = logits.size(0) + if batch_size == 0: + return logits + + mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) + end_token_id = self.think_end_token_id + + for index in range(batch_size): + state = self._state.get(index, None) + if not state or not state.get("output_tok_ids"): + continue + + last_tok = state["output_tok_ids"][-1] + in_think = state["in_think"] + count = state["count"] + + if last_tok == self.think_start_token_id: + in_think = True + count = 0 + elif last_tok == self.think_end_token_id: + in_think = False + count = 0 + elif in_think: + count += 1 + + state["in_think"] = in_think + state["count"] = count + + if state["in_think"] and state["count"] >= state["max_think_tokens"]: + mask[index] = True + + if mask.any(): + logits[mask] = -float("inf") + logits[mask, end_token_id] = 0.0 + + return logits diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 67fb9864b19c..ba0462351314 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,6 +9,7 @@ import torch from typing_extensions import deprecated +from vllm.config import ReasoningConfig from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams @@ -91,6 +92,7 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, + reasoning_config: ReasoningConfig = None, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index efb4a8c0054f..0ba0f515397d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -114,6 +114,10 @@ gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) +from vllm.config import ReasoningConfig +from vllm.reasoning import ReasoningParserManager +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput @@ -191,6 +195,20 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config + if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'): + tokenizer = init_tokenizer_from_configs( + model_config=self.vllm_config.model_config, + scheduler_config=self.vllm_config.scheduler_config, + lora_config=self.vllm_config.lora_config, + ).get_lora_tokenizer(None) + reasoning_backend = \ + self.vllm_config.decoding_config.reasoning_backend + reasoner_cls = ReasoningParserManager.get_reasoning_parser( + reasoning_backend) + reasoning_parser = reasoner_cls(tokenizer=tokenizer) + self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id, + think_end_token_id=reasoning_parser.think_end_token_id) + from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) @@ -326,6 +344,7 @@ def __init__( self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, + reasoning_config=self.vllm_config.reasoning_config, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -3864,6 +3883,7 @@ def may_reinitialize_input_batch(self, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0), + reasoning_config=self.vllm_config.reasoning_config, ) def _allocate_kv_cache_tensors( From b815e9c248405c4742f19d86b01cb90da7372b0b Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 12 Jul 2025 09:24:57 +0000 Subject: [PATCH 02/61] remove comment Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/engine/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 9379db401ad8..3ee804f10c17 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -97,7 +97,6 @@ def __init__(self, self.collective_rpc("initialize_cache", args=(num_gpu_blocks, num_cpu_blocks)) - # EngineCore holds StructuredOutputManager to handle and it has vllm config as an arg. self.structured_output_manager = StructuredOutputManager(vllm_config) # Setup scheduler. From 2001c36364564a7fd3054f677246f7d26e36da7b Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 14 Jul 2025 04:06:23 +0000 Subject: [PATCH 03/61] update states only in update_state method Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 98 ++++++++++------------ 1 file changed, 44 insertions(+), 54 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 09ce19d5aeba..cd7fe576fde2 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -291,9 +291,9 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: self.think_end_token_id = reasoning_config.think_end_token_id self.pin_memory = pin_memory self.device = device - self._state = {} + self._state: dict[int, dict[str, Any]] = {} - def _find_last_token_index(self, tokens, token_id): + def _find_last_token_index(self, tokens: list[int], token_id: int) -> int: try: return len(tokens) - tokens[::-1].index(token_id) - 1 except ValueError: @@ -305,71 +305,61 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: Optional[BatchUpdate]): - if batch_update is None: - return - - for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None - - if max_think_tokens is None: + if batch_update: + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None + if max_think_tokens is not None: + last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) + last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) + in_think = last_start > last_end + count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + + self._state[index] = { + "in_think": in_think, + "count": count, + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": output_tok_ids, + "max_think_tokens": max_think_tokens, + } + + for index in batch_update.removed: + self._state.pop(index, None) + + for i1, i2, direction in batch_update.moved: + if direction == MoveDirectionality.SWAP: + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + else: + self._state[i2] = self._state.pop(i1, None) + + # Update in_think and count for all active requests + for state in self._state.values(): + output = state["output_tok_ids"] + if not output: continue - last_think_start_idx = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) - last_think_end_idx = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) - - in_think = False - count = 0 - - if last_think_start_idx > last_think_end_idx: - in_think = True - count = len(prompt_tok_ids) - (last_think_start_idx + 1) - - self._state[index] = { - "in_think": in_think, - "count": count, - "prompt_tok_ids": prompt_tok_ids, - "output_tok_ids": output_tok_ids, - "max_think_tokens": max_think_tokens, - } - - for index in batch_update.removed: - self._state.pop(index, None) - - for i1, i2, direction in batch_update.moved: - if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[i2], self._state[i1] - else: - self._state[i2] = self._state.pop(i1, None) + last_tok = output[-1] + if last_tok == self.think_start_token_id: + state["in_think"] = True + state["count"] = 0 + elif last_tok == self.think_end_token_id: + state["in_think"] = False + state["count"] = 0 + elif state["in_think"]: + state["count"] += 1 def apply(self, logits: torch.Tensor) -> torch.Tensor: batch_size = logits.size(0) - if batch_size == 0: + if not self._state: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) end_token_id = self.think_end_token_id for index in range(batch_size): - state = self._state.get(index, None) - if not state or not state.get("output_tok_ids"): + state = self._state.get(index) + if not state: continue - last_tok = state["output_tok_ids"][-1] - in_think = state["in_think"] - count = state["count"] - - if last_tok == self.think_start_token_id: - in_think = True - count = 0 - elif last_tok == self.think_end_token_id: - in_think = False - count = 0 - elif in_think: - count += 1 - - state["in_think"] = in_think - state["count"] = count - if state["in_think"] and state["count"] >= state["max_think_tokens"]: mask[index] = True From c71cf86594d5303c6f5573ffeb424981401f7328 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 14 Jul 2025 04:40:44 +0000 Subject: [PATCH 04/61] make precommit and lint Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 41 +++++++++++++--------- vllm/v1/worker/gpu_model_runner.py | 17 +++++---- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index cd7fe576fde2..14df0ccc9024 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -278,13 +278,14 @@ def process_dict_updates( class MaxThinkTokensLogitsProcessor(LogitsProcessor): """A logits processor that limits the maximum number of thinking tokens.""" - def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device): + def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, + device: torch.device): """ Args: - think_start_token_id (int): Token ID for the start of thinking section. - think_end_token_id (int): Token ID for the end of thinking section. - pin_memory (bool): Whether to use pinned memory for tensors. - device (torch.device): Device to use for tensor operations. + reasoning_config: Configuration for reasoning, which includes + the token IDs for thinking start and end. + pin_memory (bool): Whether to use pinned memory for tensors. + device (torch.device): Device to use for tensor operations. """ super().__init__() self.think_start_token_id = reasoning_config.think_start_token_id @@ -300,19 +301,25 @@ def _find_last_token_index(self, tokens: list[int], token_id: int) -> int: return -1 def is_argmax_invariant(self) -> bool: - """This logits processor can change the outcome of greedy sampling - by forcing that the thinking section ends after a certain number of tokens.""" + """This logits processor can change the outcome of + greedy sampling by forcing that the thinking section + ends after a certain number of tokens.""" return False def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: - for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - max_think_tokens = params.max_think_tokens if isinstance(params, SamplingParams) else None + for (index, params, prompt_tok_ids, + output_tok_ids) in batch_update.added: + max_think_tokens = (params.max_think_tokens if isinstance( + params, SamplingParams) else None) if max_think_tokens is not None: - last_start = self._find_last_token_index(prompt_tok_ids, self.think_start_token_id) - last_end = self._find_last_token_index(prompt_tok_ids, self.think_end_token_id) + last_start = self._find_last_token_index( + prompt_tok_ids, self.think_start_token_id) + last_end = self._find_last_token_index( + prompt_tok_ids, self.think_end_token_id) in_think = last_start > last_end - count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + count = len(prompt_tok_ids) - (last_start + + 1) if in_think else 0 self._state[index] = { "in_think": in_think, @@ -323,13 +330,14 @@ def update_state(self, batch_update: Optional[BatchUpdate]): } for index in batch_update.removed: - self._state.pop(index, None) + self._state.pop(index, {}) for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + self._state[i1], self._state[i2] = self._state[ + i2], self._state[i1] else: - self._state[i2] = self._state.pop(i1, None) + self._state[i2] = self._state.pop(i1, {}) # Update in_think and count for all active requests for state in self._state.values(): @@ -360,7 +368,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: if not state: continue - if state["in_think"] and state["count"] >= state["max_think_tokens"]: + if state["in_think"] and state["count"] >= state[ + "max_think_tokens"]: mask[index] = True if mask.any(): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 0ba0f515397d..1db36477a813 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,8 +24,8 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, - get_layers_from_vllm_config, update_config) +from vllm.config import (CompilationLevel, CUDAGraphMode, ReasoningConfig, + VllmConfig, get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -57,6 +57,7 @@ PlaceholderRange) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams +from vllm.reasoning import ReasoningParserManager from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask @@ -114,10 +115,6 @@ gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders) -from vllm.config import ReasoningConfig -from vllm.reasoning import ReasoningParserManager -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - if TYPE_CHECKING: from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.v1.core.sched.output import SchedulerOutput @@ -195,7 +192,8 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config - if self.vllm_config.decoding_config.reasoning_backend in ('deepseek_r1', 'qwen'): + if self.vllm_config.decoding_config.reasoning_backend in ( + 'deepseek_r1', 'qwen'): tokenizer = init_tokenizer_from_configs( model_config=self.vllm_config.model_config, scheduler_config=self.vllm_config.scheduler_config, @@ -206,8 +204,9 @@ def __init__( reasoner_cls = ReasoningParserManager.get_reasoning_parser( reasoning_backend) reasoning_parser = reasoner_cls(tokenizer=tokenizer) - self.vllm_config.reasoning_config = ReasoningConfig(think_start_token_id=reasoning_parser.think_start_token_id, - think_end_token_id=reasoning_parser.think_end_token_id) + self.vllm_config.reasoning_config = ReasoningConfig( + think_start_token_id=reasoning_parser.think_start_token_id, + think_end_token_id=reasoning_parser.think_end_token_id) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From 7ae072558dd149c847be9364fdb56751651e509f Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 16 Jul 2025 06:33:40 +0000 Subject: [PATCH 05/61] support think start/end as token sequences Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/engine/arg_utils.py | 16 ++- vllm/v1/sample/logits_processor/builtin.py | 122 ++++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 16 +-- 3 files changed, 104 insertions(+), 50 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ec61fc4b9b06..4e808873ada4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -28,10 +28,10 @@ KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, ModelDType, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, RunnerOption, - SchedulerConfig, SchedulerPolicy, SpeculativeConfig, - StructuredOutputsConfig, TaskOption, TokenizerMode, - VllmConfig, get_attr_docs) + PoolerConfig, PrefixCachingHashAlgo, ReasoningConfig, + RunnerOption, SchedulerConfig, SchedulerPolicy, + SpeculativeConfig, StructuredOutputsConfig, TaskOption, + TokenizerMode, VllmConfig, get_attr_docs) from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.utils import get_field @@ -456,6 +456,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None + reasoning_config: Optional[ReasoningConfig] = None + generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode override_generation_config: dict[str, Any] = \ @@ -933,6 +935,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **vllm_kwargs["kv_events_config"]) vllm_group.add_argument("--compilation-config", "-O", **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--reasoning-config", + **vllm_kwargs["reasoning_config"]) vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) vllm_group.add_argument('--structured-outputs-config', @@ -1430,6 +1434,9 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + reasoning_config_dict = json.loads(self.reasoning_config) + reasoning_config = ReasoningConfig(**reasoning_config_dict) + config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1444,6 +1451,7 @@ def create_engine_config( compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, + reasoning_config=reasoning_config, additional_config=self.additional_config, ) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 14df0ccc9024..1ad9447d17ba 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -294,12 +294,84 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, self.device = device self._state: dict[int, dict[str, Any]] = {} - def _find_last_token_index(self, tokens: list[int], token_id: int) -> int: + def _find_first_token_index(self, target_list: list[int], token_id: int) -> int: + """ + Find the last occurrence of a single token in the list of tokens. + + Args: + target_list (list[int]): The list of token IDs. + token_id (int): The token ID to find. + """ try: - return len(tokens) - tokens[::-1].index(token_id) - 1 + return len(target_list) - target_list[::-1].index(token_id) - 1 except ValueError: return -1 + def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int: + """ + Find the last occurrence of the sequence of token_ids in tokens. + + Args: + target_list (list[int]): The list of token IDs. + token_ids (list[int]): The sequence of token IDs to find. + """ + index = self._find_first_token_index(target_list, token_ids[0]) + if index != -1: + i = 1 + for token_id in token_ids[1:]: + if index + i >= len(target_list) or target_list[index + i] != token_id: + return -1 + i += 1 + index += 1 + + return index + + def _init_state_entry(self, prompt_tok_ids, max_think_tokens): + last_start = self._find_last_sequence_index( + prompt_tok_ids, self.think_start_token_id) + last_end = self._find_last_sequence_index( + prompt_tok_ids, self.think_end_token_id) + in_think = last_start > last_end + think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + + return { + "in_think": in_think, + "in_end": False, + "think_count": think_count, + "end_count": 0, + "prompt_tok_ids": prompt_tok_ids, + "output_tok_ids": [], + "max_think_tokens": max_think_tokens, + } + + def _update_think_state(self, state): + output = state["output_tok_ids"] + if not output: + return + + sliced_output1 = output[-1 + len(self.think_start_token_id):] + sliced_output2 = output[-1 + len(self.think_end_token_id):] + + if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1: + state["in_think"] = True + state["think_count"] = 0 + elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1: + state["in_think"] = False + state["think_count"] = 0 + else: + state["think_count"] += 1 + + if state["in_end"]: + state["end_count"] += 1 + if state["end_count"] >= len(self.think_end_token_id): + state["in_end"] = False + state["end_count"] = 0 + else: + if state["in_think"] and state["think_count"] >= state["max_think_tokens"]: + state["in_think"] = False + state["in_end"] = True + state["end_count"] = 0 + def is_argmax_invariant(self) -> bool: """This logits processor can change the outcome of greedy sampling by forcing that the thinking section @@ -308,52 +380,25 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: - for (index, params, prompt_tok_ids, - output_tok_ids) in batch_update.added: + for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: max_think_tokens = (params.max_think_tokens if isinstance( params, SamplingParams) else None) if max_think_tokens is not None: - last_start = self._find_last_token_index( - prompt_tok_ids, self.think_start_token_id) - last_end = self._find_last_token_index( - prompt_tok_ids, self.think_end_token_id) - in_think = last_start > last_end - count = len(prompt_tok_ids) - (last_start + - 1) if in_think else 0 - - self._state[index] = { - "in_think": in_think, - "count": count, - "prompt_tok_ids": prompt_tok_ids, - "output_tok_ids": output_tok_ids, - "max_think_tokens": max_think_tokens, - } + self._state[index] = self._init_state_entry( + prompt_tok_ids, max_think_tokens) + self._state[index]["output_tok_ids"] = output_tok_ids for index in batch_update.removed: self._state.pop(index, {}) for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[ - i2], self._state[i1] + self._state[i1], self._state[i2] = self._state[i2], self._state[i1] else: self._state[i2] = self._state.pop(i1, {}) - # Update in_think and count for all active requests for state in self._state.values(): - output = state["output_tok_ids"] - if not output: - continue - - last_tok = output[-1] - if last_tok == self.think_start_token_id: - state["in_think"] = True - state["count"] = 0 - elif last_tok == self.think_end_token_id: - state["in_think"] = False - state["count"] = 0 - elif state["in_think"]: - state["count"] += 1 + self._update_think_state(state) def apply(self, logits: torch.Tensor) -> torch.Tensor: batch_size = logits.size(0) @@ -368,12 +413,13 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: if not state: continue - if state["in_think"] and state["count"] >= state[ - "max_think_tokens"]: + force_end_token_id = None + if state["in_end"]: + force_end_token_id = self.think_end_token_id[state["end_count"]] mask[index] = True if mask.any(): logits[mask] = -float("inf") - logits[mask, end_token_id] = 0.0 + logits[mask, force_end_token_id] = 0.0 return logits diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1db36477a813..85a2ef159e72 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -199,14 +199,14 @@ def __init__( scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) - reasoning_backend = \ - self.vllm_config.decoding_config.reasoning_backend - reasoner_cls = ReasoningParserManager.get_reasoning_parser( - reasoning_backend) - reasoning_parser = reasoner_cls(tokenizer=tokenizer) - self.vllm_config.reasoning_config = ReasoningConfig( - think_start_token_id=reasoning_parser.think_start_token_id, - think_end_token_id=reasoning_parser.think_end_token_id) + reasoning_config = self.vllm_config.reasoning_config + if reasoning_config is not None: + reasoning_config.think_start_token_id = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_start_str)) + reasoning_config.think_end_token_id = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_end_str)) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From 03d3495c6add61ff5271da95113e3d6ebdc23f14 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Thu, 17 Jul 2025 13:47:22 +0000 Subject: [PATCH 06/61] refactor and change logic faster Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 90 +++++++++------------- vllm/v1/worker/gpu_model_runner.py | 4 +- 2 files changed, 39 insertions(+), 55 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 1ad9447d17ba..cbf6e046a46c 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -276,7 +276,7 @@ def process_dict_updates( class MaxThinkTokensLogitsProcessor(LogitsProcessor): - """A logits processor that limits the maximum number of thinking tokens.""" + """Limits the number of tokens allowed inside a 'thinking' section.""" def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device: torch.device): @@ -288,82 +288,68 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device (torch.device): Device to use for tensor operations. """ super().__init__() - self.think_start_token_id = reasoning_config.think_start_token_id - self.think_end_token_id = reasoning_config.think_end_token_id + self.think_start_token_ids = reasoning_config.think_start_token_ids + self.think_end_token_ids = reasoning_config.think_end_token_ids self.pin_memory = pin_memory self.device = device self._state: dict[int, dict[str, Any]] = {} - def _find_first_token_index(self, target_list: list[int], token_id: int) -> int: + @staticmethod + def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: """ - Find the last occurrence of a single token in the list of tokens. + Returns the index of the last occurrence of token_ids in target_list. Args: target_list (list[int]): The list of token IDs. - token_id (int): The token ID to find. + token_ids (list[int]): The sequence of token IDs to find. """ - try: - return len(target_list) - target_list[::-1].index(token_id) - 1 - except ValueError: + if not token_ids: return -1 - def _find_last_sequence_index(self, target_list: list[int], token_ids: list[int]) -> int: - """ - Find the last occurrence of the sequence of token_ids in tokens. + for i in range(len(target_list) - len(token_ids), -1, -1): + if target_list[i:i + len(token_ids)] == token_ids: + return i + return -1 - Args: - target_list (list[int]): The list of token IDs. - token_ids (list[int]): The sequence of token IDs to find. - """ - index = self._find_first_token_index(target_list, token_ids[0]) - if index != -1: - i = 1 - for token_id in token_ids[1:]: - if index + i >= len(target_list) or target_list[index + i] != token_id: - return -1 - i += 1 - index += 1 - - return index - - def _init_state_entry(self, prompt_tok_ids, max_think_tokens): + def _init_state_entry(self, prompt_tok_ids: list[int], max_think_tokens: int) -> dict[str, Any]: + """Initializes the tracking state for a given sequence index.""" last_start = self._find_last_sequence_index( - prompt_tok_ids, self.think_start_token_id) + prompt_tok_ids, self.think_start_token_ids) last_end = self._find_last_sequence_index( - prompt_tok_ids, self.think_end_token_id) + prompt_tok_ids, self.think_end_token_ids) in_think = last_start > last_end think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 return { - "in_think": in_think, - "in_end": False, - "think_count": think_count, - "end_count": 0, + "in_think": in_think, # Currently in thinking mode + "in_end": False, # Currently forcing end tokens + "think_count": think_count, # Number of tokens in thinking section + "end_count": 0, # Number of end tokens forced so far "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], "max_think_tokens": max_think_tokens, } - def _update_think_state(self, state): + def _update_think_state(self, state: dict[str, Any]): + """Updates the state based on generated output tokens.""" output = state["output_tok_ids"] if not output: return - sliced_output1 = output[-1 + len(self.think_start_token_id):] - sliced_output2 = output[-1 + len(self.think_end_token_id):] - - if self._find_last_sequence_index(sliced_output1, self.think_start_token_id) != -1: + # Check if recent output matches start or end sequences + if output[-len(self.think_start_token_ids):] == self.think_start_token_ids: state["in_think"] = True state["think_count"] = 0 - elif self._find_last_sequence_index(sliced_output2, self.think_end_token_id) != -1: + elif output[-len(self.think_end_token_ids):] == self.think_end_token_ids: state["in_think"] = False state["think_count"] = 0 - else: + elif state["in_think"]: state["think_count"] += 1 + # Transition into end mode if thinking token limit exceeded if state["in_end"]: state["end_count"] += 1 - if state["end_count"] >= len(self.think_end_token_id): + if state["end_count"] >= len(self.think_end_token_ids): state["in_end"] = False state["end_count"] = 0 else: @@ -406,20 +392,18 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) - end_token_id = self.think_end_token_id - - for index in range(batch_size): - state = self._state.get(index) - if not state: - continue + force_token_ids = torch.full((batch_size,), -1, dtype=torch.long, device=logits.device) - force_end_token_id = None - if state["in_end"]: - force_end_token_id = self.think_end_token_id[state["end_count"]] - mask[index] = True + for i in range(batch_size): + state = self._state.get(i) + if state and state["in_end"]: + mask[i] = True + force_token_ids[i] = self.think_end_token_ids[state["end_count"]] if mask.any(): logits[mask] = -float("inf") - logits[mask, force_end_token_id] = 0.0 + row_indices = torch.arange(batch_size, device=logits.device)[mask] + col_indices = force_token_ids[mask] + logits[row_indices, col_indices] = 0.0 return logits diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 85a2ef159e72..21c6d5b7efce 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -201,10 +201,10 @@ def __init__( ).get_lora_tokenizer(None) reasoning_config = self.vllm_config.reasoning_config if reasoning_config is not None: - reasoning_config.think_start_token_id = \ + reasoning_config.think_start_token_ids = \ tokenizer.convert_tokens_to_ids( tokenizer.tokenize(reasoning_config.think_start_str)) - reasoning_config.think_end_token_id = \ + reasoning_config.think_end_token_ids = \ tokenizer.convert_tokens_to_ids( tokenizer.tokenize(reasoning_config.think_end_str)) From 5442d0c37ccbfae504b0040c2bd7808932552790 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:38:23 +0000 Subject: [PATCH 07/61] rename parameter and logit processor Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/entrypoints/openai/protocol.py | 4 +- vllm/sampling_params.py | 8 +- vllm/v1/sample/logits_processor/builtin.py | 99 +++++++++++----------- 3 files changed, 56 insertions(+), 55 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 682d7468a4a5..aa3247c07907 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -488,7 +488,7 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) - max_think_tokens: Optional[int] = None + thinking_token_budget: Optional[int] = None # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] @@ -796,7 +796,7 @@ def to_sampling_params( structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, bad_words= self.bad_words, - max_think_tokens=self.max_think_tokens, + thinking_token_budget=self.thinking_token_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, ) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 8a07b7843efb..ce4418b83bc9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -213,7 +213,7 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: Optional[list[list[int]]] = None - max_think_tokens: Optional[int] = None + thinking_token_budget: Optional[int] = None """Maximum number of tokens allowed for thinking operations.""" @staticmethod @@ -231,7 +231,7 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, - max_think_tokens: Optional[int] = None, + thinking_token_budget: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, @@ -286,7 +286,7 @@ def from_optional( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, - max_think_tokens=max_think_tokens, + thinking_token_budget=thinking_token_budget, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, max_tokens=max_tokens, @@ -570,7 +570,7 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " - f"max_think_tokens={self.max_think_tokens}, " + f"thinking_token_budget={self.thinking_token_budget}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index cbf6e046a46c..d03b0496a288 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -233,49 +233,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits -def process_dict_updates( - req_entries: dict[int, T], batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], - Optional[T]] -) -> bool: - """Utility function to update dict state for sparse LogitsProcessors.""" - - if not batch_update: - # Nothing to do. - return False - - updated = False - for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - if (state := new_state(params, prompt_tok_ids, - output_tok_ids)) is not None: - req_entries[index] = state - updated = True - elif req_entries.pop(index, None) is not None: - updated = True - - if req_entries: - # Process removed requests. - for index in batch_update.removed: - if req_entries.pop(index, None): - updated = True - - # Process moved requests, unidirectional (a->b) and - # swapped (a<->b) - for a_index, b_index, direct in batch_update.moved: - a_entry = req_entries.pop(a_index, None) - b_entry = req_entries.pop(b_index, None) - if a_entry is not None: - req_entries[b_index] = a_entry - updated = True - if b_entry is not None: - updated = True - if direct == MoveDirectionality.SWAP: - req_entries[a_index] = b_entry - - return updated - - -class MaxThinkTokensLogitsProcessor(LogitsProcessor): +class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor): """Limits the number of tokens allowed inside a 'thinking' section.""" def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, @@ -311,7 +269,7 @@ def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> i return i return -1 - def _init_state_entry(self, prompt_tok_ids: list[int], max_think_tokens: int) -> dict[str, Any]: + def _init_state_entry(self, prompt_tok_ids: list[int], thinking_token_budget: int) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" last_start = self._find_last_sequence_index( prompt_tok_ids, self.think_start_token_ids) @@ -327,7 +285,7 @@ def _init_state_entry(self, prompt_tok_ids: list[int], max_think_tokens: int) -> "end_count": 0, # Number of end tokens forced so far "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], - "max_think_tokens": max_think_tokens, + "thinking_token_budget": thinking_token_budget, } def _update_think_state(self, state: dict[str, Any]): @@ -353,7 +311,7 @@ def _update_think_state(self, state: dict[str, Any]): state["in_end"] = False state["end_count"] = 0 else: - if state["in_think"] and state["think_count"] >= state["max_think_tokens"]: + if state["in_think"] and state["think_count"] >= state["thinking_token_budget"]: state["in_think"] = False state["in_end"] = True state["end_count"] = 0 @@ -367,11 +325,11 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: - max_think_tokens = (params.max_think_tokens if isinstance( + thinking_token_budget = (params.thinking_token_budget if isinstance( params, SamplingParams) else None) - if max_think_tokens is not None: + if thinking_token_budget is not None: self._state[index] = self._init_state_entry( - prompt_tok_ids, max_think_tokens) + prompt_tok_ids, thinking_token_budget) self._state[index]["output_tok_ids"] = output_tok_ids for index in batch_update.removed: @@ -407,3 +365,46 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: logits[row_indices, col_indices] = 0.0 return logits + + +def process_dict_updates( + req_entries: dict[int, T], batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], + Optional[T]] +) -> bool: + """Utility function to update dict state for sparse LogitsProcessors.""" + + if not batch_update: + # Nothing to do. + return False + + updated = False + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: + if (state := new_state(params, prompt_tok_ids, + output_tok_ids)) is not None: + req_entries[index] = state + updated = True + elif req_entries.pop(index, None) is not None: + updated = True + + if req_entries: + # Process removed requests. + for index in batch_update.removed: + if req_entries.pop(index, None): + updated = True + + # Process moved requests, unidirectional (a->b) and + # swapped (a<->b) + for a_index, b_index, direct in batch_update.moved: + a_entry = req_entries.pop(a_index, None) + b_entry = req_entries.pop(b_index, None) + if a_entry is not None: + req_entries[b_index] = a_entry + updated = True + if b_entry is not None: + updated = True + if direct == MoveDirectionality.SWAP: + req_entries[a_index] = b_entry + + return updated + From 283a07a65b1ccb30db1a3a36a4abcf8fd11dc932 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:39:39 +0000 Subject: [PATCH 08/61] add reasoning effort param Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/entrypoints/openai/protocol.py | 2 ++ vllm/sampling_params.py | 5 +++ vllm/v1/sample/logits_processor/builtin.py | 39 +++++++++++++++++++--- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index aa3247c07907..0254ceec1439 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -488,6 +488,7 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) + reasoning_effort: Optional[str] = None thinking_token_budget: Optional[int] = None # --8<-- [end:chat-completion-sampling-params] @@ -796,6 +797,7 @@ def to_sampling_params( structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, bad_words= self.bad_words, + reasoning_effort=self.reasoning_effort, thinking_token_budget=self.thinking_token_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ce4418b83bc9..3d5a591e8ed1 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -213,6 +213,8 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: Optional[list[list[int]]] = None + # Fields used for reasoning + reasoning_effort: Optional[str] = None thinking_token_budget: Optional[int] = None """Maximum number of tokens allowed for thinking operations.""" @@ -231,6 +233,7 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, + reasoning_effort: Optional[str] = None, thinking_token_budget: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, @@ -286,6 +289,7 @@ def from_optional( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, + reasoning_effort=reasoning_effort, thinking_token_budget=thinking_token_budget, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, @@ -570,6 +574,7 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " + f"reasoning_effort={self.reasoning_effort}, " f"thinking_token_budget={self.thinking_token_budget}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index d03b0496a288..d5f12216ba70 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -246,8 +246,17 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, device (torch.device): Device to use for tensor operations. """ super().__init__() - self.think_start_token_ids = reasoning_config.think_start_token_ids - self.think_end_token_ids = reasoning_config.think_end_token_ids + self.reasoning_effort_to_token_budget = { + "low": 1024, + "medium": 2048, + "high": 8192, + } + self.think_start_token_ids = getattr(reasoning_config, "think_start_token_ids", []) + self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", []) + self.reasoning_effort_to_token_budget['low'] = getattr(reasoning_config, "low_effort_token_budget", self.reasoning_effort_to_token_budget['low']) + self.reasoning_effort_to_token_budget['medium'] = getattr(reasoning_config, "medium_effort_token_budget", self.reasoning_effort_to_token_budget['medium']) + self.reasoning_effort_to_token_budget['high'] = getattr(reasoning_config, "high_effort_token_budget", self.reasoning_effort_to_token_budget['high']) + self.pin_memory = pin_memory self.device = device self._state: dict[int, dict[str, Any]] = {} @@ -269,6 +278,24 @@ def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> i return i return -1 + def _resolve_thinking_token_budget(self, reasoning_effort: Optional[str], thinking_token_budget: Optional[int]) -> int: + """ + Determines the final thinking token budget. + Priority: + 1. If explicit thinking token budget is given, use it. + 2. Otherwise, use reasoning_effort mapping. + """ + if thinking_token_budget is not None: + return thinking_token_budget + + if reasoning_effort is not None: + budget = self.reasoning_effort_to_token_budget.get(reasoning_effort) + if budget is not None: + raise ValueError(f"Unknown reasoning_effort: {reasoning_effort}") + return budget + + return None + def _init_state_entry(self, prompt_tok_ids: list[int], thinking_token_budget: int) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" last_start = self._find_last_sequence_index( @@ -325,11 +352,15 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: + reasoning_effort = (params.reasoning_effort if isinstance( + params, SamplingParams) else None) thinking_token_budget = (params.thinking_token_budget if isinstance( params, SamplingParams) else None) - if thinking_token_budget is not None: + resolved_thinking_token_budget = self._resolve_thinking_token_budget( + reasoning_effort, thinking_token_budget) + if thinking_token_budget is not None or reasoning_effort is not None: self._state[index] = self._init_state_entry( - prompt_tok_ids, thinking_token_budget) + prompt_tok_ids, resolved_thinking_token_budget) self._state[index]["output_tok_ids"] = output_tok_ids for index in batch_update.removed: From 3780d55c4321f8038547062c4dd8a4a2262445b9 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:52:37 +0000 Subject: [PATCH 09/61] remove constraint of the reasoning model Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/worker/gpu_model_runner.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 21c6d5b7efce..e1a14b2c4bd7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -192,21 +192,19 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config - if self.vllm_config.decoding_config.reasoning_backend in ( - 'deepseek_r1', 'qwen'): + reasoning_config = self.vllm_config.reasoning_config + if reasoning_config is not None: tokenizer = init_tokenizer_from_configs( model_config=self.vllm_config.model_config, scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) - reasoning_config = self.vllm_config.reasoning_config - if reasoning_config is not None: - reasoning_config.think_start_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_start_str)) - reasoning_config.think_end_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_end_str)) + reasoning_config.think_start_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_start_str)) + reasoning_config.think_end_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_end_str)) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From 7a509fbcf4d6905099b8b172e49c9457a2a273f5 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 19 Jul 2025 08:28:45 +0000 Subject: [PATCH 10/61] update logit processor Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index d5f12216ba70..7b15eb48e5bc 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -290,7 +290,7 @@ def _resolve_thinking_token_budget(self, reasoning_effort: Optional[str], thinki if reasoning_effort is not None: budget = self.reasoning_effort_to_token_budget.get(reasoning_effort) - if budget is not None: + if budget is None: raise ValueError(f"Unknown reasoning_effort: {reasoning_effort}") return budget @@ -358,7 +358,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): params, SamplingParams) else None) resolved_thinking_token_budget = self._resolve_thinking_token_budget( reasoning_effort, thinking_token_budget) - if thinking_token_budget is not None or reasoning_effort is not None: + if resolved_thinking_token_budget is not None: self._state[index] = self._init_state_entry( prompt_tok_ids, resolved_thinking_token_budget) self._state[index]["output_tok_ids"] = output_tok_ids From a44e95691d45b442f484898554d0235049f246f0 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 19 Jul 2025 08:48:44 +0000 Subject: [PATCH 11/61] pass ruff Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 63 +++++++++++++++------- vllm/v1/worker/gpu_model_runner.py | 1 - 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 7b15eb48e5bc..b697ce01f64b 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -251,18 +251,27 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, "medium": 2048, "high": 8192, } - self.think_start_token_ids = getattr(reasoning_config, "think_start_token_ids", []) - self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", []) - self.reasoning_effort_to_token_budget['low'] = getattr(reasoning_config, "low_effort_token_budget", self.reasoning_effort_to_token_budget['low']) - self.reasoning_effort_to_token_budget['medium'] = getattr(reasoning_config, "medium_effort_token_budget", self.reasoning_effort_to_token_budget['medium']) - self.reasoning_effort_to_token_budget['high'] = getattr(reasoning_config, "high_effort_token_budget", self.reasoning_effort_to_token_budget['high']) + self.think_start_token_ids = getattr( + reasoning_config, "think_start_token_ids", []) + self.think_end_token_ids = getattr( + reasoning_config, "think_end_token_ids", []) + self.reasoning_effort_to_token_budget['low'] = getattr( + reasoning_config, "low_effort_token_budget", + self.reasoning_effort_to_token_budget['low']) + self.reasoning_effort_to_token_budget['medium'] = getattr( + reasoning_config, "medium_effort_token_budget", + self.reasoning_effort_to_token_budget['medium']) + self.reasoning_effort_to_token_budget['high'] = getattr( + reasoning_config, "high_effort_token_budget", + self.reasoning_effort_to_token_budget['high']) self.pin_memory = pin_memory self.device = device self._state: dict[int, dict[str, Any]] = {} @staticmethod - def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: + def _find_last_sequence_index( + target_list: list[int], token_ids: list[int]) -> int: """ Returns the index of the last occurrence of token_ids in target_list. @@ -278,7 +287,9 @@ def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> i return i return -1 - def _resolve_thinking_token_budget(self, reasoning_effort: Optional[str], thinking_token_budget: Optional[int]) -> int: + def _resolve_thinking_token_budget( + self, reasoning_effort: Optional[str], + thinking_token_budget: Optional[int]) -> int: """ Determines the final thinking token budget. Priority: @@ -291,12 +302,15 @@ def _resolve_thinking_token_budget(self, reasoning_effort: Optional[str], thinki if reasoning_effort is not None: budget = self.reasoning_effort_to_token_budget.get(reasoning_effort) if budget is None: - raise ValueError(f"Unknown reasoning_effort: {reasoning_effort}") + raise ValueError( + f"Unknown reasoning_effort: {reasoning_effort}") return budget return None - def _init_state_entry(self, prompt_tok_ids: list[int], thinking_token_budget: int) -> dict[str, Any]: + def _init_state_entry( + self, prompt_tok_ids: list[int], + thinking_token_budget: int) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" last_start = self._find_last_sequence_index( prompt_tok_ids, self.think_start_token_ids) @@ -322,10 +336,12 @@ def _update_think_state(self, state: dict[str, Any]): return # Check if recent output matches start or end sequences - if output[-len(self.think_start_token_ids):] == self.think_start_token_ids: + if output[-len(self.think_start_token_ids):] \ + == self.think_start_token_ids: state["in_think"] = True state["think_count"] = 0 - elif output[-len(self.think_end_token_ids):] == self.think_end_token_ids: + elif output[-len(self.think_end_token_ids):] \ + == self.think_end_token_ids: state["in_think"] = False state["think_count"] = 0 elif state["in_think"]: @@ -338,7 +354,8 @@ def _update_think_state(self, state: dict[str, Any]): state["in_end"] = False state["end_count"] = 0 else: - if state["in_think"] and state["think_count"] >= state["thinking_token_budget"]: + if state["in_think"] and state["think_count"] \ + >= state["thinking_token_budget"]: state["in_think"] = False state["in_end"] = True state["end_count"] = 0 @@ -351,13 +368,16 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: - for (index, params, prompt_tok_ids, output_tok_ids) in batch_update.added: + for (index, params, prompt_tok_ids, output_tok_ids) \ + in batch_update.added: reasoning_effort = (params.reasoning_effort if isinstance( params, SamplingParams) else None) - thinking_token_budget = (params.thinking_token_budget if isinstance( - params, SamplingParams) else None) - resolved_thinking_token_budget = self._resolve_thinking_token_budget( - reasoning_effort, thinking_token_budget) + thinking_token_budget = (params.thinking_token_budget + if isinstance( + params, SamplingParams) else None) + resolved_thinking_token_budget = \ + self._resolve_thinking_token_budget( + reasoning_effort, thinking_token_budget) if resolved_thinking_token_budget is not None: self._state[index] = self._init_state_entry( prompt_tok_ids, resolved_thinking_token_budget) @@ -368,7 +388,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = self._state[i2], self._state[i1] + self._state[i1], self._state[i2] = \ + self._state[i2], self._state[i1] else: self._state[i2] = self._state.pop(i1, {}) @@ -381,13 +402,15 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) - force_token_ids = torch.full((batch_size,), -1, dtype=torch.long, device=logits.device) + force_token_ids = torch.full((batch_size,), -1, + dtype=torch.long, device=logits.device) for i in range(batch_size): state = self._state.get(i) if state and state["in_end"]: mask[i] = True - force_token_ids[i] = self.think_end_token_ids[state["end_count"]] + force_token_ids[i] = \ + self.think_end_token_ids[state["end_count"]] if mask.any(): logits[mask] = -float("inf") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e1a14b2c4bd7..e8d9ac22fbb4 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -57,7 +57,6 @@ PlaceholderRange) from vllm.multimodal.utils import group_mm_kwargs_by_modality from vllm.pooling_params import PoolingParams -from vllm.reasoning import ReasoningParserManager from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask From 0272a72cf31cc4ed3df036ccfd2a1ea2812961e0 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 19 Jul 2025 09:27:09 +0000 Subject: [PATCH 12/61] pass precommit Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 50 +++++++++++----------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index b697ce01f64b..6ecc08f84988 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -251,10 +251,10 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, "medium": 2048, "high": 8192, } - self.think_start_token_ids = getattr( - reasoning_config, "think_start_token_ids", []) - self.think_end_token_ids = getattr( - reasoning_config, "think_end_token_ids", []) + self.think_start_token_ids = getattr(reasoning_config, + "think_start_token_ids", []) + self.think_end_token_ids = getattr(reasoning_config, + "think_end_token_ids", []) self.reasoning_effort_to_token_budget['low'] = getattr( reasoning_config, "low_effort_token_budget", self.reasoning_effort_to_token_budget['low']) @@ -270,8 +270,8 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, self._state: dict[int, dict[str, Any]] = {} @staticmethod - def _find_last_sequence_index( - target_list: list[int], token_ids: list[int]) -> int: + def _find_last_sequence_index(target_list: list[int], + token_ids: list[int]) -> int: """ Returns the index of the last occurrence of token_ids in target_list. @@ -288,8 +288,8 @@ def _find_last_sequence_index( return -1 def _resolve_thinking_token_budget( - self, reasoning_effort: Optional[str], - thinking_token_budget: Optional[int]) -> int: + self, reasoning_effort: Optional[str], + thinking_token_budget: Optional[int]) -> int: """ Determines the final thinking token budget. Priority: @@ -300,7 +300,8 @@ def _resolve_thinking_token_budget( return thinking_token_budget if reasoning_effort is not None: - budget = self.reasoning_effort_to_token_budget.get(reasoning_effort) + budget = self.reasoning_effort_to_token_budget.get( + reasoning_effort) if budget is None: raise ValueError( f"Unknown reasoning_effort: {reasoning_effort}") @@ -308,22 +309,21 @@ def _resolve_thinking_token_budget( return None - def _init_state_entry( - self, prompt_tok_ids: list[int], - thinking_token_budget: int) -> dict[str, Any]: + def _init_state_entry(self, prompt_tok_ids: list[int], + thinking_token_budget: int) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" - last_start = self._find_last_sequence_index( - prompt_tok_ids, self.think_start_token_ids) - last_end = self._find_last_sequence_index( - prompt_tok_ids, self.think_end_token_ids) + last_start = self._find_last_sequence_index(prompt_tok_ids, + self.think_start_token_ids) + last_end = self._find_last_sequence_index(prompt_tok_ids, + self.think_end_token_ids) in_think = last_start > last_end think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 return { - "in_think": in_think, # Currently in thinking mode - "in_end": False, # Currently forcing end tokens - "think_count": think_count, # Number of tokens in thinking section - "end_count": 0, # Number of end tokens forced so far + "in_think": in_think, # Currently in thinking mode + "in_end": False, # Currently forcing end tokens + "think_count": think_count, # Number of tokens in thinking section + "end_count": 0, # Number of end tokens forced so far "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], "thinking_token_budget": thinking_token_budget, @@ -373,8 +373,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): reasoning_effort = (params.reasoning_effort if isinstance( params, SamplingParams) else None) thinking_token_budget = (params.thinking_token_budget - if isinstance( - params, SamplingParams) else None) + if isinstance(params, SamplingParams) + else None) resolved_thinking_token_budget = \ self._resolve_thinking_token_budget( reasoning_effort, thinking_token_budget) @@ -402,8 +402,10 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) - force_token_ids = torch.full((batch_size,), -1, - dtype=torch.long, device=logits.device) + force_token_ids = torch.full((batch_size,), + -1, + dtype=torch.long, + device=logits.device) for i in range(batch_size): state = self._state.get(i) From 79c70617c9be6b0f41e09df17b05c2b326bae65c Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 19 Jul 2025 09:35:16 +0000 Subject: [PATCH 13/61] fix format Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 6ecc08f84988..19372375b793 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -323,7 +323,7 @@ def _init_state_entry(self, prompt_tok_ids: list[int], "in_think": in_think, # Currently in thinking mode "in_end": False, # Currently forcing end tokens "think_count": think_count, # Number of tokens in thinking section - "end_count": 0, # Number of end tokens forced so far + "end_count": 0, # Number of end tokens forced so far "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], "thinking_token_budget": thinking_token_budget, @@ -402,7 +402,7 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: return logits mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) - force_token_ids = torch.full((batch_size,), + force_token_ids = torch.full((batch_size, ), -1, dtype=torch.long, device=logits.device) From 44f2acb5b3d6a4de8497a2ee4d6919df4c87f50b Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 21 Jul 2025 05:11:48 +0000 Subject: [PATCH 14/61] fix: loads none error Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/engine/arg_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4e808873ada4..ce048901c56d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1434,8 +1434,9 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) - reasoning_config_dict = json.loads(self.reasoning_config) - reasoning_config = ReasoningConfig(**reasoning_config_dict) + if self.reasoning_config is not None: + reasoning_config_dict = json.loads(self.reasoning_config) + reasoning_config = ReasoningConfig(**reasoning_config_dict) config = VllmConfig( model_config=model_config, From 47da3789fe9ffea2897f166aa099ba540093ab1a Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 21 Jul 2025 05:19:04 +0000 Subject: [PATCH 15/61] fix return type Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 19372375b793..0e51f9001156 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -289,7 +289,7 @@ def _find_last_sequence_index(target_list: list[int], def _resolve_thinking_token_budget( self, reasoning_effort: Optional[str], - thinking_token_budget: Optional[int]) -> int: + thinking_token_budget: Optional[int]) -> Optional[int]: """ Determines the final thinking token budget. Priority: From 11ac0ef042b9bdcaeba5249d1fbc049dbfbaf652 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 21 Jul 2025 05:40:52 +0000 Subject: [PATCH 16/61] fix error Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/engine/arg_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ce048901c56d..bfc41dc84625 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1434,6 +1434,7 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) + reasoning_config = None if self.reasoning_config is not None: reasoning_config_dict = json.loads(self.reasoning_config) reasoning_config = ReasoningConfig(**reasoning_config_dict) From 7fe7fe40da504c1df3beffca4c841ae636d628a9 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:37:01 +0000 Subject: [PATCH 17/61] update ReasoningConfig handling Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/engine/arg_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bfc41dc84625..ead647a7353a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1434,11 +1434,6 @@ def create_engine_config( collect_detailed_traces=self.collect_detailed_traces, ) - reasoning_config = None - if self.reasoning_config is not None: - reasoning_config_dict = json.loads(self.reasoning_config) - reasoning_config = ReasoningConfig(**reasoning_config_dict) - config = VllmConfig( model_config=model_config, cache_config=cache_config, @@ -1453,7 +1448,7 @@ def create_engine_config( compilation_config=self.compilation_config, kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, - reasoning_config=reasoning_config, + reasoning_config=self.reasoning_config, additional_config=self.additional_config, ) From 336efe62bea95a1b2692520dbb6e9ce4c904c032 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:59:01 +0000 Subject: [PATCH 18/61] fix config and EngineArgs Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/engine/arg_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ead647a7353a..08b4c5d510dc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -456,7 +456,8 @@ class EngineArgs: kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None - reasoning_config: Optional[ReasoningConfig] = None + reasoning_config: ReasoningConfig = get_field(VllmConfig, + "reasoning_config") generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode From 4b64abff75638c8b40bc464b6accd3b9b024e03e Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Tue, 22 Jul 2025 07:44:59 +0000 Subject: [PATCH 19/61] simplify reasoning config checks and fix errors Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/worker/gpu_model_runner.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index e8d9ac22fbb4..05dee9de849e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -198,12 +198,14 @@ def __init__( scheduler_config=self.vllm_config.scheduler_config, lora_config=self.vllm_config.lora_config, ).get_lora_tokenizer(None) - reasoning_config.think_start_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_start_str)) - reasoning_config.think_end_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_end_str)) + if reasoning_config.think_start_str is not None: + reasoning_config.think_start_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_start_str)) + if reasoning_config.think_end_str is not None: + reasoning_config.think_end_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_end_str)) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From ace7c4f168bc099ec23e56907e30e72a672f59aa Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 27 Jul 2025 07:32:55 +0000 Subject: [PATCH 20/61] reafctor ThinkingTokenBudgetLogitsProcessor Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 63 +++++++++++----------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 0e51f9001156..216e82fcc590 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -237,7 +237,7 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor): """Limits the number of tokens allowed inside a 'thinking' section.""" def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, - device: torch.device): + device: torch.device, max_num_reqs: int): """ Args: reasoning_config: Configuration for reasoning, which includes @@ -269,6 +269,13 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, self.device = device self._state: dict[int, dict[str, Any]] = {} + # Preallocate reusable tensors + self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device) + self.force_token_ids = torch.full((max_num_reqs, ), + -1, + dtype=torch.long, + device=device) + @staticmethod def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: @@ -281,7 +288,6 @@ def _find_last_sequence_index(target_list: list[int], """ if not token_ids: return -1 - for i in range(len(target_list) - len(token_ids), -1, -1): if target_list[i:i + len(token_ids)] == token_ids: return i @@ -296,18 +302,16 @@ def _resolve_thinking_token_budget( 1. If explicit thinking token budget is given, use it. 2. Otherwise, use reasoning_effort mapping. """ + budget = None if thinking_token_budget is not None: - return thinking_token_budget - - if reasoning_effort is not None: + budget = thinking_token_budget + elif reasoning_effort is not None: budget = self.reasoning_effort_to_token_budget.get( reasoning_effort) if budget is None: raise ValueError( f"Unknown reasoning_effort: {reasoning_effort}") - return budget - - return None + return budget def _init_state_entry(self, prompt_tok_ids: list[int], thinking_token_budget: int) -> dict[str, Any]: @@ -338,12 +342,10 @@ def _update_think_state(self, state: dict[str, Any]): # Check if recent output matches start or end sequences if output[-len(self.think_start_token_ids):] \ == self.think_start_token_ids: - state["in_think"] = True - state["think_count"] = 0 + state.update({"in_think": True, "think_count": 0}) elif output[-len(self.think_end_token_ids):] \ == self.think_end_token_ids: - state["in_think"] = False - state["think_count"] = 0 + state.update({"in_think": False, "think_count": 0}) elif state["in_think"]: state["think_count"] += 1 @@ -351,14 +353,15 @@ def _update_think_state(self, state: dict[str, Any]): if state["in_end"]: state["end_count"] += 1 if state["end_count"] >= len(self.think_end_token_ids): - state["in_end"] = False - state["end_count"] = 0 + state.update({"in_end": False, "end_count": 0}) else: if state["in_think"] and state["think_count"] \ >= state["thinking_token_budget"]: - state["in_think"] = False - state["in_end"] = True - state["end_count"] = 0 + state.update({ + "in_think": False, + "in_end": True, + "end_count": 0 + }) def is_argmax_invariant(self) -> bool: """This logits processor can change the outcome of @@ -397,28 +400,24 @@ def update_state(self, batch_update: Optional[BatchUpdate]): self._update_think_state(state) def apply(self, logits: torch.Tensor) -> torch.Tensor: - batch_size = logits.size(0) if not self._state: return logits - mask = torch.zeros(batch_size, dtype=torch.bool, device=logits.device) - force_token_ids = torch.full((batch_size, ), - -1, - dtype=torch.long, - device=logits.device) + batch_size = logits.size(0) + self.mask[:batch_size] = False for i in range(batch_size): state = self._state.get(i) if state and state["in_end"]: - mask[i] = True - force_token_ids[i] = \ - self.think_end_token_ids[state["end_count"]] - - if mask.any(): - logits[mask] = -float("inf") - row_indices = torch.arange(batch_size, device=logits.device)[mask] - col_indices = force_token_ids[mask] - logits[row_indices, col_indices] = 0.0 + self.mask[i] = True + self.force_token_ids[i] = \ + self.think_end_token_ids[state["end_count"]] + + current_mask = self.mask[:batch_size] + if current_mask.any(): + logits[current_mask] = -float("inf") + logits[current_mask, + self.force_token_ids[:batch_size][current_mask]] = 0.0 return logits From 43dd44082aa6c236fa5d1bff770a63ee20545961 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 27 Jul 2025 08:11:16 +0000 Subject: [PATCH 21/61] fix import error from rebase Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 05dee9de849e..9c7ae5c8eaf5 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -60,6 +60,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, From 9ee7f2f925d8898135fe0972f6e47fd77e4ad4f3 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 16 Aug 2025 10:16:52 +0000 Subject: [PATCH 22/61] fix: remove duplicate reasoning_effort field in ChatCompletionRequest Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/entrypoints/openai/protocol.py | 3 +-- vllm/sampling_params.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0254ceec1439..1b40270f5918 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -465,6 +465,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ChatCompletionNamedToolChoiceParam, ]] = "none" reasoning_effort: Optional[Literal["low", "medium", "high"]] = None + thinking_token_budget: Optional[int] = None include_reasoning: bool = True # NOTE this will be ignored by vLLM -- the model determines the behavior @@ -488,8 +489,6 @@ class ChatCompletionRequest(OpenAIBaseModel): prompt_logprobs: Optional[int] = None allowed_token_ids: Optional[list[int]] = None bad_words: list[str] = Field(default_factory=list) - reasoning_effort: Optional[str] = None - thinking_token_budget: Optional[int] = None # --8<-- [end:chat-completion-sampling-params] # --8<-- [start:chat-completion-extra-params] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 3d5a591e8ed1..ac0885da31e9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -6,7 +6,7 @@ from dataclasses import field from enum import Enum, IntEnum from functools import cached_property -from typing import Annotated, Any, Optional, Union +from typing import Annotated, Any, Literal, Optional, Union import msgspec from pydantic.dataclasses import dataclass @@ -214,7 +214,7 @@ class SamplingParams( _bad_words_token_ids: Optional[list[list[int]]] = None # Fields used for reasoning - reasoning_effort: Optional[str] = None + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None thinking_token_budget: Optional[int] = None """Maximum number of tokens allowed for thinking operations.""" @@ -233,7 +233,7 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, - reasoning_effort: Optional[str] = None, + reasoning_effort: Optional[Literal["low", "medium", "high"]] = None, thinking_token_budget: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, From 117ca92167cb796e7ecee41df19ff2d099773416 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 17 Aug 2025 07:39:59 +0000 Subject: [PATCH 23/61] fix runtime error after rebase Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/__init__.py | 4 +++- vllm/v1/sample/logits_processor/builtin.py | 18 ++++++++---------- vllm/v1/worker/gpu_model_runner.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 10cad5b53071..08d97cf9468f 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -16,6 +16,7 @@ from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, + ThinkingTokenBudgetLogitsProcessor, process_dict_updates) from vllm.v1.sample.logits_processor.interface import (BatchUpdate, LogitsProcessor, @@ -39,6 +40,7 @@ MinTokensLogitsProcessor, LogitBiasLogitsProcessor, MinPLogitsProcessor, + ThinkingTokenBudgetLogitsProcessor, ] @@ -290,5 +292,5 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", - "AdapterLogitsProcessor" + "AdapterLogitsProcessor", "ThinkingTokenBudgetLogitsProcessor" ] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 216e82fcc590..16cd9cfa9032 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Sequence -from typing import TYPE_CHECKING, Callable, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar import torch @@ -236,8 +236,8 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor): """Limits the number of tokens allowed inside a 'thinking' section.""" - def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, - device: torch.device, max_num_reqs: int): + def __init__(self, vllm_config: "VllmConfig", device: torch.device, + is_pin_memory: bool): """ Args: reasoning_config: Configuration for reasoning, which includes @@ -245,7 +245,8 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, pin_memory (bool): Whether to use pinned memory for tensors. device (torch.device): Device to use for tensor operations. """ - super().__init__() + reasoning_config = vllm_config.reasoning_config + max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.reasoning_effort_to_token_budget = { "low": 1024, "medium": 2048, @@ -265,7 +266,7 @@ def __init__(self, reasoning_config: ReasoningConfig, pin_memory: bool, reasoning_config, "high_effort_token_budget", self.reasoning_effort_to_token_budget['high']) - self.pin_memory = pin_memory + self.pin_memory = is_pin_memory self.device = device self._state: dict[int, dict[str, Any]] = {} @@ -373,11 +374,8 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: for (index, params, prompt_tok_ids, output_tok_ids) \ in batch_update.added: - reasoning_effort = (params.reasoning_effort if isinstance( - params, SamplingParams) else None) - thinking_token_budget = (params.thinking_token_budget - if isinstance(params, SamplingParams) - else None) + reasoning_effort = params.reasoning_effort + thinking_token_budget = params.thinking_token_budget resolved_thinking_token_budget = \ self._resolve_thinking_token_budget( reasoning_effort, thinking_token_budget) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9c7ae5c8eaf5..f9dfb3064570 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -24,8 +24,8 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.cuda_graph import CUDAGraphWrapper from vllm.compilation.monitor import set_cudagraph_capturing_enabled -from vllm.config import (CompilationLevel, CUDAGraphMode, ReasoningConfig, - VllmConfig, get_layers_from_vllm_config, update_config) +from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, + get_layers_from_vllm_config, update_config) from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) From 60a275f46080a1be8a7b92d0c1dac474778dbd06 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:06:48 +0000 Subject: [PATCH 24/61] check reasoning is enabled Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 16cd9cfa9032..1cabbd50e3ba 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -247,6 +247,11 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device, """ reasoning_config = vllm_config.reasoning_config max_num_reqs = vllm_config.scheduler_config.max_num_seqs + + # Check if thinking is enabled + self.is_enabled = (reasoning_config is not None + and reasoning_config.is_thinking_enabled()) + self.reasoning_effort_to_token_budget = { "low": 1024, "medium": 2048, @@ -371,6 +376,8 @@ def is_argmax_invariant(self) -> bool: return False def update_state(self, batch_update: Optional[BatchUpdate]): + if not self.is_enabled: + return if batch_update: for (index, params, prompt_tok_ids, output_tok_ids) \ in batch_update.added: @@ -398,7 +405,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): self._update_think_state(state) def apply(self, logits: torch.Tensor) -> torch.Tensor: - if not self._state: + if not self.is_enabled or not self._state: return logits batch_size = logits.size(0) From f4afba91e155aefc5cc93d3bd244f121a3e7da03 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Tue, 19 Aug 2025 13:33:54 +0000 Subject: [PATCH 25/61] add test and implement processor with incremental token processing optimization Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- .../v1/logits_processors/test_correctness.py | 192 ++++++++++++++++-- vllm/v1/sample/logits_processor/builtin.py | 60 ++++-- 2 files changed, 224 insertions(+), 28 deletions(-) diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 43caef79b02f..6c61571a5ee6 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -20,13 +20,10 @@ from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available # yapf: disable -from vllm.v1.sample.logits_processor import (BatchUpdate, BatchUpdateBuilder, - LogitBiasLogitsProcessor, - LogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - MoveDirectionality, - build_logitsprocs) +from vllm.v1.sample.logits_processor import ( + BatchUpdate, BatchUpdateBuilder, LogitBiasLogitsProcessor, LogitsProcessor, + MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality, + ThinkingTokenBudgetLogitsProcessor, build_logitsprocs) # yapf: enable from vllm.v1.sample.metadata import SamplingMetadata @@ -43,6 +40,14 @@ REQS_PER_LOGITPROC = 50 STR_NO_LOGITPROC = "none" +# ThinkingTokenBudgetLogitsProcessor testing constants +REASONING_EFFORT = "low" +THINK_START_TOKEN_ID = 999 +THINK_END_TOKEN_ID = 998 +LOW_EFFORT_TOKEN_BUDGET = 5 +MEDIUM_EFFORT_TOKEN_BUDGET = 10 +HIGH_EFFORT_TOKEN_BUDGET = 20 + # LogitsProcessor subclass or "none" LogitprocType = Union[type[LogitsProcessor], str] @@ -62,10 +67,24 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType): self.workload_index = workload_index self.logitproc_type = logitproc_type # Number of output tokens is randomly 0 or twice the min-tokens - # threshold which will be used in testing. Output token values - # don't matter *for these tests* so use 0 as a dummy value - self.out_tokens = ([0] * - (MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2))) + # threshold which will be used in testing. + # Generate diverse random tokens for all processors (more realistic) + num_tokens = MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2) + if num_tokens > 0: + # Use diverse random tokens + self.out_tokens = [ + random.randint(1, 950) for _ in range(num_tokens) + ] + # Set first token for ThinkingTokenBudget testing + is_thinking_processor = (logitproc_type + is ThinkingTokenBudgetLogitsProcessor or + (hasattr(logitproc_type, '__name__') + and logitproc_type.__name__ + == 'ThinkingTokenBudgetLogitsProcessor')) + if is_thinking_processor: + self.out_tokens[0] = THINK_START_TOKEN_ID + else: + self.out_tokens = [] self.prompt_tokens = [] self.params = _sampling_params_from_logitproc(logitproc_type) @@ -75,6 +94,18 @@ def __str__(self): return f"MyClass({summ})" +class MockReasoningConfig: + """Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor.""" + think_start_token_ids = [THINK_START_TOKEN_ID] + think_end_token_ids = [THINK_END_TOKEN_ID] + low_effort_token_budget = LOW_EFFORT_TOKEN_BUDGET + medium_effort_token_budget = MEDIUM_EFFORT_TOKEN_BUDGET + high_effort_token_budget = HIGH_EFFORT_TOKEN_BUDGET + + def is_thinking_enabled(self) -> bool: + return True + + def _generate_fake_sampling_metadata( num_output_tokens: int, batch_size: int, @@ -92,8 +123,12 @@ def _generate_fake_sampling_metadata( vocab_size, size=np.random.randint( 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + + vllm_config = VllmConfig() + vllm_config.reasoning_config = MockReasoningConfig() + logitsprocs = build_logitsprocs( - vllm_config=VllmConfig(), + vllm_config=vllm_config, device=device, is_pin_memory=PIN_MEMORY_AVAILABLE, is_pooling_model=False, @@ -368,6 +403,120 @@ def _min_tokens_validate( step_idx=step_idx) +def _thinking_budget_params(kwargs: dict) -> None: + """Set SamplingParams kwargs for thinking token budget tests""" + kwargs["reasoning_effort"] = REASONING_EFFORT + + +def _thinking_budget_validate( + test_fakes: LogitsprocsTestFakes, + persistent_batch: list[LogitsProcsRequestParams], + logits_new: torch.Tensor, + batch_index: int, + request_params: LogitsProcsRequestParams, + step_idx: int, +) -> None: + """Validate thinking token budget processor behavior""" + # Get the ThinkingTokenBudgetLogitsProcessor instance + tb_processor: ThinkingTokenBudgetLogitsProcessor = next( + test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor)) + + # Get current request state + state = tb_processor._state.get(batch_index) + params = request_params.params + + # Validate reasoning effort configuration + if hasattr(params, 'reasoning_effort') and params.reasoning_effort: + # State should exist for requests with reasoning_effort + if state is None: + _raise_error_invalid(msg_suffix=( + f"Expected state for batch {batch_index} " + f"with reasoning_effort={params.reasoning_effort}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + # Validate budget calculation + expected_budget = { + "low": LOW_EFFORT_TOKEN_BUDGET, + "medium": MEDIUM_EFFORT_TOKEN_BUDGET, + "high": HIGH_EFFORT_TOKEN_BUDGET + }.get(params.reasoning_effort, LOW_EFFORT_TOKEN_BUDGET) + + actual_budget = state["thinking_token_budget"] + + if actual_budget != expected_budget: + _raise_error_invalid( + msg_suffix=(f"Budget mismatch: expected {expected_budget}, " + f"got {actual_budget} " + f"for effort {params.reasoning_effort}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + # Check if we're in thinking mode and validate token counting + output_tokens = request_params.out_tokens + + # Find if thinking has started in output tokens + thinking_started = False + start_tokens = tb_processor.think_start_token_ids + + if len(start_tokens) > 0: + for i in range(len(output_tokens) - len(start_tokens) + 1): + if output_tokens[i:i + len(start_tokens)] == start_tokens: + thinking_started = True + break + + if thinking_started: + # If budget is exceeded, validate end token forcing + think_count = state["think_count"] + budget = state["thinking_token_budget"] + + if think_count >= budget: + if not state["in_end"]: + _raise_error_invalid( + msg_suffix=(f"Budget exceeded ({think_count} >= " + f"{budget}) but not " + "forcing end tokens"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + # Validate that only end tokens are allowed + end_tokens = tb_processor.think_end_token_ids + if len(end_tokens) > 0: + expected_end_token_id = end_tokens[min( + state["end_count"], + len(end_tokens) - 1)] + + # Check logits masking + batch_logits = logits_new[batch_index] + for token_id in range(len(batch_logits)): + logit_value = batch_logits[token_id] + + if token_id == expected_end_token_id: + # End token should not be masked + if logit_value == -float("inf"): + _raise_error_invalid( + msg_suffix=( + f"End token {token_id} should not be " + "masked but is"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + else: + # All other tokens should be masked when forcing end + if logit_value != -float("inf"): + _raise_error_invalid( + msg_suffix=( + f"Token {token_id} should be masked " + f"when forcing end tokens, but " + f"logit={logit_value}"), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx) + + def _none_validate( test_fakes: LogitsprocsTestFakes, persistent_batch: list[LogitsProcsRequestParams], @@ -413,16 +562,27 @@ class LogitsprocTestHelpers(NamedTuple): MinTokensLogitsProcessor: LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate), + ThinkingTokenBudgetLogitsProcessor: + LogitsprocTestHelpers(gen_request_fxn=_thinking_budget_params, + eval_fxn=_thinking_budget_validate), } def _get_test_cases() -> list[list[str]]: """Each test case is a set of logitsprocs""" logitsprocs_types = list(logitsprocs_test_mapping.keys()) - return [[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC] - for logitproc_type in logitsprocs_types - if logitproc_type != STR_NO_LOGITPROC - ] + [logitsprocs_types] + + # Isolate ThinkingTokenBudgetLogitsProcessor from all other processors + # to avoid unexpected modification of logits interference + thinking_processor = ThinkingTokenBudgetLogitsProcessor + other_processors = [ + p for p in logitsprocs_types + if p != STR_NO_LOGITPROC and p != thinking_processor + ] + + return ([[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC] + for logitproc_type in other_processors] + + [other_processors] + [[thinking_processor]]) def _generate_fake_step_update( diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 1cabbd50e3ba..aaa95ad03baa 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -337,23 +337,56 @@ def _init_state_entry(self, prompt_tok_ids: list[int], "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], "thinking_token_budget": thinking_token_budget, + "prev_output_length": + 0, # Track previous output length for incremental updates } def _update_think_state(self, state: dict[str, Any]): - """Updates the state based on generated output tokens.""" - output = state["output_tok_ids"] + """Updates the state based on newly generated output tokens.""" + output = state.get("output_tok_ids", []) if not output: return - # Check if recent output matches start or end sequences - if output[-len(self.think_start_token_ids):] \ - == self.think_start_token_ids: - state.update({"in_think": True, "think_count": 0}) - elif output[-len(self.think_end_token_ids):] \ - == self.think_end_token_ids: - state.update({"in_think": False, "think_count": 0}) + # Track previous output length for incremental processing + prev_length = state.get("prev_output_length", 0) + current_length = len(output) + + if current_length <= prev_length: + return + + # Process only newly added tokens + new_tokens = output[prev_length:] + state["prev_output_length"] = current_length + + # Check if new tokens contain think start or end sequences + start_len = len(self.think_start_token_ids) + end_len = len(self.think_end_token_ids) + + # Look for think sequences in recent tokens (including boundary) + # Check overlapping regions where sequences might span boundaries + check_start_idx = max(0, prev_length - max(start_len, end_len) + 1) + recent_tokens = output[check_start_idx:] + + # Find any think start/end sequences in recent tokens + recent_start_pos = self._find_last_sequence_index( + recent_tokens, self.think_start_token_ids) + recent_end_pos = self._find_last_sequence_index( + recent_tokens, self.think_end_token_ids) + + # Update state based on recent sequences + if recent_start_pos >= 0: + # Found think start in recent tokens + absolute_start_pos = check_start_idx + recent_start_pos + state["in_think"] = True + state["think_count"] = current_length - (absolute_start_pos + + start_len) + elif recent_end_pos >= 0: + # Found think end in recent tokens + state["in_think"] = False + state["think_count"] = 0 elif state["in_think"]: - state["think_count"] += 1 + # Continue thinking mode, increment count by new tokens + state["think_count"] += len(new_tokens) # Transition into end mode if thinking token limit exceeded if state["in_end"]: @@ -396,8 +429,11 @@ def update_state(self, batch_update: Optional[BatchUpdate]): for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - self._state[i1], self._state[i2] = \ - self._state[i2], self._state[i1] + state1 = self._state.get(i1, {}) + state2 = self._state.get(i2, {}) + if state1 or state2: + self._state[i1] = state2 + self._state[i2] = state1 else: self._state[i2] = self._state.pop(i1, {}) From 937112003f89160de24b53eaf97a92b80cd1d3e1 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:32:27 +0000 Subject: [PATCH 26/61] remove connection between reasoning_effort and thinking_token_budget Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- .../v1/logits_processors/test_correctness.py | 31 +++++-------- vllm/entrypoints/openai/protocol.py | 1 - vllm/sampling_params.py | 7 +-- vllm/v1/sample/logits_processor/builtin.py | 43 ++----------------- 4 files changed, 14 insertions(+), 68 deletions(-) diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 6c61571a5ee6..788ab872b0d3 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -41,12 +41,9 @@ STR_NO_LOGITPROC = "none" # ThinkingTokenBudgetLogitsProcessor testing constants -REASONING_EFFORT = "low" +THINKING_TOKEN_BUDGET = 5 THINK_START_TOKEN_ID = 999 THINK_END_TOKEN_ID = 998 -LOW_EFFORT_TOKEN_BUDGET = 5 -MEDIUM_EFFORT_TOKEN_BUDGET = 10 -HIGH_EFFORT_TOKEN_BUDGET = 20 # LogitsProcessor subclass or "none" LogitprocType = Union[type[LogitsProcessor], str] @@ -98,9 +95,6 @@ class MockReasoningConfig: """Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor.""" think_start_token_ids = [THINK_START_TOKEN_ID] think_end_token_ids = [THINK_END_TOKEN_ID] - low_effort_token_budget = LOW_EFFORT_TOKEN_BUDGET - medium_effort_token_budget = MEDIUM_EFFORT_TOKEN_BUDGET - high_effort_token_budget = HIGH_EFFORT_TOKEN_BUDGET def is_thinking_enabled(self) -> bool: return True @@ -405,7 +399,7 @@ def _min_tokens_validate( def _thinking_budget_params(kwargs: dict) -> None: """Set SamplingParams kwargs for thinking token budget tests""" - kwargs["reasoning_effort"] = REASONING_EFFORT + kwargs["thinking_token_budget"] = THINKING_TOKEN_BUDGET def _thinking_budget_validate( @@ -425,31 +419,26 @@ def _thinking_budget_validate( state = tb_processor._state.get(batch_index) params = request_params.params - # Validate reasoning effort configuration - if hasattr(params, 'reasoning_effort') and params.reasoning_effort: - # State should exist for requests with reasoning_effort + # Validate thinking token budget configuration + if hasattr(params, + 'thinking_token_budget') and params.thinking_token_budget: + # State should exist for requests with thinking_token_budget if state is None: _raise_error_invalid(msg_suffix=( f"Expected state for batch {batch_index} " - f"with reasoning_effort={params.reasoning_effort}"), + f"with thinking_token_budget={params.thinking_token_budget}"), batch_index=batch_index, request_params=request_params, step_idx=step_idx) - # Validate budget calculation - expected_budget = { - "low": LOW_EFFORT_TOKEN_BUDGET, - "medium": MEDIUM_EFFORT_TOKEN_BUDGET, - "high": HIGH_EFFORT_TOKEN_BUDGET - }.get(params.reasoning_effort, LOW_EFFORT_TOKEN_BUDGET) - + # Validate budget matches what was set + expected_budget = params.thinking_token_budget actual_budget = state["thinking_token_budget"] if actual_budget != expected_budget: _raise_error_invalid( msg_suffix=(f"Budget mismatch: expected {expected_budget}, " - f"got {actual_budget} " - f"for effort {params.reasoning_effort}"), + f"got {actual_budget}"), batch_index=batch_index, request_params=request_params, step_idx=step_idx) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 1b40270f5918..83c031019487 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -796,7 +796,6 @@ def to_sampling_params( structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, bad_words= self.bad_words, - reasoning_effort=self.reasoning_effort, thinking_token_budget=self.thinking_token_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ac0885da31e9..ce4418b83bc9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -6,7 +6,7 @@ from dataclasses import field from enum import Enum, IntEnum from functools import cached_property -from typing import Annotated, Any, Literal, Optional, Union +from typing import Annotated, Any, Optional, Union import msgspec from pydantic.dataclasses import dataclass @@ -213,8 +213,6 @@ class SamplingParams( generated token can complete the sequence.""" _bad_words_token_ids: Optional[list[list[int]]] = None - # Fields used for reasoning - reasoning_effort: Optional[Literal["low", "medium", "high"]] = None thinking_token_budget: Optional[int] = None """Maximum number of tokens allowed for thinking operations.""" @@ -233,7 +231,6 @@ def from_optional( stop: Optional[Union[str, list[str]]] = None, stop_token_ids: Optional[list[int]] = None, bad_words: Optional[list[str]] = None, - reasoning_effort: Optional[Literal["low", "medium", "high"]] = None, thinking_token_budget: Optional[int] = None, include_stop_str_in_output: bool = False, ignore_eos: bool = False, @@ -289,7 +286,6 @@ def from_optional( stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, - reasoning_effort=reasoning_effort, thinking_token_budget=thinking_token_budget, include_stop_str_in_output=include_stop_str_in_output, ignore_eos=ignore_eos, @@ -574,7 +570,6 @@ def __repr__(self) -> str: f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"bad_words={self.bad_words}, " - f"reasoning_effort={self.reasoning_effort}, " f"thinking_token_budget={self.thinking_token_budget}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index aaa95ad03baa..4a2e9035c80c 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -252,24 +252,10 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device, self.is_enabled = (reasoning_config is not None and reasoning_config.is_thinking_enabled()) - self.reasoning_effort_to_token_budget = { - "low": 1024, - "medium": 2048, - "high": 8192, - } self.think_start_token_ids = getattr(reasoning_config, "think_start_token_ids", []) self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", []) - self.reasoning_effort_to_token_budget['low'] = getattr( - reasoning_config, "low_effort_token_budget", - self.reasoning_effort_to_token_budget['low']) - self.reasoning_effort_to_token_budget['medium'] = getattr( - reasoning_config, "medium_effort_token_budget", - self.reasoning_effort_to_token_budget['medium']) - self.reasoning_effort_to_token_budget['high'] = getattr( - reasoning_config, "high_effort_token_budget", - self.reasoning_effort_to_token_budget['high']) self.pin_memory = is_pin_memory self.device = device @@ -299,26 +285,6 @@ def _find_last_sequence_index(target_list: list[int], return i return -1 - def _resolve_thinking_token_budget( - self, reasoning_effort: Optional[str], - thinking_token_budget: Optional[int]) -> Optional[int]: - """ - Determines the final thinking token budget. - Priority: - 1. If explicit thinking token budget is given, use it. - 2. Otherwise, use reasoning_effort mapping. - """ - budget = None - if thinking_token_budget is not None: - budget = thinking_token_budget - elif reasoning_effort is not None: - budget = self.reasoning_effort_to_token_budget.get( - reasoning_effort) - if budget is None: - raise ValueError( - f"Unknown reasoning_effort: {reasoning_effort}") - return budget - def _init_state_entry(self, prompt_tok_ids: list[int], thinking_token_budget: int) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" @@ -414,14 +380,11 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if batch_update: for (index, params, prompt_tok_ids, output_tok_ids) \ in batch_update.added: - reasoning_effort = params.reasoning_effort thinking_token_budget = params.thinking_token_budget - resolved_thinking_token_budget = \ - self._resolve_thinking_token_budget( - reasoning_effort, thinking_token_budget) - if resolved_thinking_token_budget is not None: + + if thinking_token_budget is not None: self._state[index] = self._init_state_entry( - prompt_tok_ids, resolved_thinking_token_budget) + prompt_tok_ids, thinking_token_budget) self._state[index]["output_tok_ids"] = output_tok_ids for index in batch_update.removed: From 4b9b87d826f3a994e246868a91ae743423c93767 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 23 Aug 2025 05:19:52 +0000 Subject: [PATCH 27/61] fix: support corner cases Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 4a2e9035c80c..cde22dbe3276 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -293,7 +293,11 @@ def _init_state_entry(self, prompt_tok_ids: list[int], last_end = self._find_last_sequence_index(prompt_tok_ids, self.think_end_token_ids) in_think = last_start > last_end - think_count = len(prompt_tok_ids) - (last_start + 1) if in_think else 0 + if in_think: + think_count = len(prompt_tok_ids) - ( + last_start + len(self.think_start_token_ids)) + else: + think_count = 0 return { "in_think": in_think, # Currently in thinking mode @@ -340,7 +344,18 @@ def _update_think_state(self, state: dict[str, Any]): recent_tokens, self.think_end_token_ids) # Update state based on recent sequences - if recent_start_pos >= 0: + if recent_start_pos >= 0 and recent_end_pos >= 0: + if recent_start_pos > recent_end_pos: + # Case: ......... + absolute_start_pos = check_start_idx + recent_start_pos + state["in_think"] = True + state["think_count"] = current_length - (absolute_start_pos + + start_len) + else: + # Case: ......... + state["in_think"] = False + state["think_count"] = 0 + elif recent_start_pos >= 0: # Found think start in recent tokens absolute_start_pos = check_start_idx + recent_start_pos state["in_think"] = True From 93afdf073a080a152fd8d883d76a8cc0f0d7b25d Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 23 Aug 2025 06:00:11 +0000 Subject: [PATCH 28/61] cleanup unused parameters Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/worker/gpu_input_batch.py | 2 -- vllm/v1/worker/gpu_model_runner.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ba0462351314..67fb9864b19c 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -9,7 +9,6 @@ import torch from typing_extensions import deprecated -from vllm.config import ReasoningConfig from vllm.lora.request import LoRARequest from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalKwargsItems from vllm.pooling_params import PoolingParams @@ -92,7 +91,6 @@ def __init__( is_spec_decode: bool = False, is_pooling_model: bool = False, num_speculative_tokens: int = 0, - reasoning_config: ReasoningConfig = None, ): self.is_pooling_model = is_pooling_model self.is_spec_decode = is_spec_decode diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f9dfb3064570..d56aac80c153 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -343,7 +343,6 @@ def __init__( self.is_pooling_model, self.vllm_config.model_config.logits_processors), is_pooling_model=self.is_pooling_model, - reasoning_config=self.vllm_config.reasoning_config, ) self.use_async_scheduling = self.scheduler_config.async_scheduling @@ -3882,7 +3881,6 @@ def may_reinitialize_input_batch(self, num_speculative_tokens=( self.vllm_config.speculative_config.num_speculative_tokens if self.vllm_config.speculative_config else 0), - reasoning_config=self.vllm_config.reasoning_config, ) def _allocate_kv_cache_tensors( From 24334b20ef6bfb441b8abfb6c59991357a5d0e5c Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 23 Aug 2025 10:21:54 +0000 Subject: [PATCH 29/61] optimize speed up performance while apply logit processor Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index cde22dbe3276..d51d0d6bc824 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -401,6 +401,9 @@ def update_state(self, batch_update: Optional[BatchUpdate]): self._state[index] = self._init_state_entry( prompt_tok_ids, thinking_token_budget) self._state[index]["output_tok_ids"] = output_tok_ids + else: + # Remove state if no thinking budget + self._state.pop(index, None) for index in batch_update.removed: self._state.pop(index, {}) @@ -432,11 +435,17 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: self.force_token_ids[i] = \ self.think_end_token_ids[state["end_count"]] - current_mask = self.mask[:batch_size] - if current_mask.any(): - logits[current_mask] = -float("inf") - logits[current_mask, - self.force_token_ids[:batch_size][current_mask]] = 0.0 + # Check in CPU first not to sync with GPU + has_active_thinking = any( + state.get("in_end", False) for state in self._state.values()) + + if has_active_thinking: + current_mask = self.mask[:batch_size] + active_indices = current_mask.nonzero(as_tuple=False).view(-1) + if len(active_indices) > 0: + force_tokens = self.force_token_ids[active_indices] + # Apply a large value for the end thinking token id index + logits[active_indices, force_tokens] = 1e9 return logits From 0efea7577973ecad85e0c57c5828926c95268358 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Thu, 4 Sep 2025 16:22:57 +0000 Subject: [PATCH 30/61] utilize logits processor when it is needed, not every step for speed up Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 67 +++++++++++++--------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index d51d0d6bc824..e1f271b43529 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -302,6 +302,7 @@ def _init_state_entry(self, prompt_tok_ids: list[int], return { "in_think": in_think, # Currently in thinking mode "in_end": False, # Currently forcing end tokens + "check_count_down": thinking_token_budget, "think_count": think_count, # Number of tokens in thinking section "end_count": 0, # Number of end tokens forced so far "prompt_tok_ids": prompt_tok_ids, @@ -313,6 +314,11 @@ def _init_state_entry(self, prompt_tok_ids: list[int], def _update_think_state(self, state: dict[str, Any]): """Updates the state based on newly generated output tokens.""" + if not state.get("in_end", False) and state.get("check_count_down", + 0) > 0: + state["check_count_down"] -= 1 + return + output = state.get("output_tok_ids", []) if not output: return @@ -344,43 +350,52 @@ def _update_think_state(self, state: dict[str, Any]): recent_tokens, self.think_end_token_ids) # Update state based on recent sequences - if recent_start_pos >= 0 and recent_end_pos >= 0: - if recent_start_pos > recent_end_pos: - # Case: ......... + if not state["in_end"]: + if recent_start_pos >= 0 and recent_end_pos >= 0: + if recent_start_pos > recent_end_pos: + # Case: ......... + absolute_start_pos = check_start_idx + recent_start_pos + state["in_think"] = True + state["think_count"] = current_length - ( + absolute_start_pos + start_len) + state["check_count_down"] = state[ + "thinking_token_budget"] - state["think_count"] + else: + # Case: ......... + state["in_think"] = False + state["think_count"] = 0 + elif recent_start_pos >= 0: + # Found think start in recent tokens absolute_start_pos = check_start_idx + recent_start_pos state["in_think"] = True state["think_count"] = current_length - (absolute_start_pos + start_len) - else: - # Case: ......... + state["check_count_down"] = state[ + "thinking_token_budget"] - state["think_count"] + elif recent_end_pos >= 0: + # Found think end in recent tokens state["in_think"] = False state["think_count"] = 0 - elif recent_start_pos >= 0: - # Found think start in recent tokens - absolute_start_pos = check_start_idx + recent_start_pos - state["in_think"] = True - state["think_count"] = current_length - (absolute_start_pos + - start_len) - elif recent_end_pos >= 0: - # Found think end in recent tokens - state["in_think"] = False - state["think_count"] = 0 - elif state["in_think"]: - # Continue thinking mode, increment count by new tokens - state["think_count"] += len(new_tokens) - - # Transition into end mode if thinking token limit exceeded - if state["in_end"]: - state["end_count"] += 1 - if state["end_count"] >= len(self.think_end_token_ids): - state.update({"in_end": False, "end_count": 0}) - else: + state["check_count_down"] = state["thinking_token_budget"] + elif state["in_think"]: + # Continue thinking mode, increment count by new tokens + state["think_count"] += len(new_tokens) + if state["in_think"] and state["think_count"] \ >= state["thinking_token_budget"]: state.update({ "in_think": False, "in_end": True, - "end_count": 0 + "end_count": 0, + "check_count_down": state["thinking_token_budget"] + }) + else: + state["end_count"] += 1 + if state["end_count"] >= len(self.think_end_token_ids): + state.update({ + "in_end": False, + "end_count": 0, + "check_count_down": state["thinking_token_budget"] }) def is_argmax_invariant(self) -> bool: From 81362dc4ee81e7ef6fd547de2276d7ca7c3bc459 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 5 Sep 2025 11:06:46 +0000 Subject: [PATCH 31/61] refactor processor Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 48 ++++++++++++---------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index e1f271b43529..cc109ea97fcc 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -353,43 +353,48 @@ def _update_think_state(self, state: dict[str, Any]): if not state["in_end"]: if recent_start_pos >= 0 and recent_end_pos >= 0: if recent_start_pos > recent_end_pos: - # Case: ......... + # Case: ......... - entering think mode absolute_start_pos = check_start_idx + recent_start_pos + new_think_count = current_length - (absolute_start_pos + + start_len) state["in_think"] = True - state["think_count"] = current_length - ( - absolute_start_pos + start_len) - state["check_count_down"] = state[ - "thinking_token_budget"] - state["think_count"] + state["think_count"] = new_think_count else: - # Case: ......... + # Case: ......... - exiting think mode state["in_think"] = False state["think_count"] = 0 elif recent_start_pos >= 0: - # Found think start in recent tokens + # Found think start - entering think mode absolute_start_pos = check_start_idx + recent_start_pos + new_think_count = current_length - (absolute_start_pos + + start_len) state["in_think"] = True - state["think_count"] = current_length - (absolute_start_pos + - start_len) - state["check_count_down"] = state[ - "thinking_token_budget"] - state["think_count"] + state["think_count"] = new_think_count elif recent_end_pos >= 0: - # Found think end in recent tokens + # Found think end - exiting think mode state["in_think"] = False state["think_count"] = 0 - state["check_count_down"] = state["thinking_token_budget"] elif state["in_think"]: # Continue thinking mode, increment count by new tokens state["think_count"] += len(new_tokens) - if state["in_think"] and state["think_count"] \ - >= state["thinking_token_budget"]: - state.update({ - "in_think": False, - "in_end": True, - "end_count": 0, - "check_count_down": state["thinking_token_budget"] - }) + # Set countdown based on current state + if state["in_think"]: + remaining_budget = max( + 0, state["thinking_token_budget"] - state["think_count"]) + state["check_count_down"] = remaining_budget + else: + state["check_count_down"] = state["thinking_token_budget"] + + # Check if need to transition to end mode + if state["in_think"] and state["think_count"] >= state[ + "thinking_token_budget"]: + state["in_think"] = False + state["in_end"] = True + state["end_count"] = 0 + state["check_count_down"] = state["thinking_token_budget"] else: + # In end mode state["end_count"] += 1 if state["end_count"] >= len(self.think_end_token_ids): state.update({ @@ -505,4 +510,3 @@ def process_dict_updates( req_entries[a_index] = b_entry return updated - From 8312aa850ac151a2e1e85945c56e4cfee5d8a951 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 17 Sep 2025 09:21:32 +0000 Subject: [PATCH 32/61] add comment on state Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/builtin.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index cc109ea97fcc..1b9f118e2807 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -259,6 +259,18 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device, self.pin_memory = is_pin_memory self.device = device + # Per-request state tracking for thinking token management + # Key: request_index, Value: state dict containing: + # "in_think": bool - currently in thinking mode + # "in_end": bool - currently forcing end tokens output + # "check_count_down": int - steps remaining until next think + # start/end token parsing + # "think_count": int - number of thinking tokens generated + # "end_count": int - number of end tokens forced so far + # "thinking_token_budget": int - max allowed thinking tokens + # "output_tok_ids": list[int] - generated output tokens + # "prev_output_length": int - previous output length for + # incremental processing self._state: dict[int, dict[str, Any]] = {} # Preallocate reusable tensors From 3b5df9b5fa58d53a4c6fd21d3700d708b80d3573 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:24:27 +0000 Subject: [PATCH 33/61] fix tokenizer init bug Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/worker/gpu_model_runner.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d56aac80c153..45cc51bc1d87 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -60,7 +60,7 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, @@ -195,10 +195,7 @@ def __init__( reasoning_config = self.vllm_config.reasoning_config if reasoning_config is not None: tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - lora_config=self.vllm_config.lora_config, - ).get_lora_tokenizer(None) + model_config=self.vllm_config.model_config) if reasoning_config.think_start_str is not None: reasoning_config.think_start_token_ids = \ tokenizer.convert_tokens_to_ids( From 88fa857d78d5884e800279fceefbc7918271dea8 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Wed, 17 Sep 2025 10:42:46 +0000 Subject: [PATCH 34/61] make precommit Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Signed-off-by: Sungjae Lee --- vllm/v1/sample/logits_processor/__init__.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 08d97cf9468f..8fc41432b66d 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -13,11 +13,9 @@ from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams -from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor, - MinPLogitsProcessor, - MinTokensLogitsProcessor, - ThinkingTokenBudgetLogitsProcessor, - process_dict_updates) +from vllm.v1.sample.logits_processor.builtin import ( + LogitBiasLogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, + ThinkingTokenBudgetLogitsProcessor, process_dict_updates) from vllm.v1.sample.logits_processor.interface import (BatchUpdate, LogitsProcessor, MoveDirectionality) From 998b19aa027505f5eff0cf7b172dfcbfdfb4ac6d Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Thu, 18 Sep 2025 01:21:49 +0000 Subject: [PATCH 35/61] fix change condition of using tokenizer Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 45cc51bc1d87..1a67c3260372 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -193,17 +193,18 @@ def __init__( self.observability_config = vllm_config.observability_config reasoning_config = self.vllm_config.reasoning_config - if reasoning_config is not None: + has_reasoning_strings = (reasoning_config.think_start_str is not None + and reasoning_config.think_end_str + is not None) + if has_reasoning_strings: tokenizer = init_tokenizer_from_configs( model_config=self.vllm_config.model_config) - if reasoning_config.think_start_str is not None: - reasoning_config.think_start_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_start_str)) - if reasoning_config.think_end_str is not None: - reasoning_config.think_end_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_end_str)) + reasoning_config.think_start_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_start_str)) + reasoning_config.think_end_token_ids = \ + tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(reasoning_config.think_end_str)) from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( From 3fadb67c20b37d826119eb2fbec20b7ecd9789fa Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Oct 2025 03:26:27 +0000 Subject: [PATCH 36/61] make precommit Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/__init__.py | 2 +- vllm/engine/arg_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index c6b259c3772e..0d2c9b1fe38a 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -21,11 +21,11 @@ from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, ParallelConfig) from vllm.config.pooler import PoolerConfig +from vllm.config.reasoning import ReasoningConfig from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig -from vllm.config.reasoning import ReasoningConfig from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, get_attr_docs, is_init_field, update_config) from vllm.config.vllm import (VllmConfig, get_cached_compilation_config, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 08b4c5d510dc..b351f0602297 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -30,8 +30,8 @@ ModelDType, ObservabilityConfig, ParallelConfig, PoolerConfig, PrefixCachingHashAlgo, ReasoningConfig, RunnerOption, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, StructuredOutputsConfig, TaskOption, - TokenizerMode, VllmConfig, get_attr_docs) + SpeculativeConfig, StructuredOutputsConfig, + TaskOption, TokenizerMode, VllmConfig, get_attr_docs) from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.utils import get_field From 9a91759fb99f9e47c38374fd498c5370165f860e Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Oct 2025 03:51:18 +0000 Subject: [PATCH 37/61] make precommit Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 26 +++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 1b9f118e2807..a37d9d26ee11 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -297,19 +297,25 @@ def _find_last_sequence_index(target_list: list[int], return i return -1 - def _init_state_entry(self, prompt_tok_ids: list[int], + def _init_state_entry(self, prompt_tok_ids: Optional[list[int]], thinking_token_budget: int) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" - last_start = self._find_last_sequence_index(prompt_tok_ids, - self.think_start_token_ids) - last_end = self._find_last_sequence_index(prompt_tok_ids, - self.think_end_token_ids) - in_think = last_start > last_end - if in_think: - think_count = len(prompt_tok_ids) - ( - last_start + len(self.think_start_token_ids)) - else: + if prompt_tok_ids is None: + last_start = -1 + last_end = -1 + in_think = False think_count = 0 + else: + last_start = self._find_last_sequence_index( + prompt_tok_ids, self.think_start_token_ids) + last_end = self._find_last_sequence_index(prompt_tok_ids, + self.think_end_token_ids) + in_think = last_start > last_end + if in_think: + think_count = len(prompt_tok_ids) - ( + last_start + len(self.think_start_token_ids)) + else: + think_count = 0 return { "in_think": in_think, # Currently in thinking mode From 899e4a97e868a0f5fce3abe18f34dfd0b65dad31 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Oct 2025 03:56:28 +0000 Subject: [PATCH 38/61] fix: support zero thinking token budget Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index a37d9d26ee11..f719824dbec2 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -319,7 +319,7 @@ def _init_state_entry(self, prompt_tok_ids: Optional[list[int]], return { "in_think": in_think, # Currently in thinking mode - "in_end": False, # Currently forcing end tokens + "in_end": in_think and thinking_token_budget == 0, "check_count_down": thinking_token_budget, "think_count": think_count, # Number of tokens in thinking section "end_count": 0, # Number of end tokens forced so far From 86526fbcb3b16fd22112379b5427b5f51288f15c Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 3 Oct 2025 06:27:30 +0000 Subject: [PATCH 39/61] refactor: move reasoning token initialization to config level Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/reasoning.py | 15 +++++++++++++++ vllm/config/vllm.py | 4 ++++ vllm/v1/worker/gpu_model_runner.py | 15 --------------- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 220c739b3fed..c1e67e864129 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -5,7 +5,9 @@ from pydantic.dataclasses import dataclass +from vllm.config.model import ModelConfig from vllm.config.utils import config +from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs @config @@ -29,3 +31,16 @@ def is_thinking_enabled(self) -> bool: and self.think_end_token_ids is not None and len(self.think_start_token_ids) > 0 and len(self.think_end_token_ids) > 0) + + def initialize_token_ids(self, model_config: ModelConfig) -> None: + """Initialize reasoning token IDs from strings using the tokenizer.""" + if (self.think_start_str is not None + and self.think_end_str is not None): + + tokenizer = init_tokenizer_from_configs(model_config=model_config) + + # Convert reasoning strings to token IDs + self.think_start_token_ids = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(self.think_start_str)) + self.think_end_token_ids = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(self.think_end_str)) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index ae5425b17769..628ff96e2a67 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -475,6 +475,10 @@ def __post_init__(self): if not self.instance_id: self.instance_id = random_uuid()[:5] + if (self.reasoning_config is not None + and self.model_config is not None): + self.reasoning_config.initialize_token_ids(self.model_config) + if (envs.VLLM_USE_V1 and not self.scheduler_config.disable_hybrid_kv_cache_manager): # logger should only print warning message for hybrid models. As we diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1a67c3260372..efb4a8c0054f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -60,7 +60,6 @@ from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors from vllm.tasks import GenerationTask, PoolingTask, SupportedTask -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, GiB_bytes, cdiv, check_use_alibi, get_dtype_size, is_pin_memory_available, @@ -192,20 +191,6 @@ def __init__( self.speculative_config = vllm_config.speculative_config self.observability_config = vllm_config.observability_config - reasoning_config = self.vllm_config.reasoning_config - has_reasoning_strings = (reasoning_config.think_start_str is not None - and reasoning_config.think_end_str - is not None) - if has_reasoning_strings: - tokenizer = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config) - reasoning_config.think_start_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_start_str)) - reasoning_config.think_end_token_ids = \ - tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(reasoning_config.think_end_str)) - from vllm.model_executor.models.utils import set_cpu_offload_max_bytes set_cpu_offload_max_bytes( int(self.cache_config.cpu_offload_gb * 1024**3)) From 18a61b98226595384c8738a73c907906c3ffe4e8 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 12 Oct 2025 09:09:43 +0000 Subject: [PATCH 40/61] ruff Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- .pre-commit-config.yaml | 12 - pyproject.toml | 127 +- .../v1/logits_processors/test_correctness.py | 476 ++++--- vllm/config/__init__.py | 62 +- vllm/config/reasoning.py | 20 +- vllm/config/vllm.py | 380 +++--- vllm/engine/arg_utils.py | 1212 +++++++++-------- vllm/entrypoints/openai/protocol.py | 873 +++++++----- vllm/sampling_params.py | 202 +-- vllm/v1/sample/logits_processor/__init__.py | 109 +- vllm/v1/sample/logits_processor/builtin.py | 213 +-- 11 files changed, 2133 insertions(+), 1553 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8ca414ee4269..ea63ef1f528c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,28 +6,16 @@ default_stages: - manual # Run in CI exclude: 'vllm/third_party/.*' repos: -- repo: https://github.com/google/yapf - rev: v0.43.0 - hooks: - - id: yapf - args: [--in-place, --verbose] - # Keep the same list from yapfignore here to avoid yapf failing without any inputs - exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: - id: ruff args: [--output-format, github, --fix] - id: ruff-format - files: ^(.buildkite|benchmarks|examples)/.* - repo: https://github.com/crate-ci/typos rev: v1.35.5 hooks: - id: typos -- repo: https://github.com/PyCQA/isort - rev: 6.0.1 - hooks: - - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format rev: v20.1.3 hooks: diff --git a/pyproject.toml b/pyproject.toml index 034a21f1c12b..2b416d3206c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,27 +52,106 @@ lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:regi where = ["."] include = ["vllm*"] -[tool.yapfignore] -ignore_patterns = [ - ".buildkite/**", - "benchmarks/**", - "build/**", - "examples/**", -] - -[tool.ruff] -# Allow lines to be as long as 80. -line-length = 80 - [tool.ruff.lint.per-file-ignores] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing - skip V0 code -"vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/engine/**/*.py" = ["UP006", "UP035"] -"vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/worker/**/*.py" = ["UP006", "UP035"] +# TEMPORARY! These ignores will be fixed forward +## Line length violations +"csrc/cutlass_extensions/vllm_cutlass_library_extension.py" = ["E501"] +"tests/compile/piecewise/test_simple.py" = ["E501"] +"tests/compile/piecewise/test_toy_llama.py" = ["E501", "B023"] +"tests/entrypoints/conftest.py" = ["E501"] +"tests/entrypoints/openai/test_audio.py" = ["E501"] +"tests/entrypoints/openai/test_chat.py" = ["E501"] +"tests/entrypoints/openai/test_chat_template.py" = ["E501"] +"tests/entrypoints/openai/test_chat_with_tool_reasoning.py" = ["E501"] +"tests/entrypoints/openai/test_completion_with_function_calling.py" = ["E501"] +"tests/entrypoints/openai/test_video.py" = ["E501"] +"tests/entrypoints/openai/test_vision.py" = ["E501"] +"tests/entrypoints/test_chat_utils.py" = ["E501"] +"tests/kernels/moe/modular_kernel_tools/common.py" = ["E501"] +"tests/models/language/generation/test_gemma.py" = ["E501"] +"tests/models/language/generation/test_mistral.py" = ["E501"] +"tests/models/multimodal/generation/test_ultravox.py" = ["E501"] +"tests/models/multimodal/generation/test_voxtral.py" = ["E501"] +"tests/models/multimodal/generation/vlm_utils/custom_inputs.py" = ["E501"] +"tests/tool_use/test_tool_choice_required.py" = ["E501"] +"tests/v1/attention/utils.py" = ["E501"] +"tests/v1/entrypoints/openai/responses/test_image.py" = ["E501"] +"tests/v1/kv_connector/nixl_integration/test_accuracy.py" = ["E501"] +"tests/v1/kv_connector/unit/test_offloading_connector.py" = ["E501"] +"tests/v1/logits_processors/test_custom_offline.py" = ["E501"] +"vllm/attention/ops/pallas_kv_cache_update.py" = ["E501"] +"vllm/compilation/collective_fusion.py" = ["E501"] +"vllm/compilation/wrapper.py" = ["E501"] +"vllm/config/vllm.py" = ["E501"] +"vllm/distributed/device_communicators/all2all.py" = ["E501"] +"vllm/entrypoints/openai/protocol.py" = ["E501"] +"vllm/lora/layers/vocal_parallel_embedding.py" = ["E501"] +"vllm/model_executor/model_loader/bitsandbytes_loader.py" = ["E501"] +"vllm/model_executor/models/bailing_moe.py" = ["E501"] +"vllm/model_executor/models/hyperclovax_vision.py" = ["E501"] +"vllm/model_executor/models/llama4_eagle.py" = ["E501"] +"vllm/model_executor/models/longcat_flash_mtp.py" = ["E501"] +"vllm/model_executor/models/phi4mm.py" = ["E501"] +"vllm/model_executor/models/qwen3_next.py" = ["E501"] +"vllm/model_executor/layers/quantization/ptpc_fp8.py" = ["E501"] +"vllm/v1/attention/backends/mla/common.py" = ["E501"] +"vllm/v1/engine/utils.py" = ["E501"] +"vllm/v1/utils.py" = ["E501"] +"vllm/v1/worker/gpu_model_runner.py" = ["E501"] +## Simplification rules +"tests/distributed/test_expert_placement.py" = ["SIM108"] +"tests/kernels/attention/test_cutlass_mla_decode.py" = ["SIM108"] +"tests/kernels/attention/test_flashmla.py" = ["SIM108"] +"tests/kernels/attention/test_lightning_attn.py" = ["SIM108"] +"tests/kernels/moe/test_pplx_moe.py" = ["SIM108"] +"tests/kernels/quantization/test_cutlass_scaled_mm.py" = ["SIM108"] +"tests/kernels/test_onednn.py" = ["SIM108"] +"tests/kernels/utils.py" = ["SIM108"] +"tests/multimodal/test_processing.py" = ["SIM108"] +"vllm/attention/ops/triton_reshape_and_cache_flash.py" = ["SIM108"] +"vllm/distributed/parallel_state.py" = ["SIM108"] +"vllm/entrypoints/chat_utils.py" = ["SIM108"] +"vllm/entrypoints/llm.py" = ["SIM108"] +"vllm/model_executor/layers/batch_invariant.py" = ["SIM108"] +"vllm/model_executor/layers/fla/ops/chunk_o.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/fused_moe.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/layer.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/modular_kernel.py" = ["SIM108"] +"vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py" = ["SIM108"] +"vllm/model_executor/layers/layernorm.py" = ["SIM108"] +"vllm/model_executor/layers/lightning_attn.py" = ["SIM108"] +"vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py" = ["SIM103"] +"vllm/model_executor/layers/quantization/compressed_tensors/utils.py" = ["SIM110"] +"vllm/model_executor/layers/quantization/quark/utils.py" = ["SIM110"] +"vllm/utils/__init__.py" = ["SIM108"] +"vllm/v1/sample/ops/bad_words.py" = ["SIM108"] +"vllm/v1/sample/rejection_sampler.py" = ["SIM108"] +"vllm/v1/worker/tpu_model_runner.py" = ["SIM108"] +"vllm/_custom_ops.py" = ["SIM108"] +"tools/profiler/print_layerwise_table.py" = ["SIM118"] +## Loop variable binding issues +"tests/kernels/mamba/test_mamba_ssm_ssd.py" = ["B023"] +## Type annotation modernization and other rules +"vllm/attention/backends/abstract.py" = ["UP035", "UP006"] +"vllm/attention/layer.py" = ["UP035", "UP006"] +"vllm/attention/layers/chunked_local_attention.py" = ["UP035", "UP006"] +"vllm/attention/ops/flashmla.py" = ["UP035", "UP006"] +"vllm/attention/ops/paged_attn.py" = ["UP035", "UP006"] +"vllm/engine/arg_utils.py" = ["UP035", "UP006"] +"vllm/engine/metrics.py" = ["UP035", "UP006"] +"vllm/engine/metrics_types.py" = ["UP035", "UP006"] +"vllm/executor/executor_base.py" = ["UP035", "UP006"] +"vllm/executor/msgspec_utils.py" = ["UP035", "UP006"] +"vllm/executor/ray_distributed_executor.py" = ["UP035", "UP006", "SIM108", "SIM112"] +"vllm/executor/ray_utils.py" = ["UP035", "UP006"] +"vllm/executor/uniproc_executor.py" = ["UP035", "UP006"] +"vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py" = ["UP035"] +## Type comparison issues +"vllm/multimodal/inputs.py" = ["E721"] +# End of temporary ignores [tool.ruff.lint] select = [ @@ -87,7 +166,7 @@ select = [ # flake8-simplify "SIM", # isort - # "I", + "I", # flake8-logging-format "G", ] @@ -104,21 +183,15 @@ ignore = [ "UP007", ] +[tool.ruff.format] +docstring-code-format = true + [tool.mypy] plugins = ['pydantic.mypy'] ignore_missing_imports = true check_untyped_defs = true follow_imports = "silent" -[tool.isort] -skip_glob = [ - ".buildkite/*", - "benchmarks/*", - "examples/*", -] -use_parentheses = true -skip_gitignore = true - [tool.pytest.ini_options] markers = [ "slow_test", diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 788ab872b0d3..5a2214d5145d 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -10,20 +10,32 @@ import torch from tests.utils import create_new_process_for_each_test -from tests.v1.sample.utils import (LogitsprocsTestFakes, create_fake_logits, - create_penalty_tensor, - create_prompt_tokens_tensor, - fake_apply_logitsprocs, - fake_update_logitsprocs_state) +from tests.v1.sample.utils import ( + LogitsprocsTestFakes, + create_fake_logits, + create_penalty_tensor, + create_prompt_tokens_tensor, + fake_apply_logitsprocs, + fake_update_logitsprocs_state, +) from vllm.config import VllmConfig from vllm.platforms import current_platform from vllm.sampling_params import SamplingParams from vllm.utils import is_pin_memory_available + # yapf: disable from vllm.v1.sample.logits_processor import ( - BatchUpdate, BatchUpdateBuilder, LogitBiasLogitsProcessor, LogitsProcessor, - MinPLogitsProcessor, MinTokensLogitsProcessor, MoveDirectionality, - ThinkingTokenBudgetLogitsProcessor, build_logitsprocs) + BatchUpdate, + BatchUpdateBuilder, + LogitBiasLogitsProcessor, + LogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + MoveDirectionality, + ThinkingTokenBudgetLogitsProcessor, + build_logitsprocs, +) + # yapf: enable from vllm.v1.sample.metadata import SamplingMetadata @@ -51,9 +63,10 @@ class LogitsProcsRequestParams: """Encapsulates key params for a single request in a batch. - + Params can be customized based on the enabled logitproc """ + workload_index: int logitproc_type: LogitprocType # Logitproc enabled, specified by str id out_tokens: list[int] # Output tokens required for min tokens test @@ -69,15 +82,15 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType): num_tokens = MIN_TOKENS_LEN_THRESHOLD * random.randint(0, 2) if num_tokens > 0: # Use diverse random tokens - self.out_tokens = [ - random.randint(1, 950) for _ in range(num_tokens) - ] + self.out_tokens = [random.randint(1, 950) for _ in range(num_tokens)] # Set first token for ThinkingTokenBudget testing - is_thinking_processor = (logitproc_type - is ThinkingTokenBudgetLogitsProcessor or - (hasattr(logitproc_type, '__name__') - and logitproc_type.__name__ - == 'ThinkingTokenBudgetLogitsProcessor')) + is_thinking_processor = ( + logitproc_type is ThinkingTokenBudgetLogitsProcessor + or ( + hasattr(logitproc_type, "__name__") + and logitproc_type.__name__ == "ThinkingTokenBudgetLogitsProcessor" + ) + ) if is_thinking_processor: self.out_tokens[0] = THINK_START_TOKEN_ID else: @@ -87,12 +100,13 @@ def __init__(self, workload_index: int, logitproc_type: LogitprocType): def __str__(self): """For debugging""" - summ = ', '.join(f'{k}={v}' for k, v in vars(self).items()) + summ = ", ".join(f"{k}={v}" for k, v in vars(self).items()) return f"MyClass({summ})" class MockReasoningConfig: """Mock reasoning config for testing ThinkingTokenBudgetLogitsProcessor.""" + think_start_token_ids = [THINK_START_TOKEN_ID] think_end_token_ids = [THINK_END_TOKEN_ID] @@ -111,12 +125,13 @@ def _generate_fake_sampling_metadata( prompt_token_ids: list[list[int]] = [] for _ in range(batch_size): output_token_ids.append( - np.random.randint(0, vocab_size, size=num_output_tokens).tolist()) + np.random.randint(0, vocab_size, size=num_output_tokens).tolist() + ) prompt_token_ids.append( - np.random.randint(0, - vocab_size, - size=np.random.randint( - 1, MAX_NUM_PROMPT_TOKENS)).tolist()) + np.random.randint( + 0, vocab_size, size=np.random.randint(1, MAX_NUM_PROMPT_TOKENS) + ).tolist() + ) vllm_config = VllmConfig() vllm_config.reasoning_config = MockReasoningConfig() @@ -128,15 +143,16 @@ def _generate_fake_sampling_metadata( is_pooling_model=False, ) fake_sampling_metadata = SamplingMetadata( - temperature=torch.full((batch_size, ), 0.0), + temperature=torch.full((batch_size,), 0.0), all_greedy=True, all_random=False, top_p=None, top_k=None, generators={}, max_num_logprobs=0, - prompt_token_ids=create_prompt_tokens_tensor(prompt_token_ids, - vocab_size, device), + prompt_token_ids=create_prompt_tokens_tensor( + prompt_token_ids, vocab_size, device + ), output_token_ids=output_token_ids, frequency_penalties=create_penalty_tensor(batch_size, 0.0, device), presence_penalties=create_penalty_tensor(batch_size, 0.0, device), @@ -144,7 +160,8 @@ def _generate_fake_sampling_metadata( no_penalties=True, allowed_token_ids_mask=None, bad_words_token_ids={}, - logitsprocs=logitsprocs) + logitsprocs=logitsprocs, + ) return fake_sampling_metadata @@ -156,15 +173,15 @@ def _generate_test_fakes(batch_size: int, device: str) -> LogitsprocsTestFakes: fake_logits[i, 0] = 10.0 # High logit for first token fake_logits[i, 1:] = 1e-2 # Others remain low sampling_metadata = _generate_fake_sampling_metadata( - NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device)) + NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device) + ) return LogitsprocsTestFakes( logits=fake_logits, sampling_metadata=sampling_metadata, ) -def _sampling_params_from_logitproc( - logitproc_type: LogitprocType) -> SamplingParams: +def _sampling_params_from_logitproc(logitproc_type: LogitprocType) -> SamplingParams: """Customize request SamplingParams for a specified logitproc""" # SamplingParams for req with no logitproc kwargs = {"min_p": 0.0, "logit_bias": None, "min_tokens": 0} @@ -179,7 +196,7 @@ def _generate_mixed_logitsprocs_batch_params( ) -> list[LogitsProcsRequestParams]: """Define key params for a batch of requests with a different logitproc enabled per request. - + The batch will have `reqs_per_logitproc` repeats for all `logitsprocs_types` under test, including the case where no logitsproc is enabled. The batch is randomly shuffled. The @@ -202,7 +219,8 @@ def _generate_mixed_logitsprocs_batch_params( return [ LogitsProcsRequestParams( workload_index=idx, - logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc]) + logitproc_type=logitsprocs_types[pdx // reqs_per_logitproc], + ) for idx, pdx in enumerate(batch_perm) ] @@ -214,10 +232,12 @@ def _raise_error_invalid( step_idx: int, err_cls: type[Exception] = ValueError, ) -> None: - raise err_cls(f"Validation failed for step={step_idx}, " - f"batch_index={batch_index}, " - f"workload_index={request_params.workload_index}, " - f"req_params={request_params}. Reason: {msg_suffix}") + raise err_cls( + f"Validation failed for step={step_idx}, " + f"batch_index={batch_index}, " + f"workload_index={request_params.workload_index}, " + f"req_params={request_params}. Reason: {msg_suffix}" + ) def _logit_bias_params(kwargs: dict) -> None: @@ -237,8 +257,7 @@ def _logit_bias_validate( ) -> None: """Validate logit bias logitproc applied correctly""" logit_bias = request_params.params.logit_bias - logits_old = ( - test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits_old = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu() logits_new = logits_new[batch_index].cpu() for token_id in range(VOCAB_SIZE): logit_old_value = logits_old[token_id] @@ -247,22 +266,28 @@ def _logit_bias_validate( bias_value = logit_bias[token_id] exp_value = bias_value + logit_old_value if logit_new_value != pytest.approx(exp_value): - _raise_error_invalid(msg_suffix=( - f"Biased token {token_id} logit value {logit_new_value} " - f"does not match expected value {exp_value} " - f"given bias {bias_value}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Biased token {token_id} logit value {logit_new_value} " + f"does not match expected value {exp_value} " + f"given bias {bias_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) else: if logit_new_value != pytest.approx(logit_old_value): - _raise_error_invalid(msg_suffix=( - f"Unbiased token {token_id} logit value {logit_new_value} " - f"does not match expected value {logit_old_value}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Unbiased token {token_id} logit value {logit_new_value} " + f"does not match expected value {logit_old_value}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) def _min_p_params(kwargs: dict) -> None: @@ -288,26 +313,27 @@ def _min_p_validate( msg_suffix="Invalid: dominant token 0 masked (-inf)", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: if request_params.params.min_p > 0.0: # Non-dominant tokens should be masked when min_p > 0 if logits_for_token != -float("inf"): _raise_error_invalid( - msg_suffix= - f"Invalid: non-dominant token {token_id} not masked", + msg_suffix=f"Invalid: non-dominant token {token_id} not masked", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: # No masking when min_p is 0 if logits_for_token == -float("inf"): _raise_error_invalid( - msg_suffix= - f"Invalid: token {token_id} masked when min_p=0.0", + msg_suffix=f"Invalid: token {token_id} masked when min_p=0.0", batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _min_tokens_params(kwargs: dict) -> None: @@ -332,7 +358,8 @@ def _min_tokens_validate( min_reached = ref_num_out_tokens >= MIN_TOKENS_LEN_THRESHOLD ref_all_stop_token_ids = request_params.params.all_stop_token_ids mt_lp: MinTokensLogitsProcessor = next( - test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor)) + test_fakes.get_logitsprocs_by_cls(MinTokensLogitsProcessor) + ) assert isinstance(mt_lp, MinTokensLogitsProcessor) min_tok = mt_lp.min_toks.get(batch_index, None) @@ -341,38 +368,50 @@ def _min_tokens_validate( (_, out_tok, all_stop_token_ids) = min_tok num_out_tokens = len(out_tok) if num_out_tokens != ref_num_out_tokens: - _raise_error_invalid(msg_suffix=( - "Number of output tokens in min-token logit processor " - f"request metadata ({num_out_tokens}) does not match " - f"reference ({ref_num_out_tokens})."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + "Number of output tokens in min-token logit processor " + f"request metadata ({num_out_tokens}) does not match " + f"reference ({ref_num_out_tokens})." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) if ref_all_stop_token_ids != all_stop_token_ids: - _raise_error_invalid(msg_suffix=( - "Stop token ids do not match reference; all_stop_token_ids: " - f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " - f"{sorted(ref_all_stop_token_ids)}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + "Stop token ids do not match reference; all_stop_token_ids: " + f"{sorted(all_stop_token_ids)}, ref_all_stop_token_ids: " + f"{sorted(ref_all_stop_token_ids)}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) if min_reached: - _raise_error_invalid(msg_suffix=( - "Expected min-tokens request with min reached, but batch " - "index is recognized by min-tokens logits processor."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx, - err_cls=RuntimeError) + _raise_error_invalid( + msg_suffix=( + "Expected min-tokens request with min reached, but batch " + "index is recognized by min-tokens logits processor." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError, + ) elif not min_reached: - _raise_error_invalid(msg_suffix=( - "Expected min-tokens request with min not reached, but batch " - "index is not recognized by min-tokens logits processor."), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx, - err_cls=RuntimeError) + _raise_error_invalid( + msg_suffix=( + "Expected min-tokens request with min not reached, but batch " + "index is not recognized by min-tokens logits processor." + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + err_cls=RuntimeError, + ) # Validate min-token logits for token_id in range(VOCAB_SIZE): @@ -380,21 +419,27 @@ def _min_tokens_validate( if token_id in ref_all_stop_token_ids and not min_reached: if logits_for_token != -float("inf"): _raise_error_invalid( - msg_suffix=(f"Token {token_id} is a stop token and " - "the sequence has not reached min length, " - "but the token is not masked " - f"(logit={logits_for_token})"), + msg_suffix=( + f"Token {token_id} is a stop token and " + "the sequence has not reached min length, " + "but the token is not masked " + f"(logit={logits_for_token})" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: if logits_for_token == -float("inf"): _raise_error_invalid( - msg_suffix=(f"Token {token_id} should not be masked but " - f"is (output len={ref_num_out_tokens})"), + msg_suffix=( + f"Token {token_id} should not be masked but " + f"is (output len={ref_num_out_tokens})" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _thinking_budget_params(kwargs: dict) -> None: @@ -413,23 +458,26 @@ def _thinking_budget_validate( """Validate thinking token budget processor behavior""" # Get the ThinkingTokenBudgetLogitsProcessor instance tb_processor: ThinkingTokenBudgetLogitsProcessor = next( - test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor)) + test_fakes.get_logitsprocs_by_cls(ThinkingTokenBudgetLogitsProcessor) + ) # Get current request state state = tb_processor._state.get(batch_index) params = request_params.params # Validate thinking token budget configuration - if hasattr(params, - 'thinking_token_budget') and params.thinking_token_budget: + if hasattr(params, "thinking_token_budget") and params.thinking_token_budget: # State should exist for requests with thinking_token_budget if state is None: - _raise_error_invalid(msg_suffix=( - f"Expected state for batch {batch_index} " - f"with thinking_token_budget={params.thinking_token_budget}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Expected state for batch {batch_index} " + f"with thinking_token_budget={params.thinking_token_budget}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) # Validate budget matches what was set expected_budget = params.thinking_token_budget @@ -437,11 +485,13 @@ def _thinking_budget_validate( if actual_budget != expected_budget: _raise_error_invalid( - msg_suffix=(f"Budget mismatch: expected {expected_budget}, " - f"got {actual_budget}"), + msg_suffix=( + f"Budget mismatch: expected {expected_budget}, got {actual_budget}" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) # Check if we're in thinking mode and validate token counting output_tokens = request_params.out_tokens @@ -452,7 +502,7 @@ def _thinking_budget_validate( if len(start_tokens) > 0: for i in range(len(output_tokens) - len(start_tokens) + 1): - if output_tokens[i:i + len(start_tokens)] == start_tokens: + if output_tokens[i : i + len(start_tokens)] == start_tokens: thinking_started = True break @@ -464,19 +514,22 @@ def _thinking_budget_validate( if think_count >= budget: if not state["in_end"]: _raise_error_invalid( - msg_suffix=(f"Budget exceeded ({think_count} >= " - f"{budget}) but not " - "forcing end tokens"), + msg_suffix=( + f"Budget exceeded ({think_count} >= " + f"{budget}) but not " + "forcing end tokens" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) # Validate that only end tokens are allowed end_tokens = tb_processor.think_end_token_ids if len(end_tokens) > 0: - expected_end_token_id = end_tokens[min( - state["end_count"], - len(end_tokens) - 1)] + expected_end_token_id = end_tokens[ + min(state["end_count"], len(end_tokens) - 1) + ] # Check logits masking batch_logits = logits_new[batch_index] @@ -489,10 +542,12 @@ def _thinking_budget_validate( _raise_error_invalid( msg_suffix=( f"End token {token_id} should not be " - "masked but is"), + "masked but is" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) else: # All other tokens should be masked when forcing end if logit_value != -float("inf"): @@ -500,10 +555,12 @@ def _thinking_budget_validate( msg_suffix=( f"Token {token_id} should be masked " f"when forcing end tokens, but " - f"logit={logit_value}"), + f"logit={logit_value}" + ), batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) def _none_validate( @@ -515,45 +572,46 @@ def _none_validate( step_idx: int, ) -> None: """Validate that no logits processors are applied""" - logits = ( - test_fakes.logits[persistent_batch[batch_index].workload_index].cpu()) + logits = test_fakes.logits[persistent_batch[batch_index].workload_index].cpu() ref_logits = logits_new[batch_index] if not torch.all(ref_logits == logits): - mismatch_toks = (ref_logits - != logits).nonzero(as_tuple=True)[0].tolist() + mismatch_toks = (ref_logits != logits).nonzero(as_tuple=True)[0].tolist() mismatch_strs = [] for token in mismatch_toks: val = float(logits[token]) ref_val = float(ref_logits[token]) mismatch_strs.append(f"({token=},{val=},{ref_val=})") - _raise_error_invalid(msg_suffix=( - f"Unexpected modification of logits: {','.join(mismatch_strs)}"), - batch_index=batch_index, - request_params=request_params, - step_idx=step_idx) + _raise_error_invalid( + msg_suffix=( + f"Unexpected modification of logits: {','.join(mismatch_strs)}" + ), + batch_index=batch_index, + request_params=request_params, + step_idx=step_idx, + ) class LogitsprocTestHelpers(NamedTuple): """Supports setting up and validating logitsprocs unit tests.""" + eval_fxn: Callable gen_request_fxn: Optional[Callable] = None logitsprocs_test_mapping = { - STR_NO_LOGITPROC: - LogitsprocTestHelpers(eval_fxn=_none_validate), - LogitBiasLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_logit_bias_params, - eval_fxn=_logit_bias_validate), - MinPLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_min_p_params, - eval_fxn=_min_p_validate), - MinTokensLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_min_tokens_params, - eval_fxn=_min_tokens_validate), - ThinkingTokenBudgetLogitsProcessor: - LogitsprocTestHelpers(gen_request_fxn=_thinking_budget_params, - eval_fxn=_thinking_budget_validate), + STR_NO_LOGITPROC: LogitsprocTestHelpers(eval_fxn=_none_validate), + LogitBiasLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_logit_bias_params, eval_fxn=_logit_bias_validate + ), + MinPLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_min_p_params, eval_fxn=_min_p_validate + ), + MinTokensLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_min_tokens_params, eval_fxn=_min_tokens_validate + ), + ThinkingTokenBudgetLogitsProcessor: LogitsprocTestHelpers( + gen_request_fxn=_thinking_budget_params, eval_fxn=_thinking_budget_validate + ), } @@ -565,13 +623,17 @@ def _get_test_cases() -> list[list[str]]: # to avoid unexpected modification of logits interference thinking_processor = ThinkingTokenBudgetLogitsProcessor other_processors = [ - p for p in logitsprocs_types + p + for p in logitsprocs_types if p != STR_NO_LOGITPROC and p != thinking_processor ] - return ([[STR_NO_LOGITPROC]] + [[logitproc_type, STR_NO_LOGITPROC] - for logitproc_type in other_processors] + - [other_processors] + [[thinking_processor]]) + return ( + [[STR_NO_LOGITPROC]] + + [[logitproc_type, STR_NO_LOGITPROC] for logitproc_type in other_processors] + + [other_processors] + + [[thinking_processor]] + ) def _generate_fake_step_update( @@ -589,11 +651,18 @@ def _generate_fake_step_update( # Other 50%: add a limited number of reqs (less than the number # of workload reqs remaining, less than an arbitrary max) # If no workload reqs remain: 100% of steps have 0 adds - num_step_add = random.choice([ - 0, - random.randint(1, min(max_add_remove_per_step, - workload_reqs_remaining)) - ]) if workload_reqs_remaining else 0 + num_step_add = ( + random.choice( + [ + 0, + random.randint( + 1, min(max_add_remove_per_step, workload_reqs_remaining) + ), + ] + ) + if workload_reqs_remaining + else 0 + ) # 50% of steps: remove no requests # Other 50%: remove a limited number of reqs (less than the number @@ -601,9 +670,11 @@ def _generate_fake_step_update( # If persistent batch is empty: 100% of steps have 0 removals until # more requests are added. Assume that removed requests are always # drawn from the current batch, before new adds - num_step_remove = random.choice([ - 0, random.randint(1, min(max_add_remove_per_step, batch_size)) - ]) if batch_size else 0 + num_step_remove = ( + random.choice([0, random.randint(1, min(max_add_remove_per_step, batch_size))]) + if batch_size + else 0 + ) num_step_add_replace = min(num_step_add, num_step_remove) @@ -612,23 +683,34 @@ def _generate_fake_step_update( batch_update_builder.removed_append(removal) # Get added requests from workload - for add_req_params in workload_params[wdx:(wdx + num_step_add_replace)]: + for add_req_params in workload_params[wdx : (wdx + num_step_add_replace)]: # Replace as many removed requests as possible with added requests add_remove_idx = batch_update_builder.pop_removed() batch_update_builder.added.append( - (add_remove_idx, add_req_params.params, - add_req_params.prompt_tokens, add_req_params.out_tokens)) + ( + add_remove_idx, + add_req_params.params, + add_req_params.prompt_tokens, + add_req_params.out_tokens, + ) + ) persistent_batch[add_remove_idx] = add_req_params # Append remaining added requests to end of batch - add_reqs_append = workload_params[(wdx + - num_step_add_replace):(wdx + - num_step_add)] - batch_update_builder.added.extend([ - (adx + batch_size, add_req_params.params, add_req_params.prompt_tokens, - add_req_params.out_tokens) - for adx, add_req_params in enumerate(add_reqs_append) - ]) + add_reqs_append = workload_params[ + (wdx + num_step_add_replace) : (wdx + num_step_add) + ] + batch_update_builder.added.extend( + [ + ( + adx + batch_size, + add_req_params.params, + add_req_params.prompt_tokens, + add_req_params.out_tokens, + ) + for adx, add_req_params in enumerate(add_reqs_append) + ] + ) persistent_batch.extend(add_reqs_append) pre_condense_batch_size = len(persistent_batch) wdx += num_step_add # Update workload offset @@ -637,8 +719,10 @@ def _generate_fake_step_update( last_nonempty_index = pre_condense_batch_size - 1 condensed_to_idxs = set() while batch_update_builder.removed: - if (last_nonempty_index in batch_update_builder.removed - or last_nonempty_index in condensed_to_idxs): + if ( + last_nonempty_index in batch_update_builder.removed + or last_nonempty_index in condensed_to_idxs + ): last_nonempty_index -= 1 continue # last_nonempty_index is the highest persistent batch index that was @@ -653,11 +737,10 @@ def _generate_fake_step_update( # move last_nonempty_index -> first_empty_index batch_update_builder.pop_removed() condensed_to_idxs.add(first_empty_index) - persistent_batch[first_empty_index] = persistent_batch[ - last_nonempty_index] + persistent_batch[first_empty_index] = persistent_batch[last_nonempty_index] batch_update_builder.moved.append( - (last_nonempty_index, first_empty_index, - MoveDirectionality.UNIDIRECTIONAL)) + (last_nonempty_index, first_empty_index, MoveDirectionality.UNIDIRECTIONAL) + ) last_nonempty_index -= 1 @@ -673,18 +756,21 @@ def _generate_fake_step_update( k = random.randint(0, condensed_batch_size // 2) idxs = list(range(condensed_batch_size)) random.shuffle(idxs) - swaps = [ - tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k) - ] - batch_update_builder.moved.extend([ - (sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps - ]) + swaps = [tuple(sorted([idxs[2 * i], idxs[2 * i + 1]])) for i in range(k)] + batch_update_builder.moved.extend( + [(sw[0], sw[1], MoveDirectionality.SWAP) for sw in swaps] + ) for adx, bdx in swaps: - persistent_batch[adx], persistent_batch[bdx] = persistent_batch[ - bdx], persistent_batch[adx] - - return (batch_update_builder.get_and_reset(condensed_batch_size), wdx, - workload_size - wdx) + persistent_batch[adx], persistent_batch[bdx] = ( + persistent_batch[bdx], + persistent_batch[adx], + ) + + return ( + batch_update_builder.get_and_reset(condensed_batch_size), + wdx, + workload_size - wdx, + ) def _assert_valid( @@ -699,8 +785,10 @@ def _assert_valid( # Trivial case of empty persistent batch assert len(persistent_batch) == 0 if logits_w_lp.shape[0] != 0: - raise ValueError("Fake persistent batch is empty but logitsprocs " - f"output batch has shape {logits_w_lp.shape}") + raise ValueError( + "Fake persistent batch is empty but logitsprocs " + f"output batch has shape {logits_w_lp.shape}" + ) return # Validate logits for each fake request @@ -709,36 +797,40 @@ def _assert_valid( # Invoke the appropriate validation function for # the logitproc employed by this request fxn = logitsprocs_test_mapping[request_params.logitproc_type].eval_fxn - fxn(test_fakes=test_fakes, + fxn( + test_fakes=test_fakes, persistent_batch=persistent_batch, logits_new=logits_w_lp, batch_index=batch_index, request_params=request_params, - step_idx=step_idx) + step_idx=step_idx, + ) @create_new_process_for_each_test() @pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("reqs_per_logitproc", [REQS_PER_LOGITPROC]) @pytest.mark.parametrize("logitsprocs_under_test", _get_test_cases()) -def test_logitsprocs(device: str, reqs_per_logitproc: int, - logitsprocs_under_test: list[str]): +def test_logitsprocs( + device: str, reqs_per_logitproc: int, logitsprocs_under_test: list[str] +): random.seed(40) torch.set_default_device(device) # Define a shuffled batch of requests which individually use a different # logitproc, or no logitproc at all workload_params = _generate_mixed_logitsprocs_batch_params( - reqs_per_logitproc=reqs_per_logitproc, - logitsprocs_types=logitsprocs_under_test) + reqs_per_logitproc=reqs_per_logitproc, logitsprocs_types=logitsprocs_under_test + ) workload_size = len(workload_params) # Create fake test data structures for testing. test_fakes = _generate_test_fakes(workload_size, device) wdx = 0 # Next request index in workload to add - persistent_batch: list[LogitsProcsRequestParams] = [ - ] # Persistent batch state, as list of workload indices + persistent_batch: list[ + LogitsProcsRequestParams + ] = [] # Persistent batch state, as list of workload indices # Generate fake removed request indices from current persistent # batch before adds diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 0d2c9b1fe38a..f6de68f0676f 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -1,37 +1,61 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from vllm.config.cache import (BlockSize, CacheConfig, CacheDType, MambaDType, - PrefixCachingHashAlgo) -from vllm.config.compilation import (CompilationConfig, CompilationLevel, - CUDAGraphMode, PassConfig) +from vllm.config.cache import ( + BlockSize, + CacheConfig, + CacheDType, + MambaDType, + PrefixCachingHashAlgo, +) +from vllm.config.compilation import ( + CompilationConfig, + CompilationLevel, + CUDAGraphMode, + PassConfig, +) from vllm.config.device import Device, DeviceConfig from vllm.config.kv_events import KVEventsConfig from vllm.config.kv_transfer import KVTransferConfig from vllm.config.load import LoadConfig from vllm.config.lora import LoRAConfig -from vllm.config.model import (ConvertOption, HfOverrides, LogprobsMode, - ModelConfig, ModelDType, ModelImpl, - RunnerOption, TaskOption, TokenizerMode, - iter_architecture_defaults, - try_match_architecture_defaults) -from vllm.config.multimodal import (MMCacheType, MMEncoderTPMode, - MultiModalConfig) +from vllm.config.model import ( + ConvertOption, + HfOverrides, + LogprobsMode, + ModelConfig, + ModelDType, + ModelImpl, + RunnerOption, + TaskOption, + TokenizerMode, + iter_architecture_defaults, + try_match_architecture_defaults, +) +from vllm.config.multimodal import MMCacheType, MMEncoderTPMode, MultiModalConfig from vllm.config.observability import DetailedTraceModules, ObservabilityConfig -from vllm.config.parallel import (DistributedExecutorBackend, EPLBConfig, - ParallelConfig) +from vllm.config.parallel import DistributedExecutorBackend, EPLBConfig, ParallelConfig from vllm.config.pooler import PoolerConfig from vllm.config.reasoning import ReasoningConfig from vllm.config.scheduler import RunnerType, SchedulerConfig, SchedulerPolicy from vllm.config.speculative import SpeculativeConfig from vllm.config.speech_to_text import SpeechToTextConfig from vllm.config.structured_outputs import StructuredOutputsConfig -from vllm.config.utils import (ConfigType, SupportsMetricsInfo, config, - get_attr_docs, is_init_field, update_config) -from vllm.config.vllm import (VllmConfig, get_cached_compilation_config, - get_current_vllm_config, - get_layers_from_vllm_config, - set_current_vllm_config) +from vllm.config.utils import ( + ConfigType, + SupportsMetricsInfo, + config, + get_attr_docs, + is_init_field, + update_config, +) +from vllm.config.vllm import ( + VllmConfig, + get_cached_compilation_config, + get_current_vllm_config, + get_layers_from_vllm_config, + set_current_vllm_config, +) __all__ = [ # From vllm.config.cache diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index c1e67e864129..4ea20623a74e 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -27,20 +27,22 @@ class ReasoningConfig: def is_thinking_enabled(self) -> bool: """Check if both start and end thinking token IDs are set to enable thinking token budget logic.""" - return (self.think_start_token_ids is not None - and self.think_end_token_ids is not None - and len(self.think_start_token_ids) > 0 - and len(self.think_end_token_ids) > 0) + return ( + self.think_start_token_ids is not None + and self.think_end_token_ids is not None + and len(self.think_start_token_ids) > 0 + and len(self.think_end_token_ids) > 0 + ) def initialize_token_ids(self, model_config: ModelConfig) -> None: """Initialize reasoning token IDs from strings using the tokenizer.""" - if (self.think_start_str is not None - and self.think_end_str is not None): - + if self.think_start_str is not None and self.think_end_str is not None: tokenizer = init_tokenizer_from_configs(model_config=model_config) # Convert reasoning strings to token IDs self.think_start_token_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(self.think_start_str)) + tokenizer.tokenize(self.think_start_str) + ) self.think_end_token_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(self.think_end_str)) + tokenizer.tokenize(self.think_end_str) + ) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 80cd81a6535d..0781811e71b2 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -39,8 +39,7 @@ if TYPE_CHECKING: from transformers import PretrainedConfig - from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + from vllm.model_executor.layers.quantization.base_config import QuantizationConfig else: PretrainedConfig = Any @@ -75,14 +74,14 @@ class VllmConfig: speculative_config: Optional[SpeculativeConfig] = None """Speculative decoding configuration.""" structured_outputs_config: StructuredOutputsConfig = field( - default_factory=StructuredOutputsConfig) + default_factory=StructuredOutputsConfig + ) """Structured outputs configuration.""" observability_config: Optional[ObservabilityConfig] = None """Observability configuration.""" quant_config: Optional[QuantizationConfig] = None """Quantization configuration.""" - compilation_config: CompilationConfig = field( - default_factory=CompilationConfig) + compilation_config: CompilationConfig = field(default_factory=CompilationConfig) """`torch.compile` and cudagraph capture configuration for the model. As a shorthand, `-O` can be used to directly specify the compilation @@ -130,6 +129,7 @@ def compute_hash(self) -> str: # summarize vllm config vllm_factors: list[Any] = [] from vllm import __version__ + vllm_factors.append(__version__) vllm_factors.append(envs.VLLM_USE_V1) if self.model_config: @@ -161,8 +161,7 @@ def compute_hash(self) -> str: # LoRA creates static buffers based on max_num_batched_tokens. # The tensor sizes and strides get captured in the torch.compile # graph explicitly. - vllm_factors.append( - str(self.scheduler_config.max_num_batched_tokens)) + vllm_factors.append(str(self.scheduler_config.max_num_batched_tokens)) else: vllm_factors.append("None") if self.speculative_config: @@ -200,8 +199,9 @@ def compute_hash(self) -> str: vllm_factors.append("None") factors.append(vllm_factors) - hash_str = hashlib.md5(str(factors).encode(), - usedforsecurity=False).hexdigest()[:10] + hash_str = hashlib.md5( + str(factors).encode(), usedforsecurity=False + ).hexdigest()[:10] return hash_str def pad_for_cudagraph(self, batch_size: int) -> int: @@ -213,13 +213,14 @@ def pad_for_cudagraph(self, batch_size: int) -> int: @staticmethod def _get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: """Get the quantization config.""" from vllm.platforms import current_platform + if model_config.quantization is not None: - from vllm.model_executor.model_loader.weight_utils import ( - get_quant_config) + from vllm.model_executor.model_loader.weight_utils import get_quant_config + quant_config = get_quant_config(model_config, load_config) capability_tuple = current_platform.get_device_capability() @@ -230,27 +231,30 @@ def _get_quantization_config( f"The quantization method {model_config.quantization} " "is not supported for the current GPU. Minimum " f"capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") + f"{supported_dtypes}" + ) quant_config.maybe_update_config(model_config.model) return quant_config return None @staticmethod def get_quantization_config( - model_config: ModelConfig, - load_config: LoadConfig) -> Optional[QuantizationConfig]: + model_config: ModelConfig, load_config: LoadConfig + ) -> Optional[QuantizationConfig]: import copy # For some reason, the _ version of this modifies the model_config # object, so using deepcopy to avoid this problem. - return VllmConfig._get_quantization_config(copy.deepcopy(model_config), - load_config) + return VllmConfig._get_quantization_config( + copy.deepcopy(model_config), load_config + ) def with_hf_config( self, @@ -267,15 +271,13 @@ def with_hf_config( return replace(self, model_config=model_config) def __post_init__(self): - """Verify configs are valid & consistent with each other. - """ + """Verify configs are valid & consistent with each other.""" self.try_verify_and_update_config() if self.model_config is not None: self.model_config.verify_with_parallel_config(self.parallel_config) - self.model_config.verify_dual_chunk_attention_config( - self.load_config) + self.model_config.verify_dual_chunk_attention_config(self.load_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -285,29 +287,35 @@ def __post_init__(self): if self.quant_config is None and self.model_config is not None: self.quant_config = VllmConfig._get_quantization_config( - self.model_config, self.load_config) + self.model_config, self.load_config + ) from vllm.platforms import current_platform - if self.model_config is not None and \ - self.scheduler_config.chunked_prefill_enabled and \ - self.model_config.dtype == torch.float32 and \ - current_platform.get_device_capability() == (7, 5): + + if ( + self.model_config is not None + and self.scheduler_config.chunked_prefill_enabled + and self.model_config.dtype == torch.float32 + and current_platform.get_device_capability() == (7, 5) + ): logger.warning_once( "Turing devices tensor cores do not support float32 matmul. " "To workaround this limitation, vLLM will set 'ieee' input " - "precision for chunked prefill triton kernels.") + "precision for chunked prefill triton kernels." + ) # If the user does not explicitly set a compilation level, then # we use the default level. The default level depends on other # settings (see the below code). if self.compilation_config.level is None: if envs.VLLM_USE_V1: - if (self.model_config is not None - and not self.model_config.enforce_eager): + if ( + self.model_config is not None + and not self.model_config.enforce_eager + ): self.compilation_config.level = CompilationLevel.PIECEWISE else: - self.compilation_config.level = \ - CompilationLevel.NO_COMPILATION + self.compilation_config.level = CompilationLevel.NO_COMPILATION else: # NB: Passing both --enforce-eager and a compilation level @@ -317,8 +325,7 @@ def __post_init__(self): # async tp is built on top of sequence parallelism # and requires it to be enabled. if self.compilation_config.pass_config.enable_async_tp: - self.compilation_config.pass_config.enable_sequence_parallelism = \ - True + self.compilation_config.pass_config.enable_sequence_parallelism = True if self.compilation_config.pass_config.enable_sequence_parallelism: self.compilation_config.custom_ops.append("+rms_norm") @@ -326,25 +333,27 @@ def __post_init__(self): # if cudagraph_mode is not explicitly set by users, set default # value if self.compilation_config.cudagraph_mode is None: - if envs.VLLM_USE_V1 and self.compilation_config.level \ - == CompilationLevel.PIECEWISE: + if ( + envs.VLLM_USE_V1 + and self.compilation_config.level == CompilationLevel.PIECEWISE + ): # default to full and piecewise for most models - self.compilation_config.cudagraph_mode = \ + self.compilation_config.cudagraph_mode = ( CUDAGraphMode.FULL_AND_PIECEWISE + ) # pooling models and encoder-decoder models # do not support full cudagraphs - if self.model_config is not None and \ - (self.model_config.pooler_config is not None - or self.model_config.is_encoder_decoder): - self.compilation_config.cudagraph_mode = \ - CUDAGraphMode.PIECEWISE + if self.model_config is not None and ( + self.model_config.pooler_config is not None + or self.model_config.is_encoder_decoder + ): + self.compilation_config.cudagraph_mode = CUDAGraphMode.PIECEWISE else: self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE # disable cudagraph when enforce eager execution - if self.model_config is not None and \ - self.model_config.enforce_eager: + if self.model_config is not None and self.model_config.enforce_eager: logger.info("Cudagraph is disabled under eager mode") self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE elif envs.VLLM_USE_V1: @@ -355,18 +364,21 @@ def __post_init__(self): self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE if self.cache_config.kv_sharing_fast_prefill: - - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): raise NotImplementedError( "Fast prefill optimization for KV sharing is not " "compatible with EAGLE as EAGLE requires correct logits " "for all tokens while fast prefill gives incorrect logits " - "for prompt tokens.") + "for prompt tokens." + ) logger.warning_once( "--kv-sharing-fast-prefill requires changes on model side for " - "correctness and to realize prefill savings. ") + "correctness and to realize prefill savings. " + ) disable_chunked_prefill_reasons: list[str] = [] @@ -375,41 +387,51 @@ def __post_init__(self): pooling_type = self.model_config.pooler_config.pooling_type if pooling_type is None or pooling_type.lower() != "last": disable_chunked_prefill_reasons.append( - "Only \"last\" pooling supports chunked " - "prefill and prefix caching; disabling both.") + 'Only "last" pooling supports chunked ' + "prefill and prefix caching; disabling both." + ) if not getattr(self.model_config.hf_config, "is_causal", True): disable_chunked_prefill_reasons.append( "Only models using causal attention supports chunked " - "prefill and prefix caching; disabling both.") + "prefill and prefix caching; disabling both." + ) elif self.model_config.is_encoder_decoder: from vllm.multimodal import MULTIMODAL_REGISTRY - self.scheduler_config.max_num_encoder_input_tokens = \ + + self.scheduler_config.max_num_encoder_input_tokens = ( MULTIMODAL_REGISTRY.get_encdec_max_encoder_len(self.model_config) + ) logger.debug( "Encoder-decoder model detected: setting " "`max_num_encoder_input_tokens` to encoder length (%s)", - self.scheduler_config.max_num_encoder_input_tokens) - if (self.model_config.architecture - == "WhisperForConditionalGeneration" - and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") - != "spawn"): + self.scheduler_config.max_num_encoder_input_tokens, + ) + if ( + self.model_config.architecture == "WhisperForConditionalGeneration" + and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn" + ): logger.warning( "Whisper is known to have issues with " "forked workers. If startup is hanging, " "try setting 'VLLM_WORKER_MULTIPROC_METHOD' " - "to 'spawn'.") + "to 'spawn'." + ) # Final off-switch for CP/APC: # Disable for (a) collected blockers, (b) encoder–decoder, or # (c) explicit CP=False when APC wasn't requested. # Do NOT disable merely because the resolved CP flag is False. - apc_requested = (self.cache_config is not None - and self.cache_config.enable_prefix_caching) - if (disable_chunked_prefill_reasons - or (self.model_config is not None - and self.model_config.is_encoder_decoder) - or (self.scheduler_config.enable_chunked_prefill is False - and not apc_requested)): + apc_requested = ( + self.cache_config is not None and self.cache_config.enable_prefix_caching + ) + if ( + disable_chunked_prefill_reasons + or (self.model_config is not None and self.model_config.is_encoder_decoder) + or ( + self.scheduler_config.enable_chunked_prefill is False + and not apc_requested + ) + ): for reason in disable_chunked_prefill_reasons: logger.info(reason) self.scheduler_config.chunked_prefill_enabled = False @@ -418,76 +440,88 @@ def __post_init__(self): if self.cache_config is not None: self.cache_config.enable_prefix_caching = False - if (self.kv_events_config is not None - and self.kv_events_config.enable_kv_cache_events - and not self.cache_config.enable_prefix_caching): + if ( + self.kv_events_config is not None + and self.kv_events_config.enable_kv_cache_events + and not self.cache_config.enable_prefix_caching + ): logger.warning( "KV cache events are on, but prefix caching is not enabled." - "Use --enable-prefix-caching to enable.") - if (self.kv_events_config is not None - and self.kv_events_config.publisher != "null" - and not self.kv_events_config.enable_kv_cache_events): - logger.warning("KV cache events are disabled," - "but the scheduler is configured to publish them." - "Modify KVEventsConfig.enable_kv_cache_events" - "to True to enable.") + "Use --enable-prefix-caching to enable." + ) + if ( + self.kv_events_config is not None + and self.kv_events_config.publisher != "null" + and not self.kv_events_config.enable_kv_cache_events + ): + logger.warning( + "KV cache events are disabled," + "but the scheduler is configured to publish them." + "Modify KVEventsConfig.enable_kv_cache_events" + "to True to enable." + ) current_platform.check_and_update_config(self) # Do this after all the updates to compilation_config.level - if envs.VLLM_USE_V1 and \ - self.compilation_config.level == CompilationLevel.PIECEWISE: + if ( + envs.VLLM_USE_V1 + and self.compilation_config.level == CompilationLevel.PIECEWISE + ): self.compilation_config.set_splitting_ops_for_v1() # final check of cudagraph mode after all possible updates if envs.VLLM_USE_V1 and current_platform.is_cuda_alike(): - if self.compilation_config.cudagraph_mode.has_full_cudagraphs()\ - and self.model_config is not None and \ - not self.model_config.disable_cascade_attn and\ - not self.compilation_config.cudagraph_mode.\ - has_piecewise_cudagraphs(): + if ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + and self.model_config is not None + and not self.model_config.disable_cascade_attn + and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() + ): logger.warning_once( "No piecewise cudagraph for executing cascade attention." " Will fall back to eager execution if a batch runs " - "into cascade attentions") - - if self.compilation_config.cudagraph_mode\ - .requires_piecewise_compilation(): - assert self.compilation_config.level == \ - CompilationLevel.PIECEWISE, \ - "Compilation level should be CompilationLevel.PIECEWISE "\ - "when cudagraph_mode piecewise cudagraphs is used, "\ + "into cascade attentions" + ) + + if self.compilation_config.cudagraph_mode.requires_piecewise_compilation(): + assert self.compilation_config.level == CompilationLevel.PIECEWISE, ( + "Compilation level should be CompilationLevel.PIECEWISE " + "when cudagraph_mode piecewise cudagraphs is used, " f"cudagraph_mode={self.compilation_config.cudagraph_mode}" + ) # final migrate the deprecated flags - self.compilation_config.use_cudagraph = self.compilation_config.\ - cudagraph_mode!= CUDAGraphMode.NONE - self.compilation_config.full_cuda_graph = self.compilation_config.\ - cudagraph_mode.has_full_cudagraphs() + self.compilation_config.use_cudagraph = ( + self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + ) + self.compilation_config.full_cuda_graph = ( + self.compilation_config.cudagraph_mode.has_full_cudagraphs() + ) if self.parallel_config.enable_dbo: a2a_backend = envs.VLLM_ALL2ALL_BACKEND - assert a2a_backend in \ - ["deepep_low_latency", "deepep_high_throughput"], \ - "Microbatching currently only supports the deepep_low_latency and "\ - f"deepep_high_throughput all2all backend. {a2a_backend} is not "\ - "supported. To fix set the VLLM_ALL2ALL_BACKEND environment "\ - "variable to deepep_low_latency or deepep_high_throughput and "\ - "install the DeepEP kernels." + assert a2a_backend in ["deepep_low_latency", "deepep_high_throughput"], ( + "Microbatching currently only supports the deepep_low_latency and " + f"deepep_high_throughput all2all backend. {a2a_backend} is not " + "supported. To fix set the VLLM_ALL2ALL_BACKEND environment " + "variable to deepep_low_latency or deepep_high_throughput and " + "install the DeepEP kernels." + ) if not self.model_config.disable_cascade_attn: self.model_config.disable_cascade_attn = True - logger.warning_once( - "Disabling cascade attention when DBO is enabled.") + logger.warning_once("Disabling cascade attention when DBO is enabled.") if not self.instance_id: self.instance_id = random_uuid()[:5] - if (self.reasoning_config is not None - and self.model_config is not None): + if self.reasoning_config is not None and self.model_config is not None: self.reasoning_config.initialize_token_ids(self.model_config) - if (envs.VLLM_USE_V1 - and not self.scheduler_config.disable_hybrid_kv_cache_manager): + if ( + envs.VLLM_USE_V1 + and not self.scheduler_config.disable_hybrid_kv_cache_manager + ): # logger should only print warning message for hybrid models. As we # can't know whether the model is hybrid or not now, so we don't log # warning message here and will log it later. @@ -500,15 +534,18 @@ def __post_init__(self): if self.kv_events_config is not None: # Hybrid KV cache manager is not compatible with KV events. self.scheduler_config.disable_hybrid_kv_cache_manager = True - if self.model_config is not None and \ - self.model_config.attention_chunk_size is not None: - if self.speculative_config is not None and \ - self.speculative_config.use_eagle(): + if ( + self.model_config is not None + and self.model_config.attention_chunk_size is not None + ): + if ( + self.speculative_config is not None + and self.speculative_config.use_eagle() + ): # Hybrid KV cache manager is not yet supported with chunked # local attention + eagle. self.scheduler_config.disable_hybrid_kv_cache_manager = True - elif \ - not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: + elif not envs.VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: logger.warning( "There is a latency regression when using chunked local" " attention with the hybrid KV cache manager. Disabling" @@ -520,14 +557,17 @@ def __post_init__(self): self.scheduler_config.disable_hybrid_kv_cache_manager = True if self.compilation_config.debug_dump_path: - self.compilation_config.debug_dump_path = \ + self.compilation_config.debug_dump_path = ( self.compilation_config.debug_dump_path.absolute().expanduser() + ) if envs.VLLM_DEBUG_DUMP_PATH is not None: env_path = Path(envs.VLLM_DEBUG_DUMP_PATH).absolute().expanduser() if self.compilation_config.debug_dump_path: logger.warning( "Config-specified debug dump path is overridden" - " by VLLM_DEBUG_DUMP_PATH to %s", env_path) + " by VLLM_DEBUG_DUMP_PATH to %s", + env_path, + ) self.compilation_config.debug_dump_path = env_path def has_blocked_weights(): @@ -547,23 +587,26 @@ def has_blocked_weights(): if "none" not in custom_ops and "-quant_fp8" not in custom_ops: custom_ops.append("+quant_fp8") - def update_sizes_for_sequence_parallelism(self, - possible_sizes: list) -> list: + def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: # remove the sizes that not multiple of tp_size when # enable sequence parallelism removed_sizes = [ - size for size in possible_sizes + size + for size in possible_sizes if size % self.parallel_config.tensor_parallel_size != 0 ] if removed_sizes: logger.warning( "Batch sizes %s are removed because they are not " "multiple of tp_size %d when " - "sequence parallelism is enabled", removed_sizes, - self.parallel_config.tensor_parallel_size) + "sequence parallelism is enabled", + removed_sizes, + self.parallel_config.tensor_parallel_size, + ) return [ - size for size in possible_sizes + size + for size in possible_sizes if size % self.parallel_config.tensor_parallel_size == 0 ] @@ -607,13 +650,13 @@ def _set_cudagraph_sizes(self): # calculate the default `batch_size_capture_list` batch_size_capture_list = [] - if self.model_config is not None and \ - not self.model_config.enforce_eager: + if self.model_config is not None and not self.model_config.enforce_eager: cuda_graph_sizes = self.scheduler_config.cuda_graph_sizes if len(cuda_graph_sizes) == 1: max_graph_size = cuda_graph_sizes[0] - assert max_graph_size >= 1, "Maximum cudagraph size should be" \ - " greater than or equal to 1." + assert max_graph_size >= 1, ( + "Maximum cudagraph size should be greater than or equal to 1." + ) batch_size_capture_list = [ i for i in [1, 2, 4] if i <= max_graph_size ] + list(range(8, max_graph_size + 1, 8)) @@ -621,18 +664,19 @@ def _set_cudagraph_sizes(self): batch_size_capture_list = sorted(cuda_graph_sizes) else: raise TypeError(f"Invalid value for {cuda_graph_sizes=}.") - if self.parallel_config.tensor_parallel_size > 1 and \ - self.compilation_config.pass_config.enable_sequence_parallelism: - batch_size_capture_list = \ - self.update_sizes_for_sequence_parallelism(batch_size_capture_list) + if ( + self.parallel_config.tensor_parallel_size > 1 + and self.compilation_config.pass_config.enable_sequence_parallelism + ): + batch_size_capture_list = self.update_sizes_for_sequence_parallelism( + batch_size_capture_list + ) max_num_tokens = self.scheduler_config.max_num_batched_tokens batch_size_capture_list = [ - size for size in batch_size_capture_list - if size <= max_num_tokens + size for size in batch_size_capture_list if size <= max_num_tokens ] - self.compilation_config.init_with_cudagraph_sizes( - batch_size_capture_list) + self.compilation_config.init_with_cudagraph_sizes(batch_size_capture_list) def recalculate_max_model_len(self, max_model_len: int): # Can only be called in try_verify_and_update_config @@ -655,7 +699,10 @@ def try_verify_and_update_config(self): return from vllm.model_executor.models.config import ( - MODELS_CONFIG_MAP, HybridAttentionMambaModelConfig) + MODELS_CONFIG_MAP, + HybridAttentionMambaModelConfig, + ) + cls = MODELS_CONFIG_MAP.get(architecture, None) if cls is not None: cls.verify_and_update_config(self) @@ -665,21 +712,26 @@ def try_verify_and_update_config(self): if self.model_config.convert_type == "classify": # Maybe convert ForCausalLM into ForSequenceClassification model. - from vllm.model_executor.models.adapters import ( - SequenceClassificationConfig) + from vllm.model_executor.models.adapters import SequenceClassificationConfig + SequenceClassificationConfig.verify_and_update_config(self) if hasattr(self.model_config, "model_weights") and is_runai_obj_uri( - self.model_config.model_weights): + self.model_config.model_weights + ): if self.load_config.load_format == "auto": - logger.info("Detected Run:ai model config. " - "Overriding `load_format` to 'runai_streamer'") + logger.info( + "Detected Run:ai model config. " + "Overriding `load_format` to 'runai_streamer'" + ) self.load_config.load_format = "runai_streamer" elif self.load_config.load_format != "runai_streamer": - raise ValueError(f"To load a model from S3, 'load_format' " - f"must be 'runai_streamer', " - f"but got '{self.load_config.load_format}'. " - f"Model: {self.model_config.model}") + raise ValueError( + f"To load a model from S3, 'load_format' " + f"must be 'runai_streamer', " + f"but got '{self.load_config.load_format}'. " + f"Model: {self.model_config.model}" + ) def compile_debug_dump_path(self) -> Optional[Path]: """Returns a rank-aware path for dumping @@ -690,8 +742,11 @@ def compile_debug_dump_path(self) -> Optional[Path]: tp_rank = self.parallel_config.rank dp_rank = self.parallel_config.data_parallel_rank data_parallel_size = self.parallel_config.data_parallel_size - append_path = f"rank_{tp_rank}" if data_parallel_size == 1 \ + append_path = ( + f"rank_{tp_rank}" + if data_parallel_size == 1 else f"rank_{tp_rank}_dp_{dp_rank}" + ) path = self.compilation_config.debug_dump_path / append_path return path @@ -724,7 +779,8 @@ def __str__(self): f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, " f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa f"pooler_config={self.model_config.pooler_config!r}, " - f"compilation_config={self.compilation_config!r}") + f"compilation_config={self.compilation_config!r}" + ) _current_vllm_config: Optional[VllmConfig] = None @@ -732,9 +788,9 @@ def __str__(self): @contextmanager -def set_current_vllm_config(vllm_config: VllmConfig, - check_compile=False, - prefix: Optional[str] = None): +def set_current_vllm_config( + vllm_config: VllmConfig, check_compile=False, prefix: Optional[str] = None +): """ Temporarily set the current vLLM config. Used during model initialization. @@ -746,6 +802,7 @@ def set_current_vllm_config(vllm_config: VllmConfig, old_vllm_config = _current_vllm_config old_prefix = _current_prefix from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen try: _current_vllm_config = vllm_config @@ -757,9 +814,11 @@ def set_current_vllm_config(vllm_config: VllmConfig, if check_compile: vllm_config.compilation_config.custom_op_log_check() - if check_compile and \ - vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ - and compilation_counter.num_models_seen == num_models_seen: + if ( + check_compile + and vllm_config.compilation_config.level == CompilationLevel.PIECEWISE + and compilation_counter.num_models_seen == num_models_seen + ): # If the model supports compilation, # compilation_counter.num_models_seen should be increased # by at least 1. @@ -769,7 +828,8 @@ def set_current_vllm_config(vllm_config: VllmConfig, "`torch.compile` is turned on, but the model %s" " does not support it. Please open an issue on GitHub" " if you want it to be supported.", - vllm_config.model_config.model) + vllm_config.model_config.model, + ) finally: _current_vllm_config = old_vllm_config _current_prefix = old_prefix @@ -797,9 +857,10 @@ def get_current_vllm_config() -> VllmConfig: def get_layers_from_vllm_config( - vllm_config: VllmConfig, - layer_type: type[T], - layer_names: Optional[list[str]] = None) -> dict[str, T]: + vllm_config: VllmConfig, + layer_type: type[T], + layer_names: Optional[list[str]] = None, +) -> dict[str, T]: """ Get layers from the vLLM config. @@ -810,8 +871,7 @@ def get_layers_from_vllm_config( """ if layer_names is None: - layer_names = list( - vllm_config.compilation_config.static_forward_context.keys()) + layer_names = list(vllm_config.compilation_config.static_forward_context.keys()) forward_context = vllm_config.compilation_config.static_forward_context diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 76ad5c615762..c664d90f96d1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -10,9 +10,22 @@ import sys from dataclasses import MISSING, dataclass, fields, is_dataclass from itertools import permutations -from typing import (TYPE_CHECKING, Annotated, Any, Callable, Dict, List, - Literal, Optional, Type, TypeVar, Union, cast, get_args, - get_origin) +from typing import ( + TYPE_CHECKING, + Annotated, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Type, + TypeVar, + Union, + cast, + get_args, + get_origin, +) import huggingface_hub import regex as re @@ -21,17 +34,43 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigType, ConvertOption, DetailedTraceModules, - Device, DeviceConfig, DistributedExecutorBackend, - EPLBConfig, HfOverrides, KVEventsConfig, - KVTransferConfig, LoadConfig, LogprobsMode, - LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, - ModelDType, ObservabilityConfig, ParallelConfig, - PoolerConfig, PrefixCachingHashAlgo, ReasoningConfig, - RunnerOption, SchedulerConfig, SchedulerPolicy, - SpeculativeConfig, StructuredOutputsConfig, - TaskOption, TokenizerMode, VllmConfig, get_attr_docs) +from vllm.config import ( + BlockSize, + CacheConfig, + CacheDType, + CompilationConfig, + ConfigType, + ConvertOption, + DetailedTraceModules, + Device, + DeviceConfig, + DistributedExecutorBackend, + EPLBConfig, + HfOverrides, + KVEventsConfig, + KVTransferConfig, + LoadConfig, + LogprobsMode, + LoRAConfig, + MambaDType, + MMEncoderTPMode, + ModelConfig, + ModelDType, + ObservabilityConfig, + ParallelConfig, + PoolerConfig, + PrefixCachingHashAlgo, + ReasoningConfig, + RunnerOption, + SchedulerConfig, + SchedulerPolicy, + SpeculativeConfig, + StructuredOutputsConfig, + TaskOption, + TokenizerMode, + VllmConfig, + get_attr_docs, +) from vllm.config.multimodal import MMCacheType, MultiModalConfig from vllm.config.parallel import ExpertPlacementStrategy from vllm.config.utils import get_field @@ -41,11 +80,13 @@ from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 -from vllm.transformers_utils.config import (get_model_path, is_interleaved, - maybe_override_with_speculators) +from vllm.transformers_utils.config import ( + get_model_path, + is_interleaved, + maybe_override_with_speculators, +) from vllm.transformers_utils.utils import check_gguf_file -from vllm.utils import (FlexibleArgumentParser, GiB_bytes, get_ip, - is_in_ray_actor) +from vllm.utils import FlexibleArgumentParser, GiB_bytes, get_ip, is_in_ray_actor from vllm.v1.sample.logits_processor import LogitsProcessor # yapf: enable @@ -70,20 +111,18 @@ def parse_type(return_type: Callable[[str], T]) -> Callable[[str], T]: - def _parse_type(val: str) -> T: try: return return_type(val) except ValueError as e: raise argparse.ArgumentTypeError( - f"Value {val} cannot be converted to {return_type}.") from e + f"Value {val} cannot be converted to {return_type}." + ) from e return _parse_type -def optional_type( - return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: - +def optional_type(return_type: Callable[[str], T]) -> Callable[[str], Optional[T]]: def _optional_type(val: str) -> Optional[T]: if val == "" or val == "None": return None @@ -124,7 +163,8 @@ def literal_to_kwargs(type_hints: set[TypeHint]) -> dict[str, Any]: if not all(isinstance(option, option_type) for option in options): raise ValueError( "All options must be of the same type. " - f"Got {options} with types {[type(c) for c in options]}") + f"Got {options} with types {[type(c) for c in options]}" + ) kwarg = "metavar" if contains_type(type_hints, str) else "choices" return {"type": option_type, kwarg: sorted(options)} @@ -191,8 +231,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: kwargs[name] = {"default": default, "help": help} # Set other kwargs based on the type hints - json_tip = ("Should either be a valid JSON string or JSON keys passed " - "individually.") + json_tip = ( + "Should either be a valid JSON string or JSON keys passed individually." + ) if dataclass_cls is not None: def parse_dataclass(val: str, cls=dataclass_cls) -> Any: @@ -214,7 +255,8 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: tuple_type = types[0] assert all(t is tuple_type for t in types if t is not Ellipsis), ( "All non-Ellipsis tuple elements must be of the same " - f"type. Got {types}.") + f"type. Got {types}." + ) kwargs[name]["type"] = tuple_type kwargs[name]["nargs"] = "+" if Ellipsis in types else len(types) elif contains_type(type_hints, list): @@ -240,19 +282,20 @@ def parse_dataclass(val: str, cls=dataclass_cls) -> Any: kwargs[name]["help"] += f"\n\n{human_readable_int.__doc__}" elif contains_type(type_hints, float): kwargs[name]["type"] = float - elif (contains_type(type_hints, dict) - and (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints))): + elif contains_type(type_hints, dict) and ( + contains_type(type_hints, str) + or any(is_not_builtin(th) for th in type_hints) + ): kwargs[name]["type"] = union_dict_and_str elif contains_type(type_hints, dict): kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["help"] += f"\n\n{json_tip}" - elif (contains_type(type_hints, str) - or any(is_not_builtin(th) for th in type_hints)): + elif contains_type(type_hints, str) or any( + is_not_builtin(th) for th in type_hints + ): kwargs[name]["type"] = str else: - raise ValueError( - f"Unsupported type {type_hints} for argument {name}.") + raise ValueError(f"Unsupported type {type_hints} for argument {name}.") # If the type hint was a sequence of literals, use the helper function # to update the type and choices @@ -284,9 +327,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model: str = ModelConfig.model - served_model_name: Optional[Union[ - str, List[str]]] = ModelConfig.served_model_name + served_model_name: Optional[Union[str, List[str]]] = ModelConfig.served_model_name tokenizer: Optional[str] = ModelConfig.tokenizer hf_config_path: Optional[str] = ModelConfig.hf_config_path runner: RunnerOption = ModelConfig.runner @@ -297,8 +340,7 @@ class EngineArgs: tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode trust_remote_code: bool = ModelConfig.trust_remote_code allowed_local_media_path: str = ModelConfig.allowed_local_media_path - allowed_media_domains: Optional[ - list[str]] = ModelConfig.allowed_media_domains + allowed_media_domains: Optional[list[str]] = ModelConfig.allowed_media_domains download_dir: Optional[str] = LoadConfig.download_dir safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy load_format: Union[str, LoadFormats] = LoadConfig.load_format @@ -307,19 +349,17 @@ class EngineArgs: kv_cache_dtype: CacheDType = CacheConfig.cache_dtype seed: Optional[int] = ModelConfig.seed max_model_len: Optional[int] = ModelConfig.max_model_len - cuda_graph_sizes: list[int] = get_field(SchedulerConfig, - "cuda_graph_sizes") + cuda_graph_sizes: list[int] = get_field(SchedulerConfig, "cuda_graph_sizes") # Note: Specifying a custom executor backend by passing a class # is intended for expert use only. The API may change without # notice. - distributed_executor_backend: Optional[Union[ - str, DistributedExecutorBackend, - Type[ExecutorBase]]] = ParallelConfig.distributed_executor_backend + distributed_executor_backend: Optional[ + Union[str, DistributedExecutorBackend, Type[ExecutorBase]] + ] = ParallelConfig.distributed_executor_backend # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size - decode_context_parallel_size: int = \ - ParallelConfig.decode_context_parallel_size + decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size data_parallel_rank: Optional[int] = None data_parallel_start_rank: Optional[int] = None @@ -330,38 +370,37 @@ class EngineArgs: data_parallel_backend: str = ParallelConfig.data_parallel_backend enable_expert_parallel: bool = ParallelConfig.enable_expert_parallel enable_dbo: bool = ParallelConfig.enable_dbo - dbo_decode_token_threshold: int = \ - ParallelConfig.dbo_decode_token_threshold - dbo_prefill_token_threshold: int = \ - ParallelConfig.dbo_prefill_token_threshold + dbo_decode_token_threshold: int = ParallelConfig.dbo_decode_token_threshold + dbo_prefill_token_threshold: int = ParallelConfig.dbo_prefill_token_threshold eplb_config: EPLBConfig = get_field(ParallelConfig, "eplb_config") enable_eplb: bool = ParallelConfig.enable_eplb - expert_placement_strategy: ExpertPlacementStrategy = \ + expert_placement_strategy: ExpertPlacementStrategy = ( ParallelConfig.expert_placement_strategy + ) _api_process_count: int = ParallelConfig._api_process_count _api_process_rank: int = ParallelConfig._api_process_rank num_redundant_experts: int = EPLBConfig.num_redundant_experts eplb_window_size: int = EPLBConfig.window_size eplb_step_interval: int = EPLBConfig.step_interval eplb_log_balancedness: bool = EPLBConfig.log_balancedness - max_parallel_loading_workers: Optional[ - int] = ParallelConfig.max_parallel_loading_workers + max_parallel_loading_workers: Optional[int] = ( + ParallelConfig.max_parallel_loading_workers + ) block_size: Optional[BlockSize] = CacheConfig.block_size enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching - prefix_caching_hash_algo: PrefixCachingHashAlgo = \ + prefix_caching_hash_algo: PrefixCachingHashAlgo = ( CacheConfig.prefix_caching_hash_algo + ) disable_sliding_window: bool = ModelConfig.disable_sliding_window disable_cascade_attn: bool = ModelConfig.disable_cascade_attn swap_space: float = CacheConfig.swap_space cpu_offload_gb: float = CacheConfig.cpu_offload_gb gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization kv_cache_memory_bytes: Optional[int] = CacheConfig.kv_cache_memory_bytes - max_num_batched_tokens: Optional[ - int] = SchedulerConfig.max_num_batched_tokens + max_num_batched_tokens: Optional[int] = SchedulerConfig.max_num_batched_tokens max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills - long_prefill_token_threshold: int = \ - SchedulerConfig.long_prefill_token_threshold + long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_logprobs: int = ModelConfig.max_logprobs logprobs_mode: LogprobsMode = ModelConfig.logprobs_mode @@ -376,20 +415,22 @@ class EngineArgs: quantization: Optional[QuantizationMethods] = ModelConfig.quantization enforce_eager: bool = ModelConfig.enforce_eager disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce - limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = \ - get_field(MultiModalConfig, "limit_per_prompt") + limit_mm_per_prompt: dict[str, Union[int, dict[str, int]]] = get_field( + MultiModalConfig, "limit_per_prompt" + ) interleave_mm_strings: bool = MultiModalConfig.interleave_mm_strings - media_io_kwargs: dict[str, dict[str, - Any]] = get_field(MultiModalConfig, - "media_io_kwargs") - mm_processor_kwargs: Optional[Dict[str, Any]] = \ - MultiModalConfig.mm_processor_kwargs + media_io_kwargs: dict[str, dict[str, Any]] = get_field( + MultiModalConfig, "media_io_kwargs" + ) + mm_processor_kwargs: Optional[Dict[str, Any]] = MultiModalConfig.mm_processor_kwargs disable_mm_preprocessor_cache: bool = False # DEPRECATED mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb - mm_processor_cache_type: Optional[MMCacheType] = \ + mm_processor_cache_type: Optional[MMCacheType] = ( MultiModalConfig.mm_processor_cache_type - mm_shm_cache_max_object_size_mb: int = \ + ) + mm_shm_cache_max_object_size_mb: int = ( MultiModalConfig.mm_shm_cache_max_object_size_mb + ) mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode io_processor_plugin: Optional[str] = None skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling @@ -399,31 +440,28 @@ class EngineArgs: enable_lora_bias: bool = LoRAConfig.bias_enabled max_loras: int = LoRAConfig.max_loras max_lora_rank: int = LoRAConfig.max_lora_rank - default_mm_loras: Optional[Dict[str, str]] = \ - LoRAConfig.default_mm_loras + default_mm_loras: Optional[Dict[str, str]] = LoRAConfig.default_mm_loras fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: Optional[int] = LoRAConfig.max_cpu_loras lora_dtype: Optional[Union[str, torch.dtype]] = LoRAConfig.lora_dtype lora_extra_vocab_size: int = LoRAConfig.lora_extra_vocab_size ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight - num_gpu_blocks_override: Optional[ - int] = CacheConfig.num_gpu_blocks_override + num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots - model_loader_extra_config: dict = \ - get_field(LoadConfig, "model_loader_extra_config") - ignore_patterns: Optional[Union[str, - List[str]]] = LoadConfig.ignore_patterns + model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config") + ignore_patterns: Optional[Union[str, List[str]]] = LoadConfig.ignore_patterns - enable_chunked_prefill: Optional[ - bool] = SchedulerConfig.enable_chunked_prefill + enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input disable_hybrid_kv_cache_manager: bool = ( - SchedulerConfig.disable_hybrid_kv_cache_manager) + SchedulerConfig.disable_hybrid_kv_cache_manager + ) structured_outputs_config: StructuredOutputsConfig = get_field( - VllmConfig, "structured_outputs_config") + VllmConfig, "structured_outputs_config" + ) reasoning_parser: str = StructuredOutputsConfig.reasoning_parser # Deprecated guided decoding fields guided_decoding_backend: Optional[str] = None @@ -431,38 +469,38 @@ class EngineArgs: guided_decoding_disable_any_whitespace: Optional[bool] = None guided_decoding_disable_additional_properties: Optional[bool] = None - logits_processor_pattern: Optional[ - str] = ModelConfig.logits_processor_pattern + logits_processor_pattern: Optional[str] = ModelConfig.logits_processor_pattern speculative_config: Optional[Dict[str, Any]] = None - show_hidden_metrics_for_version: Optional[str] = \ + show_hidden_metrics_for_version: Optional[str] = ( ObservabilityConfig.show_hidden_metrics_for_version - otlp_traces_endpoint: Optional[str] = \ - ObservabilityConfig.otlp_traces_endpoint - collect_detailed_traces: Optional[list[DetailedTraceModules]] = \ + ) + otlp_traces_endpoint: Optional[str] = ObservabilityConfig.otlp_traces_endpoint + collect_detailed_traces: Optional[list[DetailedTraceModules]] = ( ObservabilityConfig.collect_detailed_traces + ) scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config - override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ + override_pooler_config: Optional[Union[dict, PoolerConfig]] = ( ModelConfig.override_pooler_config - compilation_config: CompilationConfig = \ - get_field(VllmConfig, "compilation_config") + ) + compilation_config: CompilationConfig = get_field(VllmConfig, "compilation_config") worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls kv_transfer_config: Optional[KVTransferConfig] = None kv_events_config: Optional[KVEventsConfig] = None - reasoning_config: ReasoningConfig = get_field(VllmConfig, - "reasoning_config") + reasoning_config: ReasoningConfig = get_field(VllmConfig, "reasoning_config") generation_config: str = ModelConfig.generation_config enable_sleep_mode: bool = ModelConfig.enable_sleep_mode - override_generation_config: dict[str, Any] = \ - get_field(ModelConfig, "override_generation_config") + override_generation_config: dict[str, Any] = get_field( + ModelConfig, "override_generation_config" + ) model_impl: str = ModelConfig.model_impl override_attention_dtype: str = ModelConfig.override_attention_dtype @@ -470,8 +508,7 @@ class EngineArgs: mamba_cache_dtype: MambaDType = CacheConfig.mamba_cache_dtype mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype - additional_config: dict[str, Any] = \ - get_field(VllmConfig, "additional_config") + additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load pt_load_map_location: str = LoadConfig.pt_load_map_location @@ -479,34 +516,36 @@ class EngineArgs: # DEPRECATED enable_multimodal_encoder_data_parallel: bool = False - logits_processors: Optional[list[Union[ - str, type[LogitsProcessor]]]] = ModelConfig.logits_processors + logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = ( + ModelConfig.logits_processors + ) """Custom logitproc types""" async_scheduling: bool = SchedulerConfig.async_scheduling - kv_sharing_fast_prefill: bool = \ - CacheConfig.kv_sharing_fast_prefill + kv_sharing_fast_prefill: bool = CacheConfig.kv_sharing_fast_prefill def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a # CompilationConfig object if isinstance(self.compilation_config, dict): - self.compilation_config = CompilationConfig( - **self.compilation_config) + self.compilation_config = CompilationConfig(**self.compilation_config) if isinstance(self.eplb_config, dict): self.eplb_config = EPLBConfig(**self.eplb_config) # Setup plugins from vllm.plugins import load_general_plugins + load_general_plugins() # when use hf offline,replace model id to local model path if huggingface_hub.constants.HF_HUB_OFFLINE: model_id = self.model self.model = get_model_path(self.model, self.revision) logger.info( - "HF_HUB_OFFLINE is True, replace model_id [%s] " \ - "to model_path [%s]",model_id, self.model) + "HF_HUB_OFFLINE is True, replace model_id [%s] to model_path [%s]", + model_id, + self.model, + ) @staticmethod def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: @@ -518,86 +557,92 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="ModelConfig", description=ModelConfig.__doc__, ) - if not ('serve' in sys.argv[1:] and '--help' in sys.argv[1:]): + if not ("serve" in sys.argv[1:] and "--help" in sys.argv[1:]): model_group.add_argument("--model", **model_kwargs["model"]) model_group.add_argument("--runner", **model_kwargs["runner"]) model_group.add_argument("--convert", **model_kwargs["convert"]) - model_group.add_argument("--task", - **model_kwargs["task"], - deprecated=True) + model_group.add_argument("--task", **model_kwargs["task"], deprecated=True) model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"]) - model_group.add_argument("--tokenizer-mode", - **model_kwargs["tokenizer_mode"]) - model_group.add_argument("--trust-remote-code", - **model_kwargs["trust_remote_code"]) + model_group.add_argument("--tokenizer-mode", **model_kwargs["tokenizer_mode"]) + model_group.add_argument( + "--trust-remote-code", **model_kwargs["trust_remote_code"] + ) model_group.add_argument("--dtype", **model_kwargs["dtype"]) model_group.add_argument("--seed", **model_kwargs["seed"]) - model_group.add_argument("--hf-config-path", - **model_kwargs["hf_config_path"]) - model_group.add_argument("--allowed-local-media-path", - **model_kwargs["allowed_local_media_path"]) - model_group.add_argument("--allowed-media-domains", - **model_kwargs["allowed_media_domains"]) + model_group.add_argument("--hf-config-path", **model_kwargs["hf_config_path"]) + model_group.add_argument( + "--allowed-local-media-path", **model_kwargs["allowed_local_media_path"] + ) + model_group.add_argument( + "--allowed-media-domains", **model_kwargs["allowed_media_domains"] + ) model_group.add_argument("--revision", **model_kwargs["revision"]) - model_group.add_argument("--code-revision", - **model_kwargs["code_revision"]) - model_group.add_argument("--rope-scaling", - **model_kwargs["rope_scaling"]) + model_group.add_argument("--code-revision", **model_kwargs["code_revision"]) + model_group.add_argument("--rope-scaling", **model_kwargs["rope_scaling"]) model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"]) - model_group.add_argument("--tokenizer-revision", - **model_kwargs["tokenizer_revision"]) - model_group.add_argument("--max-model-len", - **model_kwargs["max_model_len"]) - model_group.add_argument("--quantization", "-q", - **model_kwargs["quantization"]) - model_group.add_argument("--enforce-eager", - **model_kwargs["enforce_eager"]) - model_group.add_argument("--max-logprobs", - **model_kwargs["max_logprobs"]) - model_group.add_argument("--logprobs-mode", - **model_kwargs["logprobs_mode"]) - model_group.add_argument("--disable-sliding-window", - **model_kwargs["disable_sliding_window"]) - model_group.add_argument("--disable-cascade-attn", - **model_kwargs["disable_cascade_attn"]) - model_group.add_argument("--skip-tokenizer-init", - **model_kwargs["skip_tokenizer_init"]) - model_group.add_argument("--enable-prompt-embeds", - **model_kwargs["enable_prompt_embeds"]) - model_group.add_argument("--served-model-name", - **model_kwargs["served_model_name"]) - model_group.add_argument("--config-format", - **model_kwargs["config_format"]) + model_group.add_argument( + "--tokenizer-revision", **model_kwargs["tokenizer_revision"] + ) + model_group.add_argument("--max-model-len", **model_kwargs["max_model_len"]) + model_group.add_argument("--quantization", "-q", **model_kwargs["quantization"]) + model_group.add_argument("--enforce-eager", **model_kwargs["enforce_eager"]) + model_group.add_argument("--max-logprobs", **model_kwargs["max_logprobs"]) + model_group.add_argument("--logprobs-mode", **model_kwargs["logprobs_mode"]) + model_group.add_argument( + "--disable-sliding-window", **model_kwargs["disable_sliding_window"] + ) + model_group.add_argument( + "--disable-cascade-attn", **model_kwargs["disable_cascade_attn"] + ) + model_group.add_argument( + "--skip-tokenizer-init", **model_kwargs["skip_tokenizer_init"] + ) + model_group.add_argument( + "--enable-prompt-embeds", **model_kwargs["enable_prompt_embeds"] + ) + model_group.add_argument( + "--served-model-name", **model_kwargs["served_model_name"] + ) + model_group.add_argument("--config-format", **model_kwargs["config_format"]) # This one is a special case because it can bool # or str. TODO: Handle this in get_kwargs - model_group.add_argument("--hf-token", - type=str, - nargs="?", - const=True, - default=model_kwargs["hf_token"]["default"], - help=model_kwargs["hf_token"]["help"]) - model_group.add_argument("--hf-overrides", - **model_kwargs["hf_overrides"]) - model_group.add_argument("--pooler-config", - **model_kwargs["pooler_config"]) - model_group.add_argument("--override-pooler-config", - **model_kwargs["override_pooler_config"], - deprecated=True) - model_group.add_argument("--logits-processor-pattern", - **model_kwargs["logits_processor_pattern"]) - model_group.add_argument("--generation-config", - **model_kwargs["generation_config"]) - model_group.add_argument("--override-generation-config", - **model_kwargs["override_generation_config"]) - model_group.add_argument("--enable-sleep-mode", - **model_kwargs["enable_sleep_mode"]) + model_group.add_argument( + "--hf-token", + type=str, + nargs="?", + const=True, + default=model_kwargs["hf_token"]["default"], + help=model_kwargs["hf_token"]["help"], + ) + model_group.add_argument("--hf-overrides", **model_kwargs["hf_overrides"]) + model_group.add_argument("--pooler-config", **model_kwargs["pooler_config"]) + model_group.add_argument( + "--override-pooler-config", + **model_kwargs["override_pooler_config"], + deprecated=True, + ) + model_group.add_argument( + "--logits-processor-pattern", **model_kwargs["logits_processor_pattern"] + ) + model_group.add_argument( + "--generation-config", **model_kwargs["generation_config"] + ) + model_group.add_argument( + "--override-generation-config", **model_kwargs["override_generation_config"] + ) + model_group.add_argument( + "--enable-sleep-mode", **model_kwargs["enable_sleep_mode"] + ) model_group.add_argument("--model-impl", **model_kwargs["model_impl"]) - model_group.add_argument("--override-attention-dtype", - **model_kwargs["override_attention_dtype"]) - model_group.add_argument("--logits-processors", - **model_kwargs["logits_processors"]) - model_group.add_argument("--io-processor-plugin", - **model_kwargs["io_processor_plugin"]) + model_group.add_argument( + "--override-attention-dtype", **model_kwargs["override_attention_dtype"] + ) + model_group.add_argument( + "--logits-processors", **model_kwargs["logits_processors"] + ) + model_group.add_argument( + "--io-processor-plugin", **model_kwargs["io_processor_plugin"] + ) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -606,18 +651,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=LoadConfig.__doc__, ) load_group.add_argument("--load-format", **load_kwargs["load_format"]) - load_group.add_argument("--download-dir", - **load_kwargs["download_dir"]) - load_group.add_argument("--safetensors-load-strategy", - **load_kwargs["safetensors_load_strategy"]) - load_group.add_argument("--model-loader-extra-config", - **load_kwargs["model_loader_extra_config"]) - load_group.add_argument("--ignore-patterns", - **load_kwargs["ignore_patterns"]) - load_group.add_argument("--use-tqdm-on-load", - **load_kwargs["use_tqdm_on_load"]) - load_group.add_argument('--pt-load-map-location', - **load_kwargs["pt_load_map_location"]) + load_group.add_argument("--download-dir", **load_kwargs["download_dir"]) + load_group.add_argument( + "--safetensors-load-strategy", **load_kwargs["safetensors_load_strategy"] + ) + load_group.add_argument( + "--model-loader-extra-config", **load_kwargs["model_loader_extra_config"] + ) + load_group.add_argument("--ignore-patterns", **load_kwargs["ignore_patterns"]) + load_group.add_argument("--use-tqdm-on-load", **load_kwargs["use_tqdm_on_load"]) + load_group.add_argument( + "--pt-load-map-location", **load_kwargs["pt_load_map_location"] + ) # Structured outputs arguments structured_outputs_kwargs = get_kwargs(StructuredOutputsConfig) @@ -629,7 +674,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--reasoning-parser", # This choice is a special case because it's not static choices=list(ReasoningParserManager.reasoning_parsers), - **structured_outputs_kwargs["reasoning_parser"]) + **structured_outputs_kwargs["reasoning_parser"], + ) # Deprecated guided decoding arguments for arg, type in [ ("--guided-decoding-backend", str), @@ -641,7 +687,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: arg, type=type, help=(f"[DEPRECATED] {arg} will be removed in v0.12.0."), - deprecated=True) + deprecated=True, + ) # Parallel arguments parallel_kwargs = get_kwargs(ParallelConfig) @@ -651,111 +698,128 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) parallel_group.add_argument( "--distributed-executor-backend", - **parallel_kwargs["distributed_executor_backend"]) + **parallel_kwargs["distributed_executor_backend"], + ) parallel_group.add_argument( - "--pipeline-parallel-size", "-pp", - **parallel_kwargs["pipeline_parallel_size"]) - parallel_group.add_argument("--tensor-parallel-size", "-tp", - **parallel_kwargs["tensor_parallel_size"]) + "--pipeline-parallel-size", + "-pp", + **parallel_kwargs["pipeline_parallel_size"], + ) parallel_group.add_argument( - "--decode-context-parallel-size", "-dcp", - **parallel_kwargs["decode_context_parallel_size"]) - parallel_group.add_argument("--data-parallel-size", "-dp", - **parallel_kwargs["data_parallel_size"]) + "--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"] + ) parallel_group.add_argument( - '--data-parallel-rank', - '-dpn', + "--decode-context-parallel-size", + "-dcp", + **parallel_kwargs["decode_context_parallel_size"], + ) + parallel_group.add_argument( + "--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"] + ) + parallel_group.add_argument( + "--data-parallel-rank", + "-dpn", type=int, - help='Data parallel rank of this instance. ' - 'When set, enables external load balancer mode.') - parallel_group.add_argument('--data-parallel-start-rank', - '-dpr', - type=int, - help='Starting data parallel rank ' - 'for secondary nodes.') - parallel_group.add_argument('--data-parallel-size-local', - '-dpl', - type=int, - help='Number of data parallel replicas ' - 'to run on this node.') - parallel_group.add_argument('--data-parallel-address', - '-dpa', - type=str, - help='Address of data parallel cluster ' - 'head-node.') - parallel_group.add_argument('--data-parallel-rpc-port', - '-dpp', - type=int, - help='Port for data parallel RPC ' - 'communication.') - parallel_group.add_argument('--data-parallel-backend', - '-dpb', - type=str, - default='mp', - help='Backend for data parallel, either ' - '"mp" or "ray".') + help="Data parallel rank of this instance. " + "When set, enables external load balancer mode.", + ) parallel_group.add_argument( - "--data-parallel-hybrid-lb", - **parallel_kwargs["data_parallel_hybrid_lb"]) + "--data-parallel-start-rank", + "-dpr", + type=int, + help="Starting data parallel rank for secondary nodes.", + ) + parallel_group.add_argument( + "--data-parallel-size-local", + "-dpl", + type=int, + help="Number of data parallel replicas to run on this node.", + ) + parallel_group.add_argument( + "--data-parallel-address", + "-dpa", + type=str, + help="Address of data parallel cluster head-node.", + ) + parallel_group.add_argument( + "--data-parallel-rpc-port", + "-dpp", + type=int, + help="Port for data parallel RPC communication.", + ) + parallel_group.add_argument( + "--data-parallel-backend", + "-dpb", + type=str, + default="mp", + help='Backend for data parallel, either "mp" or "ray".', + ) parallel_group.add_argument( - "--enable-expert-parallel", - **parallel_kwargs["enable_expert_parallel"]) - parallel_group.add_argument("--enable-dbo", - **parallel_kwargs["enable_dbo"]) + "--data-parallel-hybrid-lb", **parallel_kwargs["data_parallel_hybrid_lb"] + ) + parallel_group.add_argument( + "--enable-expert-parallel", **parallel_kwargs["enable_expert_parallel"] + ) + parallel_group.add_argument("--enable-dbo", **parallel_kwargs["enable_dbo"]) parallel_group.add_argument( "--dbo-decode-token-threshold", - **parallel_kwargs["dbo_decode_token_threshold"]) + **parallel_kwargs["dbo_decode_token_threshold"], + ) parallel_group.add_argument( "--dbo-prefill-token-threshold", - **parallel_kwargs["dbo_prefill_token_threshold"]) - parallel_group.add_argument("--enable-eplb", - **parallel_kwargs["enable_eplb"]) - parallel_group.add_argument("--eplb-config", - **parallel_kwargs["eplb_config"]) + **parallel_kwargs["dbo_prefill_token_threshold"], + ) + parallel_group.add_argument("--enable-eplb", **parallel_kwargs["enable_eplb"]) + parallel_group.add_argument("--eplb-config", **parallel_kwargs["eplb_config"]) parallel_group.add_argument( "--expert-placement-strategy", - **parallel_kwargs["expert_placement_strategy"]) + **parallel_kwargs["expert_placement_strategy"], + ) parallel_group.add_argument( "--num-redundant-experts", type=int, - help= - "[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --num-redundant-experts will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-window-size", type=int, help="[DEPRECATED] --eplb-window-size will be removed in v0.12.0.", - deprecated=True) + deprecated=True, + ) parallel_group.add_argument( "--eplb-step-interval", type=int, - help= - "[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-step-interval will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--eplb-log-balancedness", action=argparse.BooleanOptionalAction, - help= - "[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", - deprecated=True) + help="[DEPRECATED] --eplb-log-balancedness will be removed in v0.12.0.", + deprecated=True, + ) parallel_group.add_argument( "--max-parallel-loading-workers", - **parallel_kwargs["max_parallel_loading_workers"]) + **parallel_kwargs["max_parallel_loading_workers"], + ) parallel_group.add_argument( - "--ray-workers-use-nsight", - **parallel_kwargs["ray_workers_use_nsight"]) + "--ray-workers-use-nsight", **parallel_kwargs["ray_workers_use_nsight"] + ) parallel_group.add_argument( "--disable-custom-all-reduce", - **parallel_kwargs["disable_custom_all_reduce"]) - parallel_group.add_argument("--worker-cls", - **parallel_kwargs["worker_cls"]) - parallel_group.add_argument("--worker-extension-cls", - **parallel_kwargs["worker_extension_cls"]) + **parallel_kwargs["disable_custom_all_reduce"], + ) + parallel_group.add_argument("--worker-cls", **parallel_kwargs["worker_cls"]) + parallel_group.add_argument( + "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] + ) parallel_group.add_argument( "--enable-multimodal-encoder-data-parallel", action="store_true", - deprecated=True) + deprecated=True, + ) # KV cache arguments cache_kwargs = get_kwargs(CacheConfig) @@ -764,29 +828,36 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=CacheConfig.__doc__, ) cache_group.add_argument("--block-size", **cache_kwargs["block_size"]) - cache_group.add_argument("--gpu-memory-utilization", - **cache_kwargs["gpu_memory_utilization"]) - cache_group.add_argument("--kv-cache-memory-bytes", - **cache_kwargs["kv_cache_memory_bytes"]) + cache_group.add_argument( + "--gpu-memory-utilization", **cache_kwargs["gpu_memory_utilization"] + ) + cache_group.add_argument( + "--kv-cache-memory-bytes", **cache_kwargs["kv_cache_memory_bytes"] + ) cache_group.add_argument("--swap-space", **cache_kwargs["swap_space"]) - cache_group.add_argument("--kv-cache-dtype", - **cache_kwargs["cache_dtype"]) - cache_group.add_argument("--num-gpu-blocks-override", - **cache_kwargs["num_gpu_blocks_override"]) - cache_group.add_argument("--enable-prefix-caching", - **cache_kwargs["enable_prefix_caching"]) - cache_group.add_argument("--prefix-caching-hash-algo", - **cache_kwargs["prefix_caching_hash_algo"]) - cache_group.add_argument("--cpu-offload-gb", - **cache_kwargs["cpu_offload_gb"]) - cache_group.add_argument("--calculate-kv-scales", - **cache_kwargs["calculate_kv_scales"]) - cache_group.add_argument("--kv-sharing-fast-prefill", - **cache_kwargs["kv_sharing_fast_prefill"]) - cache_group.add_argument("--mamba-cache-dtype", - **cache_kwargs["mamba_cache_dtype"]) - cache_group.add_argument("--mamba-ssm-cache-dtype", - **cache_kwargs["mamba_ssm_cache_dtype"]) + cache_group.add_argument("--kv-cache-dtype", **cache_kwargs["cache_dtype"]) + cache_group.add_argument( + "--num-gpu-blocks-override", **cache_kwargs["num_gpu_blocks_override"] + ) + cache_group.add_argument( + "--enable-prefix-caching", **cache_kwargs["enable_prefix_caching"] + ) + cache_group.add_argument( + "--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"] + ) + cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"]) + cache_group.add_argument( + "--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"] + ) + cache_group.add_argument( + "--kv-sharing-fast-prefill", **cache_kwargs["kv_sharing_fast_prefill"] + ) + cache_group.add_argument( + "--mamba-cache-dtype", **cache_kwargs["mamba_cache_dtype"] + ) + cache_group.add_argument( + "--mamba-ssm-cache-dtype", **cache_kwargs["mamba_ssm_cache_dtype"] + ) # Multimodal related configs multimodal_kwargs = get_kwargs(MultiModalConfig) @@ -794,35 +865,41 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: title="MultiModalConfig", description=MultiModalConfig.__doc__, ) - multimodal_group.add_argument("--limit-mm-per-prompt", - **multimodal_kwargs["limit_per_prompt"]) - multimodal_group.add_argument("--media-io-kwargs", - **multimodal_kwargs["media_io_kwargs"]) multimodal_group.add_argument( - "--mm-processor-kwargs", - **multimodal_kwargs["mm_processor_kwargs"]) + "--limit-mm-per-prompt", **multimodal_kwargs["limit_per_prompt"] + ) + multimodal_group.add_argument( + "--media-io-kwargs", **multimodal_kwargs["media_io_kwargs"] + ) + multimodal_group.add_argument( + "--mm-processor-kwargs", **multimodal_kwargs["mm_processor_kwargs"] + ) + multimodal_group.add_argument( + "--mm-processor-cache-gb", **multimodal_kwargs["mm_processor_cache_gb"] + ) multimodal_group.add_argument( - "--mm-processor-cache-gb", - **multimodal_kwargs["mm_processor_cache_gb"]) - multimodal_group.add_argument("--disable-mm-preprocessor-cache", - action="store_true", - deprecated=True) + "--disable-mm-preprocessor-cache", action="store_true", deprecated=True + ) multimodal_group.add_argument( - "--mm-processor-cache-type", - **multimodal_kwargs["mm_processor_cache_type"]) + "--mm-processor-cache-type", **multimodal_kwargs["mm_processor_cache_type"] + ) multimodal_group.add_argument( "--mm-shm-cache-max-object-size-mb", - **multimodal_kwargs["mm_shm_cache_max_object_size_mb"]) + **multimodal_kwargs["mm_shm_cache_max_object_size_mb"], + ) multimodal_group.add_argument( - "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"]) + "--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"] + ) + multimodal_group.add_argument( + "--interleave-mm-strings", **multimodal_kwargs["interleave_mm_strings"] + ) multimodal_group.add_argument( - "--interleave-mm-strings", - **multimodal_kwargs["interleave_mm_strings"]) - multimodal_group.add_argument("--skip-mm-profiling", - **multimodal_kwargs["skip_mm_profiling"]) + "--skip-mm-profiling", **multimodal_kwargs["skip_mm_profiling"] + ) multimodal_group.add_argument( - "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"]) + "--video-pruning-rate", **multimodal_kwargs["video_pruning_rate"] + ) # LoRA related configs lora_kwargs = get_kwargs(LoRAConfig) @@ -833,24 +910,23 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: lora_group.add_argument( "--enable-lora", action=argparse.BooleanOptionalAction, - help="If True, enable handling of LoRA adapters.") - lora_group.add_argument("--enable-lora-bias", - **lora_kwargs["bias_enabled"]) + help="If True, enable handling of LoRA adapters.", + ) + lora_group.add_argument("--enable-lora-bias", **lora_kwargs["bias_enabled"]) lora_group.add_argument("--max-loras", **lora_kwargs["max_loras"]) - lora_group.add_argument("--max-lora-rank", - **lora_kwargs["max_lora_rank"]) - lora_group.add_argument("--lora-extra-vocab-size", - **lora_kwargs["lora_extra_vocab_size"]) + lora_group.add_argument("--max-lora-rank", **lora_kwargs["max_lora_rank"]) + lora_group.add_argument( + "--lora-extra-vocab-size", **lora_kwargs["lora_extra_vocab_size"] + ) lora_group.add_argument( "--lora-dtype", **lora_kwargs["lora_dtype"], ) - lora_group.add_argument("--max-cpu-loras", - **lora_kwargs["max_cpu_loras"]) - lora_group.add_argument("--fully-sharded-loras", - **lora_kwargs["fully_sharded_loras"]) - lora_group.add_argument("--default-mm-loras", - **lora_kwargs["default_mm_loras"]) + lora_group.add_argument("--max-cpu-loras", **lora_kwargs["max_cpu_loras"]) + lora_group.add_argument( + "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] + ) + lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) @@ -860,21 +936,22 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: ) observability_group.add_argument( "--show-hidden-metrics-for-version", - **observability_kwargs["show_hidden_metrics_for_version"]) + **observability_kwargs["show_hidden_metrics_for_version"], + ) observability_group.add_argument( - "--otlp-traces-endpoint", - **observability_kwargs["otlp_traces_endpoint"]) + "--otlp-traces-endpoint", **observability_kwargs["otlp_traces_endpoint"] + ) # TODO: generalise this special case choices = observability_kwargs["collect_detailed_traces"]["choices"] metavar = f"{{{','.join(choices)}}}" observability_kwargs["collect_detailed_traces"]["metavar"] = metavar observability_kwargs["collect_detailed_traces"]["choices"] += [ - ",".join(p) - for p in permutations(get_args(DetailedTraceModules), r=2) + ",".join(p) for p in permutations(get_args(DetailedTraceModules), r=2) ] observability_group.add_argument( "--collect-detailed-traces", - **observability_kwargs["collect_detailed_traces"]) + **observability_kwargs["collect_detailed_traces"], + ) # Scheduler arguments scheduler_kwargs = get_kwargs(SchedulerConfig) @@ -883,40 +960,49 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: description=SchedulerConfig.__doc__, ) scheduler_group.add_argument( - "--max-num-batched-tokens", - **scheduler_kwargs["max_num_batched_tokens"]) - scheduler_group.add_argument("--max-num-seqs", - **scheduler_kwargs["max_num_seqs"]) + "--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"] + ) scheduler_group.add_argument( - "--max-num-partial-prefills", - **scheduler_kwargs["max_num_partial_prefills"]) + "--max-num-seqs", **scheduler_kwargs["max_num_seqs"] + ) + scheduler_group.add_argument( + "--max-num-partial-prefills", **scheduler_kwargs["max_num_partial_prefills"] + ) scheduler_group.add_argument( "--max-long-partial-prefills", - **scheduler_kwargs["max_long_partial_prefills"]) - scheduler_group.add_argument('--cuda-graph-sizes', - **scheduler_kwargs["cuda_graph_sizes"]) + **scheduler_kwargs["max_long_partial_prefills"], + ) + scheduler_group.add_argument( + "--cuda-graph-sizes", **scheduler_kwargs["cuda_graph_sizes"] + ) scheduler_group.add_argument( "--long-prefill-token-threshold", - **scheduler_kwargs["long_prefill_token_threshold"]) - scheduler_group.add_argument("--num-lookahead-slots", - **scheduler_kwargs["num_lookahead_slots"]) + **scheduler_kwargs["long_prefill_token_threshold"], + ) + scheduler_group.add_argument( + "--num-lookahead-slots", **scheduler_kwargs["num_lookahead_slots"] + ) # multi-step scheduling has been removed; corresponding arguments # are no longer supported. - scheduler_group.add_argument("--scheduling-policy", - **scheduler_kwargs["policy"]) scheduler_group.add_argument( - "--enable-chunked-prefill", - **scheduler_kwargs["enable_chunked_prefill"]) + "--scheduling-policy", **scheduler_kwargs["policy"] + ) scheduler_group.add_argument( - "--disable-chunked-mm-input", - **scheduler_kwargs["disable_chunked_mm_input"]) - scheduler_group.add_argument("--scheduler-cls", - **scheduler_kwargs["scheduler_cls"]) + "--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"] + ) + scheduler_group.add_argument( + "--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"] + ) + scheduler_group.add_argument( + "--scheduler-cls", **scheduler_kwargs["scheduler_cls"] + ) scheduler_group.add_argument( "--disable-hybrid-kv-cache-manager", - **scheduler_kwargs["disable_hybrid_kv_cache_manager"]) - scheduler_group.add_argument("--async-scheduling", - **scheduler_kwargs["async_scheduling"]) + **scheduler_kwargs["disable_hybrid_kv_cache_manager"], + ) + scheduler_group.add_argument( + "--async-scheduling", **scheduler_kwargs["async_scheduling"] + ) # vLLM arguments vllm_kwargs = get_kwargs(VllmConfig) @@ -928,25 +1014,30 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: # create_engine_config. So we set the type to a JSON string here to # delay the Pydantic validation that comes with SpeculativeConfig. vllm_kwargs["speculative_config"]["type"] = optional_type(json.loads) - vllm_group.add_argument("--speculative-config", - **vllm_kwargs["speculative_config"]) - vllm_group.add_argument("--kv-transfer-config", - **vllm_kwargs["kv_transfer_config"]) - vllm_group.add_argument('--kv-events-config', - **vllm_kwargs["kv_events_config"]) - vllm_group.add_argument("--compilation-config", "-O", - **vllm_kwargs["compilation_config"]) - vllm_group.add_argument("--reasoning-config", - **vllm_kwargs["reasoning_config"]) - vllm_group.add_argument("--additional-config", - **vllm_kwargs["additional_config"]) - vllm_group.add_argument('--structured-outputs-config', - **vllm_kwargs["structured_outputs_config"]) + vllm_group.add_argument( + "--speculative-config", **vllm_kwargs["speculative_config"] + ) + vllm_group.add_argument( + "--kv-transfer-config", **vllm_kwargs["kv_transfer_config"] + ) + vllm_group.add_argument("--kv-events-config", **vllm_kwargs["kv_events_config"]) + vllm_group.add_argument( + "--compilation-config", "-O", **vllm_kwargs["compilation_config"] + ) + vllm_group.add_argument("--reasoning-config", **vllm_kwargs["reasoning_config"]) + vllm_group.add_argument( + "--additional-config", **vllm_kwargs["additional_config"] + ) + vllm_group.add_argument( + "--structured-outputs-config", **vllm_kwargs["structured_outputs_config"] + ) # Other arguments - parser.add_argument('--disable-log-stats', - action='store_true', - help='Disable logging statistics.') + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="Disable logging statistics.", + ) return parser @@ -955,10 +1046,9 @@ def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - engine_args = cls(**{ - attr: getattr(args, attr) - for attr in attrs if hasattr(args, attr) - }) + engine_args = cls( + **{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)} + ) return engine_args def create_model_config(self) -> ModelConfig: @@ -967,15 +1057,20 @@ def create_model_config(self) -> ModelConfig: self.quantization = self.load_format = "gguf" # NOTE: This is to allow model loading from S3 in CI - if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3 - and self.model in MODELS_ON_S3 and self.load_format == "auto"): + if ( + not isinstance(self, AsyncEngineArgs) + and envs.VLLM_CI_USE_S3 + and self.model in MODELS_ON_S3 + and self.load_format == "auto" + ): self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}" if self.disable_mm_preprocessor_cache: logger.warning( "`--disable-mm-preprocessor-cache` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-processor-cache-gb 0` instead.", ) + "Please use `--mm-processor-cache-gb 0` instead.", + ) self.mm_processor_cache_gb = 0 elif envs.VLLM_MM_INPUT_CACHE_GIB != 4: @@ -992,7 +1087,8 @@ def create_model_config(self) -> ModelConfig: logger.warning( "--enable-multimodal-encoder-data-parallel` is deprecated " "and will be removed in v0.13. " - "Please use `--mm-encoder-tp-mode data` instead.") + "Please use `--mm-encoder-tp-mode data` instead." + ) self.mm_encoder_tp_mode = "data" @@ -1034,8 +1130,7 @@ def create_model_config(self) -> ModelConfig: mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_cache_gb=self.mm_processor_cache_gb, mm_processor_cache_type=self.mm_processor_cache_type, - mm_shm_cache_max_object_size_mb=self. - mm_shm_cache_max_object_size_mb, + mm_shm_cache_max_object_size_mb=self.mm_shm_cache_max_object_size_mb, mm_encoder_tp_mode=self.mm_encoder_tp_mode, pooler_config=self.pooler_config, override_pooler_config=self.override_pooler_config, @@ -1051,33 +1146,34 @@ def create_model_config(self) -> ModelConfig: ) def validate_tensorizer_args(self): - from vllm.model_executor.model_loader.tensorizer import ( - TensorizerConfig) + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + for key in self.model_loader_extra_config: if key in TensorizerConfig._fields: - self.model_loader_extra_config["tensorizer_config"][ - key] = self.model_loader_extra_config[key] + self.model_loader_extra_config["tensorizer_config"][key] = ( + self.model_loader_extra_config[key] + ) def create_load_config(self) -> LoadConfig: - if self.quantization == "bitsandbytes": self.load_format = "bitsandbytes" if self.load_format == "tensorizer": if hasattr(self.model_loader_extra_config, "to_serializable"): self.model_loader_extra_config = ( - self.model_loader_extra_config.to_serializable()) + self.model_loader_extra_config.to_serializable() + ) self.model_loader_extra_config["tensorizer_config"] = {} - self.model_loader_extra_config["tensorizer_config"][ - "tensorizer_dir"] = self.model + self.model_loader_extra_config["tensorizer_config"]["tensorizer_dir"] = ( + self.model + ) self.validate_tensorizer_args() return LoadConfig( load_format=self.load_format, download_dir=self.download_dir, safetensors_load_strategy=self.safetensors_load_strategy, - device="cpu" - if is_online_quantization(self.quantization) else None, + device="cpu" if is_online_quantization(self.quantization) else None, model_loader_extra_config=self.model_loader_extra_config, ignore_patterns=self.ignore_patterns, use_tqdm_on_load=self.use_tqdm_on_load, @@ -1105,12 +1201,14 @@ def create_speculative_config( # Note(Shangming): These parameters are not obtained from the cli arg # '--speculative-config' and must be passed in when creating the engine # config. - self.speculative_config.update({ - "target_model_config": target_model_config, - "target_parallel_config": target_parallel_config, - "enable_chunked_prefill": enable_chunked_prefill, - "disable_log_stats": disable_log_stats, - }) + self.speculative_config.update( + { + "target_model_config": target_model_config, + "target_parallel_config": target_parallel_config, + "enable_chunked_prefill": enable_chunked_prefill, + "disable_log_stats": disable_log_stats, + } + ) return SpeculativeConfig(**self.speculative_config) def create_engine_config( @@ -1133,21 +1231,21 @@ def create_engine_config( """ current_platform.pre_register_and_update() - device_config = DeviceConfig( - device=cast(Device, current_platform.device_type)) + device_config = DeviceConfig(device=cast(Device, current_platform.device_type)) model_config = self.create_model_config() self.model = model_config.model self.tokenizer = model_config.tokenizer - (self.model, self.tokenizer, - self.speculative_config) = maybe_override_with_speculators( - model=self.model, - tokenizer=self.tokenizer, - revision=self.revision, - trust_remote_code=self.trust_remote_code, - vllm_speculative_config=self.speculative_config, - ) + (self.model, self.tokenizer, self.speculative_config) = ( + maybe_override_with_speculators( + model=self.model, + tokenizer=self.tokenizer, + revision=self.revision, + trust_remote_code=self.trust_remote_code, + vllm_speculative_config=self.speculative_config, + ) + ) # * If VLLM_USE_V1 is unset, we enable V1 for "supported features" # and fall back to V0 for experimental or unsupported features. @@ -1169,12 +1267,17 @@ def create_engine_config( # Set default arguments for V1 Engine. self._set_default_args(usage_context, model_config) # Disable chunked prefill for POWER (ppc64le)/ARM/s390x/RISCV CPUs in V1 - if current_platform.is_cpu() and current_platform.get_cpu_architecture( - ) in (CpuArchEnum.POWERPC, CpuArchEnum.S390X, CpuArchEnum.ARM, - CpuArchEnum.RISCV): - logger.info("Chunked prefill is not supported for ARM and POWER, " - "S390X and RISC-V CPUs; " - "disabling it for V1 backend.") + if current_platform.is_cpu() and current_platform.get_cpu_architecture() in ( + CpuArchEnum.POWERPC, + CpuArchEnum.S390X, + CpuArchEnum.ARM, + CpuArchEnum.RISCV, + ): + logger.info( + "Chunked prefill is not supported for ARM and POWER, " + "S390X and RISC-V CPUs; " + "disabling it for V1 backend." + ) self.enable_chunked_prefill = False assert self.enable_chunked_prefill is not None @@ -1190,8 +1293,7 @@ def create_engine_config( # because the world size does not change by dcp, it simply # reuses the GPUs of TP group, and split one TP group into # tp_size//dcp_size DCP groups. - assert self.tensor_parallel_size % self.decode_context_parallel_size \ - == 0, ( + assert self.tensor_parallel_size % self.decode_context_parallel_size == 0, ( f"tp_size={self.tensor_parallel_size} must be divisible by" f"dcp_size={self.decode_context_parallel_size}." ) @@ -1220,6 +1322,7 @@ def create_engine_config( # of a Ray task, therefore we check is_ray_initialized() # as opposed to is_in_ray_actor(). import ray + ray_runtime_env = ray.get_runtime_context().runtime_env logger.info("Using ray runtime env: %s", ray_runtime_env) @@ -1235,15 +1338,15 @@ def create_engine_config( placement_group = ray.util.get_current_placement_group() assert not headless or not self.data_parallel_hybrid_lb, ( - "data_parallel_hybrid_lb is not applicable in " - "headless mode") + "data_parallel_hybrid_lb is not applicable in headless mode" + ) data_parallel_external_lb = self.data_parallel_rank is not None # Local DP rank = 1, use pure-external LB. if data_parallel_external_lb: assert self.data_parallel_size_local in (1, None), ( - "data_parallel_size_local must be 1 when data_parallel_rank " - "is set") + "data_parallel_size_local must be 1 when data_parallel_rank is set" + ) data_parallel_size_local = 1 # Use full external lb if we have local_size of 1. self.data_parallel_hybrid_lb = False @@ -1266,8 +1369,8 @@ def create_engine_config( self.data_parallel_rank = self.data_parallel_start_rank or 0 else: assert not self.data_parallel_hybrid_lb, ( - "data_parallel_size_local must be set to use " - "data_parallel_hybrid_lb.") + "data_parallel_size_local must be set to use data_parallel_hybrid_lb." + ) # Local DP size defaults to global DP size if not set. data_parallel_size_local = self.data_parallel_size @@ -1278,39 +1381,46 @@ def create_engine_config( if self.data_parallel_backend == "ray": host_ip = get_ip() logger.info( - "Using host IP %s as ray-based data parallel address", - host_ip) + "Using host IP %s as ray-based data parallel address", host_ip + ) data_parallel_address = host_ip else: assert self.data_parallel_backend == "mp", ( "data_parallel_backend can only be ray or mp, got %s", - self.data_parallel_backend) + self.data_parallel_backend, + ) data_parallel_address = ParallelConfig.data_parallel_master_ip else: data_parallel_address = self.data_parallel_address # This port is only used when there are remote data parallel engines, # otherwise the local IPC transport is used. - data_parallel_rpc_port = self.data_parallel_rpc_port if ( + data_parallel_rpc_port = ( self.data_parallel_rpc_port - is not None) else ParallelConfig.data_parallel_rpc_port + if (self.data_parallel_rpc_port is not None) + else ParallelConfig.data_parallel_rpc_port + ) if self.async_scheduling: # Async scheduling does not work with the uniprocess backend. if self.distributed_executor_backend is None: self.distributed_executor_backend = "mp" - logger.info("Defaulting to mp-based distributed executor " - "backend for async scheduling.") + logger.info( + "Defaulting to mp-based distributed executor " + "backend for async scheduling." + ) if self.pipeline_parallel_size > 1: - raise ValueError("Async scheduling is not supported with " - "pipeline-parallel-size > 1.") + raise ValueError( + "Async scheduling is not supported with pipeline-parallel-size > 1." + ) # Currently, async scheduling does not support speculative decoding. # TODO(woosuk): Support it. if self.speculative_config is not None: raise ValueError( "Currently, speculative decoding is not supported with " - "async scheduling.") + "async scheduling." + ) # Forward the deprecated CLI args to the EPLB config. if self.num_redundant_experts is not None: @@ -1377,33 +1487,38 @@ def create_engine_config( disable_chunked_mm_input=self.disable_chunked_mm_input, is_multimodal_model=model_config.is_multimodal_model, is_encoder_decoder=model_config.is_encoder_decoder, - send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER - and parallel_config.use_ray), + send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER and parallel_config.use_ray), policy=self.scheduling_policy, scheduler_cls=self.scheduler_cls, max_num_partial_prefills=self.max_num_partial_prefills, max_long_partial_prefills=self.max_long_partial_prefills, long_prefill_token_threshold=self.long_prefill_token_threshold, - disable_hybrid_kv_cache_manager=self. - disable_hybrid_kv_cache_manager, + disable_hybrid_kv_cache_manager=self.disable_hybrid_kv_cache_manager, async_scheduling=self.async_scheduling, ) if not model_config.is_multimodal_model and self.default_mm_loras: raise ValueError( "Default modality-specific LoRA(s) were provided for a " - "non multimodal model") - - lora_config = LoRAConfig( - bias_enabled=self.enable_lora_bias, - max_lora_rank=self.max_lora_rank, - max_loras=self.max_loras, - default_mm_loras=self.default_mm_loras, - fully_sharded_loras=self.fully_sharded_loras, - lora_extra_vocab_size=self.lora_extra_vocab_size, - lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras - and self.max_cpu_loras > 0 else None) if self.enable_lora else None + "non multimodal model" + ) + + lora_config = ( + LoRAConfig( + bias_enabled=self.enable_lora_bias, + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + default_mm_loras=self.default_mm_loras, + fully_sharded_loras=self.fully_sharded_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras and self.max_cpu_loras > 0 + else None, + ) + if self.enable_lora + else None + ) # bitsandbytes pre-quantized model need a specific model loader if model_config.quantization == "bitsandbytes": @@ -1413,27 +1528,27 @@ def create_engine_config( # Pass reasoning_parser into StructuredOutputsConfig if self.reasoning_parser: - self.structured_outputs_config.reasoning_parser = \ - self.reasoning_parser + self.structured_outputs_config.reasoning_parser = self.reasoning_parser # Forward the deprecated CLI args to the StructuredOutputsConfig so_config = self.structured_outputs_config if self.guided_decoding_backend is not None: - so_config.guided_decoding_backend = \ - self.guided_decoding_backend + so_config.guided_decoding_backend = self.guided_decoding_backend if self.guided_decoding_disable_fallback is not None: - so_config.guided_decoding_disable_fallback = \ - self.guided_decoding_disable_fallback + so_config.guided_decoding_disable_fallback = ( + self.guided_decoding_disable_fallback + ) if self.guided_decoding_disable_any_whitespace is not None: - so_config.guided_decoding_disable_any_whitespace = \ - self.guided_decoding_disable_any_whitespace + so_config.guided_decoding_disable_any_whitespace = ( + self.guided_decoding_disable_any_whitespace + ) if self.guided_decoding_disable_additional_properties is not None: - so_config.guided_decoding_disable_additional_properties = \ - self.guided_decoding_disable_additional_properties + so_config.guided_decoding_disable_additional_properties = ( + self.guided_decoding_disable_additional_properties + ) observability_config = ObservabilityConfig( - show_hidden_metrics_for_version=( - self.show_hidden_metrics_for_version), + show_hidden_metrics_for_version=(self.show_hidden_metrics_for_version), otlp_traces_endpoint=self.otlp_traces_endpoint, collect_detailed_traces=self.collect_detailed_traces, ) @@ -1464,25 +1579,28 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: ############################################################# # Unsupported Feature Flags on V1. - if (self.logits_processor_pattern - != EngineArgs.logits_processor_pattern): - _raise_or_fallback(feature_name="--logits-processor-pattern", - recommend_to_remove=False) + if self.logits_processor_pattern != EngineArgs.logits_processor_pattern: + _raise_or_fallback( + feature_name="--logits-processor-pattern", recommend_to_remove=False + ) return False # No Mamba or Encoder-Decoder so far. if not model_config.is_v1_compatible: - _raise_or_fallback(feature_name=model_config.architectures, - recommend_to_remove=False) + _raise_or_fallback( + feature_name=model_config.architectures, recommend_to_remove=False + ) return False # No Concurrent Partial Prefills so far. - if (self.max_num_partial_prefills - != SchedulerConfig.max_num_partial_prefills - or self.max_long_partial_prefills - != SchedulerConfig.max_long_partial_prefills): - _raise_or_fallback(feature_name="Concurrent Partial Prefill", - recommend_to_remove=False) + if ( + self.max_num_partial_prefills != SchedulerConfig.max_num_partial_prefills + or self.max_long_partial_prefills + != SchedulerConfig.max_long_partial_prefills + ): + _raise_or_fallback( + feature_name="Concurrent Partial Prefill", recommend_to_remove=False + ) return False # V1 supports N-gram, Medusa, and Eagle speculative decoding. @@ -1497,7 +1615,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: raise NotImplementedError( "Draft model speculative decoding is not supported yet. " "Please consider using other speculative decoding methods " - "such as ngram, medusa, eagle, or mtp.") + "such as ngram, medusa, eagle, or mtp." + ) V1_BACKENDS = [ "FLASH_ATTN", @@ -1516,8 +1635,10 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "XFORMERS", "ROCM_ATTN", ] - if (envs.is_set("VLLM_ATTENTION_BACKEND") - and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS): + if ( + envs.is_set("VLLM_ATTENTION_BACKEND") + and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS + ): name = f"VLLM_ATTENTION_BACKEND={envs.VLLM_ATTENTION_BACKEND}" _raise_or_fallback(feature_name=name, recommend_to_remove=True) return False @@ -1526,30 +1647,36 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: # Experimental Features - allow users to opt in. if self.pipeline_parallel_size > 1: - supports_pp = getattr(self.distributed_executor_backend, - 'supports_pp', False) + supports_pp = getattr( + self.distributed_executor_backend, "supports_pp", False + ) if not supports_pp and self.distributed_executor_backend not in ( - ParallelConfig.distributed_executor_backend, "ray", "mp", - "external_launcher"): - name = "Pipeline Parallelism without Ray distributed " \ - "executor or multiprocessing executor or external " \ - "launcher" - _raise_or_fallback(feature_name=name, - recommend_to_remove=False) + ParallelConfig.distributed_executor_backend, + "ray", + "mp", + "external_launcher", + ): + name = ( + "Pipeline Parallelism without Ray distributed " + "executor or multiprocessing executor or external " + "launcher" + ) + _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False - if (current_platform.is_cpu() - and model_config.get_sliding_window() is not None): - _raise_or_fallback(feature_name="sliding window (CPU backend)", - recommend_to_remove=False) + if current_platform.is_cpu() and model_config.get_sliding_window() is not None: + _raise_or_fallback( + feature_name="sliding window (CPU backend)", recommend_to_remove=False + ) return False ############################################################# return True - def _set_default_args(self, usage_context: UsageContext, - model_config: ModelConfig) -> None: + def _set_default_args( + self, usage_context: UsageContext, model_config: ModelConfig + ) -> None: """Set Default Arguments for V1 Engine.""" # V1 always uses chunked prefills and prefix caching @@ -1560,12 +1687,12 @@ def _set_default_args(self, usage_context: UsageContext, # TODO: When prefix caching supports prompt embeds inputs, this # check can be removed. - if (self.enable_prompt_embeds - and self.enable_prefix_caching is not False): + if self.enable_prompt_embeds and self.enable_prefix_caching is not False: logger.warning( "--enable-prompt-embeds and --enable-prefix-caching " "are not supported together in V1. Prefix caching has " - "been disabled.") + "been disabled." + ) self.enable_prefix_caching = False if self.enable_prefix_caching is None: @@ -1576,15 +1703,15 @@ def _set_default_args(self, usage_context: UsageContext, else: self.enable_prefix_caching = True else: - pooling_type = model_config.pooler_config.pooling_type is_causal = getattr(model_config.hf_config, "is_causal", True) - incremental_prefill_supported = (pooling_type is not None - and pooling_type.lower() == "last" - and is_causal) + incremental_prefill_supported = ( + pooling_type is not None + and pooling_type.lower() == "last" + and is_causal + ) - action = "Enabling" if \ - incremental_prefill_supported else "Disabling" + action = "Enabling" if incremental_prefill_supported else "Disabling" if self.enable_chunked_prefill is None: self.enable_chunked_prefill = incremental_prefill_supported @@ -1618,6 +1745,7 @@ def _set_default_args(self, usage_context: UsageContext, # throughput, see PR #17885 for more details. # So here we do an extra device name check to prevent such regression. from vllm.usage.usage_lib import UsageContext + if device_memory >= 70 * GiB_bytes and "a100" not in device_name: # For GPUs like H100 and MI300x, use larger default values. default_max_num_batched_tokens = { @@ -1643,15 +1771,15 @@ def _set_default_args(self, usage_context: UsageContext, if current_platform.is_tpu(): default_max_num_batched_tokens_tpu = { UsageContext.LLM_CLASS: { - 'V6E': 2048, - 'V5E': 1024, - 'V5P': 512, + "V6E": 2048, + "V5E": 1024, + "V5P": 512, }, UsageContext.OPENAI_API_SERVER: { - 'V6E': 1024, - 'V5E': 512, - 'V5P': 256, - } + "V6E": 1024, + "V5E": 512, + "V5P": 256, + }, } # cpu specific default values. @@ -1667,47 +1795,58 @@ def _set_default_args(self, usage_context: UsageContext, } use_context_value = usage_context.value if usage_context else None - if (self.max_num_batched_tokens is None - and usage_context in default_max_num_batched_tokens): + if ( + self.max_num_batched_tokens is None + and usage_context in default_max_num_batched_tokens + ): if current_platform.is_tpu(): chip_name = current_platform.get_device_name() - if chip_name in default_max_num_batched_tokens_tpu[ - usage_context]: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens_tpu[ - usage_context][chip_name] + if chip_name in default_max_num_batched_tokens_tpu[usage_context]: + self.max_num_batched_tokens = default_max_num_batched_tokens_tpu[ + usage_context + ][chip_name] else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] else: if not self.enable_chunked_prefill: self.max_num_batched_tokens = model_config.max_model_len else: - self.max_num_batched_tokens = \ - default_max_num_batched_tokens[usage_context] + self.max_num_batched_tokens = default_max_num_batched_tokens[ + usage_context + ] logger.debug( "Setting max_num_batched_tokens to %d for %s usage context.", - self.max_num_batched_tokens, use_context_value) + self.max_num_batched_tokens, + use_context_value, + ) - if (self.max_num_seqs is None - and usage_context in default_max_num_seqs): - self.max_num_seqs = min(default_max_num_seqs[usage_context], - self.max_num_batched_tokens or sys.maxsize) + if self.max_num_seqs is None and usage_context in default_max_num_seqs: + self.max_num_seqs = min( + default_max_num_seqs[usage_context], + self.max_num_batched_tokens or sys.maxsize, + ) - logger.debug("Setting max_num_seqs to %d for %s usage context.", - self.max_num_seqs, use_context_value) + logger.debug( + "Setting max_num_seqs to %d for %s usage context.", + self.max_num_seqs, + use_context_value, + ) @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + enable_log_requests: bool = False @property @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self) -> bool: return not self.enable_log_requests @@ -1715,28 +1854,34 @@ def disable_log_requests(self) -> bool: @deprecated( "`disable_log_requests` is deprecated and has been replaced with " "`enable_log_requests`. This will be removed in v0.12.0. Please use " - "`enable_log_requests` instead.") + "`enable_log_requests` instead." + ) def disable_log_requests(self, value: bool): self.enable_log_requests = not value @staticmethod - def add_cli_args(parser: FlexibleArgumentParser, - async_args_only: bool = False) -> FlexibleArgumentParser: + def add_cli_args( + parser: FlexibleArgumentParser, async_args_only: bool = False + ) -> FlexibleArgumentParser: # Initialize plugin to update the parser, for example, The plugin may # add a new kind of quantization method to --quantization argument or # a new device to --device argument. load_general_plugins() if not async_args_only: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--enable-log-requests', - action=argparse.BooleanOptionalAction, - default=AsyncEngineArgs.enable_log_requests, - help='Enable logging requests.') - parser.add_argument('--disable-log-requests', - action=argparse.BooleanOptionalAction, - default=not AsyncEngineArgs.enable_log_requests, - help='[DEPRECATED] Disable logging requests.', - deprecated=True) + parser.add_argument( + "--enable-log-requests", + action=argparse.BooleanOptionalAction, + default=AsyncEngineArgs.enable_log_requests, + help="Enable logging requests.", + ) + parser.add_argument( + "--disable-log-requests", + action=argparse.BooleanOptionalAction, + default=not AsyncEngineArgs.enable_log_requests, + help="[DEPRECATED] Disable logging requests.", + deprecated=True, + ) current_platform.pre_register_and_update(parser) return parser @@ -1744,7 +1889,8 @@ def add_cli_args(parser: FlexibleArgumentParser, def _raise_or_fallback(feature_name: str, recommend_to_remove: bool): if envs.is_set("VLLM_USE_V1") and envs.VLLM_USE_V1: raise NotImplementedError( - f"VLLM_USE_V1=1 is not supported with {feature_name}.") + f"VLLM_USE_V1=1 is not supported with {feature_name}." + ) msg = f"{feature_name} is not supported by the V1 Engine. " msg += "Falling back to V0. " if recommend_to_remove: @@ -1763,17 +1909,17 @@ def human_readable_int(value): - '25.6k' -> 25,600 """ value = value.strip() - match = re.fullmatch(r'(\d+(?:\.\d+)?)([kKmMgGtT])', value) + match = re.fullmatch(r"(\d+(?:\.\d+)?)([kKmMgGtT])", value) if match: decimal_multiplier = { - 'k': 10**3, - 'm': 10**6, - 'g': 10**9, + "k": 10**3, + "m": 10**6, + "g": 10**9, } binary_multiplier = { - 'K': 2**10, - 'M': 2**20, - 'G': 2**30, + "K": 2**10, + "M": 2**20, + "G": 2**30, } number, suffix = match.groups() @@ -1786,9 +1932,11 @@ def human_readable_int(value): try: return int(number) * mult except ValueError as e: - raise argparse.ArgumentTypeError("Decimals are not allowed " \ - f"with binary suffixes like {suffix}. Did you mean to use " \ - f"{number}{suffix.lower()} instead?") from e + raise argparse.ArgumentTypeError( + "Decimals are not allowed " + f"with binary suffixes like {suffix}. Did you mean to use " + f"{number}{suffix.lower()} instead?" + ) from e # Regular plain number. return int(value) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 83c031019487..2b18196629e6 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,70 +6,84 @@ import json import time from http import HTTPStatus -from typing import (Annotated, Any, ClassVar, Generic, Literal, Optional, - TypeVar, Union) +from typing import Annotated, Any, ClassVar, Generic, Literal, Optional, TypeVar, Union import regex as re import torch from fastapi import HTTPException, UploadFile + # yapf: disable from openai.types.chat.chat_completion_audio import ( - ChatCompletionAudio as OpenAIChatCompletionAudio) -from openai.types.chat.chat_completion_message import ( - Annotation as OpenAIAnnotation) + ChatCompletionAudio as OpenAIChatCompletionAudio, +) +from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation from openai.types.responses import ( ResponseCodeInterpreterCallCodeDeltaEvent, ResponseCodeInterpreterCallCodeDoneEvent, ResponseCodeInterpreterCallCompletedEvent, ResponseCodeInterpreterCallInProgressEvent, - ResponseCodeInterpreterCallInterpretingEvent) -from openai.types.responses import ( - ResponseCompletedEvent as OpenAIResponseCompletedEvent) -from openai.types.responses import (ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent) + ResponseCodeInterpreterCallInterpretingEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseFunctionToolCall, + ResponseInputItemParam, + ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponsePrompt, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseStatus, + ResponseWebSearchCallCompletedEvent, + ResponseWebSearchCallInProgressEvent, + ResponseWebSearchCallSearchingEvent, +) from openai.types.responses import ( - ResponseCreatedEvent as OpenAIResponseCreatedEvent) -from openai.types.responses import ResponseFunctionToolCall + ResponseCompletedEvent as OpenAIResponseCompletedEvent, +) +from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent from openai.types.responses import ( - ResponseInProgressEvent as OpenAIResponseInProgressEvent) -from openai.types.responses import (ResponseInputItemParam, ResponseOutputItem, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponsePrompt, ResponseReasoningItem, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent, - ResponseStatus, - ResponseWebSearchCallCompletedEvent, - ResponseWebSearchCallInProgressEvent, - ResponseWebSearchCallSearchingEvent) + ResponseInProgressEvent as OpenAIResponseInProgressEvent, +) + # yapf: enable from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent) + Content as ResponseReasoningTextContent, +) # Backward compatibility for OpenAI client versions try: # For older openai versions (< 1.100.0) from openai.types.responses import ResponseTextConfig except ImportError: # For newer openai versions (>= 1.100.0) - from openai.types.responses import (ResponseFormatTextConfig as - ResponseTextConfig) + from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig from openai.types.responses.response import IncompleteDetails, ToolChoice from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning -from pydantic import (BaseModel, ConfigDict, Field, TypeAdapter, - ValidationInfo, field_validator, model_validator) +from pydantic import ( + BaseModel, + ConfigDict, + Field, + TypeAdapter, + ValidationInfo, + field_validator, + model_validator, +) from typing_extensions import TypeAlias from vllm import envs -from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - make_tool_call_id) -from vllm.entrypoints.score_utils import (ScoreContentPartParam, - ScoreMultiModalParam) +from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id +from vllm.entrypoints.score_utils import ScoreContentPartParam, ScoreMultiModalParam from vllm.logger import init_logger from vllm.logprobs import Logprob from vllm.pooling_params import PoolingParams -from vllm.sampling_params import (BeamSearchParams, RequestOutputKind, - SamplingParams, StructuredOutputsParams) +from vllm.sampling_params import ( + BeamSearchParams, + RequestOutputKind, + SamplingParams, + StructuredOutputsParams, +) from vllm.utils import random_uuid, resolve_obj_by_qualname logger = init_logger(__name__) @@ -103,8 +117,7 @@ def __log_extra_fields__(cls, data, handler): # Compare against both field names and aliases if any(k not in field_names for k in data): logger.warning( - "The following fields were present in the request " - "but ignored: %s", + "The following fields were present in the request but ignored: %s", data.keys() - field_names, ) return result @@ -173,7 +186,7 @@ class JsonSchemaResponseFormat(OpenAIBaseModel): description: Optional[str] = None # schema is the field in openai but that causes conflicts with pydantic so # instead use json_schema with an alias - json_schema: Optional[dict[str, Any]] = Field(default=None, alias='schema') + json_schema: Optional[dict[str, Any]] = Field(default=None, alias="schema") strict: Optional[bool] = None @@ -181,8 +194,9 @@ class StructuralTag(OpenAIBaseModel): begin: str # schema is the field, but that causes conflicts with pydantic so # instead use structural_tag_schema with an alias - structural_tag_schema: Optional[dict[str, Any]] = Field(default=None, - alias="schema") + structural_tag_schema: Optional[dict[str, Any]] = Field( + default=None, alias="schema" + ) end: str @@ -239,18 +253,19 @@ class LogitsProcessorConstructor(BaseModel): LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] -def get_logits_processors(processors: Optional[LogitsProcessors], - pattern: Optional[str]) -> Optional[list[Any]]: +def get_logits_processors( + processors: Optional[LogitsProcessors], pattern: Optional[str] +) -> Optional[list[Any]]: if processors and pattern: logits_processors = [] for processor in processors: - qualname = processor if isinstance(processor, - str) else processor.qualname + qualname = processor if isinstance(processor, str) else processor.qualname if not re.match(pattern, qualname): raise ValueError( f"Logits processor '{qualname}' is not allowed by this " "server. See --logits-processor-pattern engine argument " - "for more information.") + "for more information." + ) try: logits_processor = resolve_obj_by_qualname(qualname) except Exception as e: @@ -258,37 +273,41 @@ def get_logits_processors(processors: Optional[LogitsProcessors], f"Logits processor '{qualname}' could not be resolved: {e}" ) from e if isinstance(processor, LogitsProcessorConstructor): - logits_processor = logits_processor(*processor.args or [], - **processor.kwargs or {}) + logits_processor = logits_processor( + *processor.args or [], **processor.kwargs or {} + ) logits_processors.append(logits_processor) return logits_processors elif processors: raise ValueError( "The `logits_processors` argument is not supported by this " "server. See --logits-processor-pattern engine argument " - "for more information.") + "for more information." + ) return None -ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, - ResponseReasoningItem, - ResponseFunctionToolCall] +ResponseInputOutputItem: TypeAlias = Union[ + ResponseInputItemParam, ResponseReasoningItem, ResponseFunctionToolCall +] class ResponsesRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/responses/create background: Optional[bool] = False - include: Optional[list[ - Literal[ - "code_interpreter_call.outputs", - "computer_call_output.output.image_url", - "file_search_call.results", - "message.input_image.image_url", - "message.output_text.logprobs", - "reasoning.encrypted_content", - ], - ]] = None + include: Optional[ + list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ] + ] = None input: Union[str, list[ResponseInputOutputItem]] instructions: Optional[str] = None max_output_tokens: Optional[int] = None @@ -299,8 +318,7 @@ class ResponsesRequest(OpenAIBaseModel): previous_response_id: Optional[str] = None prompt: Optional[ResponsePrompt] = None reasoning: Optional[Reasoning] = None - service_tier: Literal["auto", "default", "flex", "scale", - "priority"] = "auto" + service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto" store: Optional[bool] = True stream: Optional[bool] = False temperature: Optional[float] = None @@ -318,7 +336,8 @@ class ResponsesRequest(OpenAIBaseModel): description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -329,7 +348,8 @@ class ResponsesRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) cache_salt: Optional[str] = Field( default=None, @@ -339,14 +359,18 @@ class ResponsesRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) enable_response_messages: bool = Field( default=False, description=( "Dictates whether or not to return messages as part of the " "response object. Currently only supported for non-streaming " - "non-background and gpt-oss only. ")) + "non-background and gpt-oss only. " + ), + ) # --8<-- [end:responses-extra-params] _DEFAULT_SAMPLING_PARAMS = { @@ -367,20 +391,25 @@ def to_sampling_params( default_sampling_params = default_sampling_params or {} if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) stop_token_ids = default_sampling_params.get("stop_token_ids") # Structured output structured_outputs = None if self.text is not None and self.text.format is not None: response_format = self.text.format - if (response_format.type == "json_schema" - and response_format.schema_ is not None): + if ( + response_format.type == "json_schema" + and response_format.schema_ is not None + ): structured_outputs = StructuredOutputsParams( - json=response_format.schema_) + json=response_format.schema_ + ) elif response_format.type == "json_object": raise NotImplementedError("json_object is not supported") @@ -389,11 +418,11 @@ def to_sampling_params( temperature=temperature, top_p=top_p, max_tokens=max_tokens, - logprobs=self.top_logprobs - if self.is_include_output_logprobs() else None, + logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, stop_token_ids=stop_token_ids, - output_kind=(RequestOutputKind.DELTA - if self.stream else RequestOutputKind.FINAL_ONLY), + output_kind=( + RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY + ), structured_outputs=structured_outputs, ) @@ -401,17 +430,17 @@ def is_include_output_logprobs(self) -> bool: """Check if the request includes output logprobs.""" if self.include is None: return False - return isinstance( - self.include, - list) and "message.output_text.logprobs" in self.include + return ( + isinstance(self.include, list) + and "message.output_text.logprobs" in self.include + ) @model_validator(mode="before") def validate_background(cls, data): if not data.get("background"): return data if not data.get("store", True): - raise ValueError( - "background can only be used when `store` is true") + raise ValueError("background can only be used when `store` is true") return data @model_validator(mode="before") @@ -426,11 +455,12 @@ def check_cache_salt_support(cls, data): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -445,8 +475,8 @@ class ChatCompletionRequest(OpenAIBaseModel): top_logprobs: Optional[int] = 0 max_tokens: Optional[int] = Field( default=None, - deprecated= - 'max_tokens is deprecated in favor of the max_completion_tokens field') + deprecated="max_tokens is deprecated in favor of the max_completion_tokens field", + ) max_completion_tokens: Optional[int] = None n: Optional[int] = 1 presence_penalty: Optional[float] = 0.0 @@ -458,12 +488,14 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = None top_p: Optional[float] = None tools: Optional[list[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[ - Literal["none"], - Literal["auto"], - Literal["required"], - ChatCompletionNamedToolChoiceParam, - ]] = "none" + tool_choice: Optional[ + Union[ + Literal["none"], + Literal["auto"], + Literal["required"], + ChatCompletionNamedToolChoiceParam, + ] + ] = "none" reasoning_effort: Optional[Literal["low", "medium", "high"]] = None thinking_token_budget: Optional[int] = None include_reasoning: bool = True @@ -496,23 +528,26 @@ class ChatCompletionRequest(OpenAIBaseModel): default=False, description=( "If true, the new message will be prepended with the last message " - "if they belong to the same role."), + "if they belong to the same role." + ), ) add_generation_prompt: bool = Field( default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) continue_final_message: bool = Field( default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), ) add_special_tokens: bool = Field( default=False, @@ -521,16 +556,18 @@ class ChatCompletionRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) documents: Optional[list[dict[str, str]]] = Field( default=None, - description= - ("A list of dicts representing documents that will be accessible to " - "the model if it is performing RAG (retrieval-augmented generation)." - " If the template does not support RAG, this argument will have no " - "effect. We recommend that each document should be a dict containing " - "\"title\" and \"text\" keys."), + description=( + "A list of dicts representing documents that will be accessible to " + "the model if it is performing RAG (retrieval-augmented generation)." + " If the template does not support RAG, this argument will have no " + "effect. We recommend that each document should be a dict containing " + '"title" and "text" keys.' + ), ) chat_template: Optional[str] = Field( default=None, @@ -538,13 +575,15 @@ class ChatCompletionRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -559,42 +598,48 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "`guided_json` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `json` to `structured_outputs` instead."), + "Please pass `json` to `structured_outputs` instead." + ), ) guided_regex: Optional[str] = Field( default=None, description=( "`guided_regex` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `regex` to `structured_outputs` instead."), + "Please pass `regex` to `structured_outputs` instead." + ), ) guided_choice: Optional[list[str]] = Field( default=None, description=( "`guided_choice` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `choice` to `structured_outputs` instead."), + "Please pass `choice` to `structured_outputs` instead." + ), ) guided_grammar: Optional[str] = Field( default=None, description=( "`guided_grammar` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `grammar` to `structured_outputs` instead."), + "Please pass `grammar` to `structured_outputs` instead." + ), ) structural_tag: Optional[str] = Field( default=None, description=( "`structural_tag` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `structural_tag` to `structured_outputs` instead."), + "Please pass `structural_tag` to `structured_outputs` instead." + ), ) guided_decoding_backend: Optional[str] = Field( default=None, description=( "`guided_decoding_backend` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please remove it from your request."), + "Please remove it from your request." + ), ) guided_whitespace_pattern: Optional[str] = Field( default=None, @@ -609,14 +654,16 @@ class ChatCompletionRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) logits_processors: Optional[LogitsProcessors] = Field( default=None, @@ -628,13 +675,17 @@ class ChatCompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}." + ), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified." + ), + ) return_token_ids: Optional[bool] = Field( default=None, description=( @@ -642,7 +693,9 @@ class ChatCompletionRequest(OpenAIBaseModel): "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens.")) + "need to map generated text back to input tokens." + ), + ) cache_salt: Optional[str] = Field( default=None, description=( @@ -651,15 +704,20 @@ class ChatCompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:chat-completion-extra-params] @@ -674,13 +732,13 @@ class ChatCompletionRequest(OpenAIBaseModel): } def to_beam_search_params( - self, max_tokens: int, - default_sampling_params: dict) -> BeamSearchParams: - + self, max_tokens: int, default_sampling_params: dict + ) -> BeamSearchParams: n = self.n if self.n is not None else 1 if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) return BeamSearchParams( beam_width=n, @@ -697,7 +755,6 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: dict, ) -> SamplingParams: - # Default parameters if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( @@ -706,16 +763,20 @@ def to_sampling_params( ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -754,10 +815,10 @@ def to_sampling_params( elif response_format.type == "structural_tag": structural_tag = response_format assert structural_tag is not None and isinstance( - structural_tag, StructuralTagResponseFormat) + structural_tag, StructuralTagResponseFormat + ) s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structured_outputs.structural_tag = json.dumps( - s_tag_obj) + self.structured_outputs.structural_tag = json.dumps(s_tag_obj) # Set structured output params for tool calling if json_schema_from_tool is not None: @@ -787,15 +848,17 @@ def to_sampling_params( min_tokens=self.min_tokens, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, - logits_processors=get_logits_processors(self.logits_processors, - logits_processor_pattern), + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), include_stop_str_in_output=self.include_stop_str_in_output, truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, - bad_words= self.bad_words, + bad_words=self.bad_words, thinking_token_budget=self.thinking_token_budget, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, @@ -811,8 +874,7 @@ def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: tool_name = self.tool_choice.function.name tools = {tool.function.name: tool.function for tool in self.tools} if tool_name not in tools: - raise ValueError( - f"Tool '{tool_name}' has not been passed in `tools`.") + raise ValueError(f"Tool '{tool_name}' has not been passed in `tools`.") tool = tools[tool_name] return tool.parameters @@ -824,37 +886,31 @@ def _get_json_schema_from_tool(self) -> Optional[Union[str, dict]]: def get_tool_schema(tool: ChatCompletionToolsParam) -> dict: return { "properties": { - "name": { - "type": "string", - "enum": [tool.function.name] - }, + "name": {"type": "string", "enum": [tool.function.name]}, # parameters are always generated as '{}' in the final # output if they are missing from the request # (i.e. are None or '{}') so the schema is # updated to produce an empty object in that case "parameters": tool.function.parameters - if tool.function.parameters else { - "type": "object", - "properties": {} - } + if tool.function.parameters + else {"type": "object", "properties": {}}, }, - "required": ["name", "parameters"] + "required": ["name", "parameters"], } - def get_tool_schema_defs( - tools: list[ChatCompletionToolsParam]) -> dict: + def get_tool_schema_defs(tools: list[ChatCompletionToolsParam]) -> dict: all_defs = dict[str, dict[str, Any]]() for tool in tools: if tool.function.parameters is None: continue defs = tool.function.parameters.pop("$defs", {}) for def_name, def_schema in defs.items(): - if def_name in all_defs and all_defs[ - def_name] != def_schema: + if def_name in all_defs and all_defs[def_name] != def_schema: raise ValueError( f"Tool definition '{def_name}' has " "multiple schemas, which is not " - "supported.") + "supported." + ) else: all_defs[def_name] = def_schema return all_defs @@ -864,8 +920,8 @@ def get_tool_schema_defs( "minItems": 1, "items": { "type": "object", - "anyOf": [get_tool_schema(tool) for tool in self.tools] - } + "anyOf": [get_tool_schema(tool) for tool in self.tools], + }, } json_schema_defs = get_tool_schema_defs(self.tools) if json_schema_defs: @@ -878,8 +934,7 @@ def get_tool_schema_defs( @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -887,24 +942,22 @@ def validate_stream_options(cls, data): @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 - or prompt_logprobs == -1): + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise ValueError( - "`prompt_logprobs` are not available when `stream=True`.") + "`prompt_logprobs` are not available when `stream=True`." + ) if prompt_logprobs < 0 and prompt_logprobs != -1: - raise ValueError( - "`prompt_logprobs` must be a positive value or -1.") + raise ValueError("`prompt_logprobs` must be a positive value or -1.") if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError("`prompt_logprobs=-1` is only supported with " - "vLLM engine V1.") + raise ValueError( + "`prompt_logprobs=-1` is only supported with vLLM engine V1." + ) if (top_logprobs := data.get("top_logprobs")) is not None: if top_logprobs < 0 and top_logprobs != -1: - raise ValueError( - "`top_logprobs` must be a positive value or -1.") + raise ValueError("`top_logprobs` must be a positive value or -1.") - if (top_logprobs == -1 - or top_logprobs > 0) and not data.get("logprobs"): + if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"): raise ValueError( "when using `top_logprobs`, `logprobs` must be set to true." ) @@ -920,30 +973,32 @@ def check_structured_outputs_count(cls, data): if data.get("structured_outputs", None) is None: return data - structured_outputs_kwargs = data['structured_outputs'] + structured_outputs_kwargs = data["structured_outputs"] count = sum( structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice")) + for k in ("json", "regex", "choice") + ) # you can only use one kind of constraints for structured outputs if count > 1: raise ValueError( "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice').") + "outputs ('json', 'regex' or 'choice')." + ) # you can only either use structured outputs or tools, not both if count > 1 and data.get("tool_choice", "none") not in ( - "none", - "auto", - "required", + "none", + "auto", + "required", ): raise ValueError( "You can only either use constraints for structured outputs " - "or tools, not both.") + "or tools, not both." + ) return data @model_validator(mode="before") @classmethod def check_tool_usage(cls, data): - # if "tool_choice" is not specified but tools are provided, # default to "auto" tool_choice if "tool_choice" not in data and data.get("tools"): @@ -955,52 +1010,58 @@ def check_tool_usage(cls, data): # if "tool_choice" is specified -- validation if "tool_choice" in data and data["tool_choice"] is not None: - # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: - raise ValueError( - "When using `tool_choice`, `tools` must be set.") + raise ValueError("When using `tool_choice`, `tools` must be set.") # make sure that tool choice is either a named tool # OR that it's set to "auto" or "required" - if data["tool_choice"] not in [ - "auto", "required" - ] and not isinstance(data["tool_choice"], dict): + if data["tool_choice"] not in ["auto", "required"] and not isinstance( + data["tool_choice"], dict + ): raise ValueError( - f'Invalid value for `tool_choice`: {data["tool_choice"]}! '\ - 'Only named tools, "none", "auto" or "required" '\ - 'are supported.' + f"Invalid value for `tool_choice`: {data['tool_choice']}! " + 'Only named tools, "none", "auto" or "required" ' + "are supported." ) # if tool_choice is "required" but the "tools" list is empty, # override the data to behave like "none" to align with # OpenAI’s behavior. - if data["tool_choice"] == "required" and isinstance( - data["tools"], list) and len(data["tools"]) == 0: + if ( + data["tool_choice"] == "required" + and isinstance(data["tools"], list) + and len(data["tools"]) == 0 + ): data["tool_choice"] = "none" del data["tools"] return data # ensure that if "tool_choice" is specified as an object, # it matches a valid tool - correct_usage_message = 'Correct usage: `{"type": "function",' \ + correct_usage_message = ( + 'Correct usage: `{"type": "function",' ' "function": {"name": "my_function"}}`' + ) if isinstance(data["tool_choice"], dict): valid_tool = False function = data["tool_choice"].get("function") if not isinstance(function, dict): raise ValueError( f"Invalid value for `function`: `{function}` in " - f"`tool_choice`! {correct_usage_message}") + f"`tool_choice`! {correct_usage_message}" + ) if "name" not in function: - raise ValueError(f"Expected field `name` in `function` in " - f"`tool_choice`! {correct_usage_message}") + raise ValueError( + f"Expected field `name` in `function` in " + f"`tool_choice`! {correct_usage_message}" + ) function_name = function["name"] - if not isinstance(function_name, - str) or len(function_name) == 0: + if not isinstance(function_name, str) or len(function_name) == 0: raise ValueError( f"Invalid `name` in `function`: `{function_name}`" - f" in `tool_choice`! {correct_usage_message}") + f" in `tool_choice`! {correct_usage_message}" + ) for tool in data["tools"]: if tool["function"]["name"] == function_name: valid_tool = True @@ -1008,16 +1069,18 @@ def check_tool_usage(cls, data): if not valid_tool: raise ValueError( "The tool specified in `tool_choice` does not match any" - " of the specified `tools`") + " of the specified `tools`" + ) return data @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data @model_validator(mode="before") @@ -1027,11 +1090,12 @@ def check_cache_salt_support(cls, data): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1080,7 +1144,8 @@ class CompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) response_format: Optional[AnyResponseFormat] = Field( default=None, @@ -1099,35 +1164,40 @@ class CompletionRequest(OpenAIBaseModel): description=( "`guided_json` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `json` to `structured_outputs` instead."), + "Please pass `json` to `structured_outputs` instead." + ), ) guided_regex: Optional[str] = Field( default=None, description=( "`guided_regex` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `regex` to `structured_outputs` instead."), + "Please pass `regex` to `structured_outputs` instead." + ), ) guided_choice: Optional[list[str]] = Field( default=None, description=( "`guided_choice` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `choice` to `structured_outputs` instead."), + "Please pass `choice` to `structured_outputs` instead." + ), ) guided_grammar: Optional[str] = Field( default=None, description=( "`guided_grammar` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please pass `grammar` to `structured_outputs` instead."), + "Please pass `grammar` to `structured_outputs` instead." + ), ) guided_decoding_backend: Optional[str] = Field( default=None, description=( "`guided_decoding_backend` is deprecated. " "This will be removed in v0.12.0 or v1.0.0, whichever is soonest. " - "Please remove it from your request."), + "Please remove it from your request." + ), ) guided_whitespace_pattern: Optional[str] = Field( default=None, @@ -1142,14 +1212,16 @@ class CompletionRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) logits_processors: Optional[LogitsProcessors] = Field( default=None, @@ -1161,14 +1233,18 @@ class CompletionRequest(OpenAIBaseModel): "'args' and 'kwargs' fields containing positional and keyword " "arguments. For example: {'qualname': " "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}.")) + "{'param': 'value'}}." + ), + ) return_tokens_as_token_ids: Optional[bool] = Field( default=None, description=( "If specified with 'logprobs', tokens are represented " " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified.")) + "that are not JSON-encodable can be identified." + ), + ) return_token_ids: Optional[bool] = Field( default=None, description=( @@ -1176,7 +1252,9 @@ class CompletionRequest(OpenAIBaseModel): "generated text. In streaming mode, prompt_token_ids is included " "only in the first chunk, and token_ids contains the delta tokens " "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens.")) + "need to map generated text back to input tokens." + ), + ) cache_salt: Optional[str] = Field( default=None, @@ -1186,16 +1264,21 @@ class CompletionRequest(OpenAIBaseModel): "environments. The salt should be random, protected from " "access by 3rd parties, and long enough to be " "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit). Not supported by vLLM engine V0.")) + "to 256 bit). Not supported by vLLM engine V0." + ), + ) kv_transfer_params: Optional[dict[str, Any]] = Field( default=None, - description="KVTransfer parameters used for disaggregated serving.") + description="KVTransfer parameters used for disaggregated serving.", + ) vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:completion-extra-params] @@ -1214,7 +1297,6 @@ def to_beam_search_params( max_tokens: int, default_sampling_params: Optional[dict] = None, ) -> BeamSearchParams: - if default_sampling_params is None: default_sampling_params = {} n = self.n if self.n is not None else 1 @@ -1237,7 +1319,6 @@ def to_sampling_params( logits_processor_pattern: Optional[str], default_sampling_params: Optional[dict] = None, ) -> SamplingParams: - if default_sampling_params is None: default_sampling_params = {} @@ -1249,16 +1330,20 @@ def to_sampling_params( ) if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) prompt_logprobs = self.prompt_logprobs if prompt_logprobs is None and self.echo: @@ -1279,9 +1364,11 @@ def to_sampling_params( if len(kwargs) > 0: self.structured_outputs = StructuredOutputsParams(**kwargs) - if (self.structured_outputs is not None - and self.response_format is not None - and self.response_format.type == "json_object"): + if ( + self.structured_outputs is not None + and self.response_format is not None + and self.response_format.type == "json_object" + ): self.structured_outputs.json_object = True extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} @@ -1309,16 +1396,18 @@ def to_sampling_params( skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - logits_processors=get_logits_processors(self.logits_processors, - logits_processor_pattern), + logits_processors=get_logits_processors( + self.logits_processors, logits_processor_pattern + ), truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA if self.stream \ - else RequestOutputKind.FINAL_ONLY, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, structured_outputs=self.structured_outputs, logit_bias=self.logit_bias, allowed_token_ids=self.allowed_token_ids, extra_args=extra_args or None, - ) + ) @model_validator(mode="before") @classmethod @@ -1326,31 +1415,33 @@ def check_structured_outputs_count(cls, data): if data.get("structured_outputs", None) is None: return data - structured_outputs_kwargs = data['structured_outputs'] + structured_outputs_kwargs = data["structured_outputs"] count = sum( structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice")) + for k in ("json", "regex", "choice") + ) if count > 1: raise ValueError( "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice').") + "outputs ('json', 'regex' or 'choice')." + ) return data @model_validator(mode="before") @classmethod def check_logprobs(cls, data): if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 - or prompt_logprobs == -1): + if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): raise ValueError( - "`prompt_logprobs` are not available when `stream=True`.") + "`prompt_logprobs` are not available when `stream=True`." + ) if prompt_logprobs < 0 and prompt_logprobs != -1: - raise ValueError( - "`prompt_logprobs` must be a positive value or -1.") + raise ValueError("`prompt_logprobs` must be a positive value or -1.") if prompt_logprobs == -1 and not envs.VLLM_USE_V1: - raise ValueError("`prompt_logprobs=-1` is only supported with " - "vLLM engine V1.") + raise ValueError( + "`prompt_logprobs=-1` is only supported with vLLM engine V1." + ) if (logprobs := data.get("logprobs")) is not None and logprobs < 0: raise ValueError("`logprobs` must be a positive value.") @@ -1360,8 +1451,7 @@ def check_logprobs(cls, data): @classmethod def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -1371,11 +1461,10 @@ def validate_prompt_and_prompt_embeds(cls, data): prompt = data.get("prompt") prompt_embeds = data.get("prompt_embeds") - prompt_is_empty = (prompt is None - or (isinstance(prompt, str) and prompt == "")) - embeds_is_empty = (prompt_embeds is None - or (isinstance(prompt_embeds, list) - and len(prompt_embeds) == 0)) + prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "") + embeds_is_empty = prompt_embeds is None or ( + isinstance(prompt_embeds, list) and len(prompt_embeds) == 0 + ) if prompt_is_empty and embeds_is_empty: raise ValueError( @@ -1391,11 +1480,12 @@ def check_cache_salt_support(cls, data): if not envs.VLLM_USE_V1: raise ValueError( "Parameter 'cache_salt' is not supported with " - "this instance of vLLM, which uses engine V0.") - if not isinstance(data["cache_salt"], - str) or not data["cache_salt"]: - raise ValueError("Parameter 'cache_salt' must be a " - "non-empty string if provided.") + "this instance of vLLM, which uses engine V0." + ) + if not isinstance(data["cache_salt"], str) or not data["cache_salt"]: + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) return data @@ -1414,21 +1504,24 @@ class EmbeddingCompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) priority: int = Field( default=0, description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) normalize: Optional[bool] = None @@ -1438,7 +1531,8 @@ def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, dimensions=self.dimensions, - normalize=self.normalize) + normalize=self.normalize, + ) class EmbeddingChatRequest(OpenAIBaseModel): @@ -1453,10 +1547,11 @@ class EmbeddingChatRequest(OpenAIBaseModel): # --8<-- [start:chat-embedding-extra-params] add_generation_prompt: bool = Field( default=False, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) add_special_tokens: bool = Field( @@ -1466,7 +1561,8 @@ class EmbeddingChatRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) chat_template: Optional[str] = Field( default=None, @@ -1474,13 +1570,15 @@ class EmbeddingChatRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -1491,14 +1589,16 @@ class EmbeddingChatRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) request_id: str = Field( default_factory=lambda: f"{random_uuid()}", description=( "The request_id related to this request. If the caller does " "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response."), + "through out the inference process and return in response." + ), ) normalize: Optional[bool] = None # --8<-- [end:chat-embedding-extra-params] @@ -1506,17 +1606,19 @@ class EmbeddingChatRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, dimensions=self.dimensions, - normalize=self.normalize) + normalize=self.normalize, + ) EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] @@ -1548,7 +1650,6 @@ def to_pooling_params(self): class IOProcessorResponse(OpenAIBaseModel, Generic[T]): - request_id: Optional[str] = None """ The request_id associated with this response @@ -1562,8 +1663,7 @@ class IOProcessorResponse(OpenAIBaseModel, Generic[T]): """ -PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, - IOProcessorRequest] +PoolingRequest = Union[PoolingCompletionRequest, PoolingChatRequest, IOProcessorRequest] class ScoreRequest(OpenAIBaseModel): @@ -1584,7 +1684,8 @@ class ScoreRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1594,7 +1695,8 @@ class ScoreRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation) + activation=self.activation, + ) class RerankRequest(OpenAIBaseModel): @@ -1616,7 +1718,8 @@ class RerankRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1626,7 +1729,8 @@ class RerankRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation) + activation=self.activation, + ) class RerankDocument(BaseModel): @@ -1655,8 +1759,7 @@ class CompletionLogProbs(OpenAIBaseModel): text_offset: list[int] = Field(default_factory=list) token_logprobs: list[Optional[float]] = Field(default_factory=list) tokens: list[str] = Field(default_factory=list) - top_logprobs: list[Optional[dict[str, - float]]] = Field(default_factory=list) + top_logprobs: list[Optional[dict[str, float]]] = Field(default_factory=list) class CompletionResponseChoice(OpenAIBaseModel): @@ -1669,7 +1772,8 @@ class CompletionResponseChoice(OpenAIBaseModel): description=( "The stop string or token id that caused the completion " "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), + "including encountering the EOS token" + ), ) token_ids: Optional[list[int]] = None # For response prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None @@ -1682,14 +1786,16 @@ class CompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[CompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", - "priority"]] = None + service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( + None + ) system_fingerprint: Optional[str] = None usage: UsageInfo # vLLM-specific fields that are not in OpenAI spec kv_transfer_params: Optional[dict[str, Any]] = Field( - default=None, description="KVTransfer parameters.") + default=None, description="KVTransfer parameters." + ) class CompletionResponseStreamChoice(OpenAIBaseModel): @@ -1702,7 +1808,8 @@ class CompletionResponseStreamChoice(OpenAIBaseModel): description=( "The stop string or token id that caused the completion " "to stop, None if the completion finished for some other reason " - "including encountering the EOS token"), + "including encountering the EOS token" + ), ) # not part of the OpenAI spec but for tracing the tokens # prompt tokens is put into choice to align with CompletionResponseChoice @@ -1776,7 +1883,8 @@ class ClassificationRequest(OpenAIBaseModel): description=( "The priority of the request (lower means earlier handling; " "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling."), + "if the served model does not use priority scheduling." + ), ) activation: Optional[bool] = None @@ -1786,7 +1894,8 @@ class ClassificationRequest(OpenAIBaseModel): def to_pooling_params(self): return PoolingParams( truncate_prompt_tokens=self.truncate_prompt_tokens, - activation=self.activation) + activation=self.activation, + ) class ClassificationData(OpenAIBaseModel): @@ -1890,8 +1999,9 @@ class ChatCompletionResponse(OpenAIBaseModel): created: int = Field(default_factory=lambda: int(time.time())) model: str choices: list[ChatCompletionResponseChoice] - service_tier: Optional[Literal["auto", "default", "flex", "scale", - "priority"]] = None + service_tier: Optional[Literal["auto", "default", "flex", "scale", "priority"]] = ( + None + ) system_fingerprint: Optional[str] = None usage: UsageInfo @@ -1899,7 +2009,8 @@ class ChatCompletionResponse(OpenAIBaseModel): prompt_logprobs: Optional[list[Optional[dict[int, Logprob]]]] = None prompt_token_ids: Optional[list[int]] = None kv_transfer_params: Optional[dict[str, Any]] = Field( - default=None, description="KVTransfer parameters.") + default=None, description="KVTransfer parameters." + ) class DeltaMessage(OpenAIBaseModel): @@ -2009,10 +2120,9 @@ def from_request( input_messages: Optional[list[ChatCompletionMessageParam]] = None, output_messages: Optional[list[ChatCompletionMessageParam]] = None, ) -> "ResponsesResponse": - incomplete_details: Optional[IncompleteDetails] = None - if status == 'incomplete': - incomplete_details = IncompleteDetails(reason='max_output_tokens') + if status == "incomplete": + incomplete_details = IncompleteDetails(reason="max_output_tokens") # TODO: implement the other reason for incomplete_details, # which is content_filter # incomplete_details = IncompleteDetails(reason='content_filter') @@ -2127,8 +2237,9 @@ class ResponseInProgressEvent(OpenAIResponseInProgressEvent): ResponseCodeInterpreterCallCompletedEvent, ] -BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest, - ScoreRequest, RerankRequest] +BatchRequestInputBody = Union[ + ChatCompletionRequest, EmbeddingRequest, ScoreRequest, RerankRequest +] class BatchRequestInput(OpenAIBaseModel): @@ -2153,7 +2264,7 @@ class BatchRequestInput(OpenAIBaseModel): # The parameters of the request. body: BatchRequestInputBody - @field_validator('body', mode='plain') + @field_validator("body", mode="plain") @classmethod def check_type_for_url(cls, value: Any, info: ValidationInfo): # Use url to disambiguate models @@ -2177,8 +2288,9 @@ class BatchResponseData(OpenAIBaseModel): request_id: str # The body of the response. - body: Optional[Union[ChatCompletionResponse, EmbeddingResponse, - ScoreResponse, RerankResponse]] = None + body: Optional[ + Union[ChatCompletionResponse, EmbeddingResponse, ScoreResponse, RerankResponse] + ] = None class BatchRequestOutput(OpenAIBaseModel): @@ -2207,12 +2319,14 @@ class TokenizeCompletionRequest(OpenAIBaseModel): default=True, description=( "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt."), + "the prompt." + ), ) return_token_strs: Optional[bool] = Field( default=False, - description=("If true, also return the token strings " - "corresponding to the token ids."), + description=( + "If true, also return the token strings corresponding to the token ids." + ), ) @@ -2222,24 +2336,27 @@ class TokenizeChatRequest(OpenAIBaseModel): add_generation_prompt: bool = Field( default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), + description=( + "If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model." + ), ) return_token_strs: Optional[bool] = Field( default=False, - description=("If true, also return the token strings " - "corresponding to the token ids."), + description=( + "If true, also return the token strings corresponding to the token ids." + ), ) continue_final_message: bool = Field( default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), + description=( + "If this is set, the chat will be formatted so that the final " + "message in the chat is open-ended, without any EOS tokens. The " + "model will continue this message rather than starting a new one. " + 'This allows you to "prefill" part of the model\'s response for it. ' + "Cannot be used at the same time as `add_generation_prompt`." + ), ) add_special_tokens: bool = Field( default=False, @@ -2248,7 +2365,8 @@ class TokenizeChatRequest(OpenAIBaseModel): "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to false (as is the " - "default)."), + "default)." + ), ) chat_template: Optional[str] = Field( default=None, @@ -2256,13 +2374,15 @@ class TokenizeChatRequest(OpenAIBaseModel): "A Jinja template to use for this conversion. " "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " - "does not define one."), + "does not define one." + ), ) chat_template_kwargs: Optional[dict[str, Any]] = Field( default=None, description=( "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template."), + "Will be accessible by the chat template." + ), ) mm_processor_kwargs: Optional[dict[str, Any]] = Field( default=None, @@ -2276,10 +2396,11 @@ class TokenizeChatRequest(OpenAIBaseModel): @model_validator(mode="before") @classmethod def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get( - "add_generation_prompt"): - raise ValueError("Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True.") + if data.get("continue_final_message") and data.get("add_generation_prompt"): + raise ValueError( + "Cannot set both `continue_final_message` and " + "`add_generation_prompt` to True." + ) return data @@ -2323,8 +2444,7 @@ class UnloadLoRAAdapterRequest(BaseModel): ## Protocols for Audio -AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", - "vtt"] +AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] class TranscriptionRequest(OpenAIBaseModel): @@ -2366,7 +2486,8 @@ class TranscriptionRequest(OpenAIBaseModel): ## TODO (varun) : Support if set to 0, certain thresholds are met !! timestamp_granularities: list[Literal["word", "segment"]] = Field( - alias="timestamp_granularities[]", default=[]) + alias="timestamp_granularities[]", default=[] + ) """The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. @@ -2386,8 +2507,10 @@ class TranscriptionRequest(OpenAIBaseModel): vllm_xargs: Optional[dict[str, Union[str, int, float]]] = Field( default=None, - description=("Additional request parameters with string or " - "numeric values, used by custom extensions."), + description=( + "Additional request parameters with string or " + "numeric values, used by custom extensions." + ), ) # --8<-- [end:transcription-extra-params] @@ -2444,10 +2567,8 @@ class TranscriptionRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2456,35 +2577,42 @@ def to_sampling_params( # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) if (top_p := self.top_p) is None: top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] + ) if (top_k := self.top_k) is None: top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"]) + "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] + ) if (min_p := self.min_p) is None: min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"]) + "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] + ) if (repetition_penalty := self.repetition_penalty) is None: repetition_penalty = default_sampling_params.get( "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"]) - - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - top_p=top_p, - top_k=top_k, - min_p=min_p, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - presence_penalty=self.presence_penalty, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY, - extra_args=self.vllm_xargs) + self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], + ) + + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + top_p=top_p, + top_k=top_k, + min_p=min_p, + frequency_penalty=self.frequency_penalty, + repetition_penalty=repetition_penalty, + presence_penalty=self.presence_penalty, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + extra_args=self.vllm_xargs, + ) @model_validator(mode="before") @classmethod @@ -2498,8 +2626,7 @@ def validate_transcription_request(cls, data): stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data @@ -2677,10 +2804,8 @@ class TranslationRequest(OpenAIBaseModel): } def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: Optional[dict] = None) -> SamplingParams: - + self, default_max_tokens: int, default_sampling_params: Optional[dict] = None + ) -> SamplingParams: max_tokens = default_max_tokens if default_sampling_params is None: @@ -2688,14 +2813,17 @@ def to_sampling_params( # Default parameters if (temperature := self.temperature) is None: temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] + ) - return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - output_kind=RequestOutputKind.DELTA - if self.stream \ - else RequestOutputKind.FINAL_ONLY) + return SamplingParams.from_optional( + temperature=temperature, + max_tokens=max_tokens, + seed=self.seed, + output_kind=RequestOutputKind.DELTA + if self.stream + else RequestOutputKind.FINAL_ONLY, + ) @model_validator(mode="before") @classmethod @@ -2703,8 +2831,7 @@ def validate_stream_options(cls, data): stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] stream = data.get("stream", False) if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - raise ValueError( - "Stream options can only be defined when `stream=True`.") + raise ValueError("Stream options can only be defined when `stream=True`.") return data diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index ce4418b83bc9..9be23e0cef2b 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Sampling parameters for text generation.""" + import copy import warnings from dataclasses import field @@ -50,26 +51,32 @@ class StructuredOutputsParams: def __post_init__(self): """Validate that some fields are mutually exclusive.""" - count = sum([ - self.json is not None, self.regex is not None, self.choice - is not None, self.grammar is not None, self.json_object is not None - ]) + count = sum( + [ + self.json is not None, + self.regex is not None, + self.choice is not None, + self.grammar is not None, + self.json_object is not None, + ] + ) if count > 1: raise ValueError( "You can only use one kind of structured outputs constraint " - f"but multiple are specified: {self.__dict__}") + f"but multiple are specified: {self.__dict__}" + ) @dataclass class GuidedDecodingParams(StructuredOutputsParams): - def __post_init__(self): warnings.warn( "GuidedDecodingParams is deprecated. This will be removed in " "v0.12.0 or v1.0.0, which ever is soonest. Please use " "StructuredOutputsParams instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) return super().__post_init__() @@ -83,10 +90,11 @@ class RequestOutputKind(Enum): class SamplingParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): # type: ignore[call-arg] """Sampling parameters for text generation. Overall, we follow the sampling parameters from the OpenAI text completion @@ -178,8 +186,7 @@ class SamplingParams( optionally prompt tokens as a first argument.""" include_stop_str_in_output: bool = False """Whether to include the stop strings in output text.""" - truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta(ge=-1)]] = None + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None """If set to -1, will use the truncation size supported by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is disabled.""" @@ -242,9 +249,7 @@ def from_optional( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[list[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[Annotated[int, - msgspec.Meta( - ge=-1)]] = None, + truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None, output_kind: RequestOutputKind = RequestOutputKind.CUMULATIVE, structured_outputs: Optional[StructuredOutputsParams] = None, guided_decoding: Optional[GuidedDecodingParams] = None, @@ -265,19 +270,19 @@ def from_optional( "v0.12.0 or v1.0.0, which ever is soonest. Please use " "structured_outputs instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) structured_outputs = guided_decoding guided_decoding = None return SamplingParams( n=1 if n is None else n, best_of=best_of, - presence_penalty=0.0 - if presence_penalty is None else presence_penalty, - frequency_penalty=0.0 - if frequency_penalty is None else frequency_penalty, + presence_penalty=0.0 if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 if frequency_penalty is None else frequency_penalty, repetition_penalty=1.0 - if repetition_penalty is None else repetition_penalty, + if repetition_penalty is None + else repetition_penalty, temperature=1.0 if temperature is None else temperature, top_p=1.0 if top_p is None else top_p, top_k=top_k, @@ -316,7 +321,8 @@ def __post_init__(self) -> None: if self.best_of < self.n: raise ValueError( f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + f"got n={self.n} and best_of={self.best_of}." + ) if not self._real_n: self._real_n = self.n self.n = self.best_of @@ -325,7 +331,10 @@ def __post_init__(self) -> None: logger.warning( "temperature %s is less than %s, which may cause numerical " "errors nan or inf in tensors. We have maxed it out to %s.", - self.temperature, _MAX_TEMP, _MAX_TEMP) + self.temperature, + _MAX_TEMP, + _MAX_TEMP, + ) self.temperature = max(self.temperature, _MAX_TEMP) if self.seed == -1: @@ -371,101 +380,116 @@ def __post_init__(self) -> None: "v0.12.0 or v1.0.0, which ever is soonest. Please use " "structured_outputs instead.", DeprecationWarning, - stacklevel=2) + stacklevel=2, + ) self.structured_outputs = self.guided_decoding self.guided_decoding = None def _verify_args(self) -> None: if not isinstance(self.n, int): - raise ValueError(f"n must be an int, but is of " - f"type {type(self.n)}") + raise ValueError(f"n must be an int, but is of type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if self.best_of is not None: if not isinstance(self.best_of, int): raise ValueError( - f"best_of must be an integer, got {type(self.best_of)}") + f"best_of must be an integer, got {type(self.best_of)}" + ) if self.best_of < 1: - raise ValueError( - f"best_of must be at least 1, got {self.best_of}") + raise ValueError(f"best_of must be at least 1, got {self.best_of}") if self.best_of < self.n: raise ValueError( f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + f"got n={self.n} and best_of={self.best_of}." + ) if not -2.0 <= self.presence_penalty <= 2.0: - raise ValueError("presence_penalty must be in [-2, 2], got " - f"{self.presence_penalty}.") + raise ValueError( + f"presence_penalty must be in [-2, 2], got {self.presence_penalty}." + ) if not -2.0 <= self.frequency_penalty <= 2.0: - raise ValueError("frequency_penalty must be in [-2, 2], got " - f"{self.frequency_penalty}.") + raise ValueError( + f"frequency_penalty must be in [-2, 2], got {self.frequency_penalty}." + ) if self.repetition_penalty <= 0.0: raise ValueError( "repetition_penalty must be greater than zero, got " - f"{self.repetition_penalty}.") + f"{self.repetition_penalty}." + ) if self.temperature < 0.0: raise ValueError( - f"temperature must be non-negative, got {self.temperature}.") + f"temperature must be non-negative, got {self.temperature}." + ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") # quietly accept -1 as disabled, but prefer 0 if self.top_k < -1: - raise ValueError(f"top_k must be 0 (disable), or at least 1, " - f"got {self.top_k}.") + raise ValueError( + f"top_k must be 0 (disable), or at least 1, got {self.top_k}." + ) if not isinstance(self.top_k, int): raise TypeError( - f"top_k must be an integer, got {type(self.top_k).__name__}") + f"top_k must be an integer, got {type(self.top_k).__name__}" + ) if not 0.0 <= self.min_p <= 1.0: - raise ValueError("min_p must be in [0, 1], got " - f"{self.min_p}.") + raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.") if self.max_tokens is not None and self.max_tokens < 1: - raise ValueError( - f"max_tokens must be at least 1, got {self.max_tokens}.") + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") if self.min_tokens < 0: - raise ValueError(f"min_tokens must be greater than or equal to 0, " - f"got {self.min_tokens}.") + raise ValueError( + f"min_tokens must be greater than or equal to 0, got {self.min_tokens}." + ) if self.max_tokens is not None and self.min_tokens > self.max_tokens: raise ValueError( f"min_tokens must be less than or equal to " - f"max_tokens={self.max_tokens}, got {self.min_tokens}.") - if (self.logprobs is not None and self.logprobs != -1 - and self.logprobs < 0): + f"max_tokens={self.max_tokens}, got {self.min_tokens}." + ) + if self.logprobs is not None and self.logprobs != -1 and self.logprobs < 0: raise ValueError( - f"logprobs must be non-negative or -1, got {self.logprobs}.") - if (self.prompt_logprobs is not None and self.prompt_logprobs != -1 - and self.prompt_logprobs < 0): + f"logprobs must be non-negative or -1, got {self.logprobs}." + ) + if ( + self.prompt_logprobs is not None + and self.prompt_logprobs != -1 + and self.prompt_logprobs < 0 + ): raise ValueError( f"prompt_logprobs must be non-negative or -1, got " - f"{self.prompt_logprobs}.") - if (self.truncate_prompt_tokens is not None - and (self.truncate_prompt_tokens == 0 - or self.truncate_prompt_tokens < -1)): + f"{self.prompt_logprobs}." + ) + if self.truncate_prompt_tokens is not None and ( + self.truncate_prompt_tokens == 0 or self.truncate_prompt_tokens < -1 + ): raise ValueError( f"truncate_prompt_tokens must be an integer >= 1 or -1, " - f"got {self.truncate_prompt_tokens}") + f"got {self.truncate_prompt_tokens}" + ) assert isinstance(self.stop_token_ids, list) if not all(isinstance(st_id, int) for st_id in self.stop_token_ids): - raise ValueError(f"stop_token_ids must contain only integers, " - f"got {self.stop_token_ids}.") + raise ValueError( + f"stop_token_ids must contain only integers, got {self.stop_token_ids}." + ) assert isinstance(self.stop, list) if any(not stop_str for stop_str in self.stop): raise ValueError("stop cannot contain an empty string.") if self.stop and not self.detokenize: raise ValueError( "stop strings are only supported when detokenize is True. " - "Set detokenize=True to use stop.") + "Set detokenize=True to use stop." + ) if self.best_of != self._real_n and self.output_kind == ( - RequestOutputKind.DELTA): + RequestOutputKind.DELTA + ): raise ValueError("best_of must equal n to use output_kind=DELTA") def _verify_greedy_sampling(self) -> None: if self.n > 1: - raise ValueError("n must be 1 when using greedy sampling, " - f"got {self.n}.") + raise ValueError(f"n must be 1 when using greedy sampling, got {self.n}.") def update_from_generation_config( - self, - generation_config: dict[str, Any], - model_eos_token_id: Optional[int] = None) -> None: + self, + generation_config: dict[str, Any], + model_eos_token_id: Optional[int] = None, + ) -> None: """Update if there are non-default values from generation_config""" if model_eos_token_id is not None: @@ -499,30 +523,33 @@ def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None: for add_prefix_space in [False, True]: prefix = " " if add_prefix_space else "" prompt = prefix + bad_word.lstrip() - prompt_token_ids = tokenizer.encode(text=prompt, - add_special_tokens=False) + prompt_token_ids = tokenizer.encode( + text=prompt, add_special_tokens=False + ) # If no space at the beginning # or if prefix space produces a new word token if (not add_prefix_space) or ( - add_prefix_space and prompt_token_ids[0] - != self._bad_words_token_ids[-1][0] - and len(prompt_token_ids) == len( - self._bad_words_token_ids[-1])): + add_prefix_space + and prompt_token_ids[0] != self._bad_words_token_ids[-1][0] + and len(prompt_token_ids) == len(self._bad_words_token_ids[-1]) + ): self._bad_words_token_ids.append(prompt_token_ids) invalid_token_ids = [ - token_id for bad_words_token_ids in self._bad_words_token_ids + token_id + for bad_words_token_ids in self._bad_words_token_ids for token_id in bad_words_token_ids if token_id < 0 or token_id > tokenizer.max_token_id ] if len(invalid_token_ids) > 0: raise ValueError( - f"The model vocabulary size is {tokenizer.max_token_id+1}," + f"The model vocabulary size is {tokenizer.max_token_id + 1}," f" but the following tokens" f" were specified as bad: {invalid_token_ids}." f" All token id values should be integers satisfying:" - f" 0 <= token_id <= {tokenizer.max_token_id}.") + f" 0 <= token_id <= {tokenizer.max_token_id}." + ) @cached_property def sampling_type(self) -> SamplingType: @@ -550,10 +577,14 @@ def clone(self) -> "SamplingParams": See https://github.com/vllm-project/vllm/issues/3087 """ - logit_processor_refs = None if self.logits_processors is None else { - id(lp): lp.clone() if hasattr(lp, 'clone') else lp - for lp in self.logits_processors - } + logit_processor_refs = ( + None + if self.logits_processors is None + else { + id(lp): lp.clone() if hasattr(lp, "clone") else lp + for lp in self.logits_processors + } + ) return copy.deepcopy(self, memo=logit_processor_refs) def __repr__(self) -> str: @@ -582,15 +613,18 @@ def __repr__(self) -> str: f"{self.spaces_between_special_tokens}, " f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " f"structured_outputs={self.structured_outputs}, " - f"extra_args={self.extra_args})") + f"extra_args={self.extra_args})" + ) class BeamSearchParams( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - # required for @cached_property. - dict=True): # type: ignore[call-arg] + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + # required for @cached_property. + dict=True, +): # type: ignore[call-arg] """Beam search parameters for text generation.""" + beam_width: int max_tokens: int ignore_eos: bool = False diff --git a/vllm/v1/sample/logits_processor/__init__.py b/vllm/v1/sample/logits_processor/__init__.py index 8fc41432b66d..0d4df239c868 100644 --- a/vllm/v1/sample/logits_processor/__init__.py +++ b/vllm/v1/sample/logits_processor/__init__.py @@ -14,13 +14,18 @@ from vllm.logits_process import LogitsProcessor as RequestLogitsProcessor from vllm.sampling_params import SamplingParams from vllm.v1.sample.logits_processor.builtin import ( - LogitBiasLogitsProcessor, MinPLogitsProcessor, MinTokensLogitsProcessor, - ThinkingTokenBudgetLogitsProcessor, process_dict_updates) -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) -from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder, - LogitsProcessors) + LogitBiasLogitsProcessor, + MinPLogitsProcessor, + MinTokensLogitsProcessor, + ThinkingTokenBudgetLogitsProcessor, + process_dict_updates, +) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) +from vllm.v1.sample.logits_processor.state import BatchUpdateBuilder, LogitsProcessors if TYPE_CHECKING: from vllm.config import VllmConfig @@ -29,10 +34,11 @@ # Error message when the user tries to initialize vLLM with a pooling model # and custom logitsproces -STR_POOLING_REJECTS_LOGITSPROCS = ("Pooling models do not support custom" - " logits processors.") +STR_POOLING_REJECTS_LOGITSPROCS = ( + "Pooling models do not support custom logits processors." +) -LOGITSPROCS_GROUP = 'vllm.logits_processors' +LOGITSPROCS_GROUP = "vllm.logits_processors" BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [ MinTokensLogitsProcessor, @@ -54,27 +60,29 @@ def _load_logitsprocs_plugins() -> list[type[LogitsProcessor]]: installed_logitsprocs_plugins = entry_points(group=LOGITSPROCS_GROUP) if len(installed_logitsprocs_plugins) == 0: - logger.debug("No logitsprocs plugins installed (group %s).", - LOGITSPROCS_GROUP) + logger.debug("No logitsprocs plugins installed (group %s).", LOGITSPROCS_GROUP) return [] # Load logitsprocs plugins - logger.debug("Loading installed logitsprocs plugins (group %s):", - LOGITSPROCS_GROUP) + logger.debug("Loading installed logitsprocs plugins (group %s):", LOGITSPROCS_GROUP) classes: list[type[LogitsProcessor]] = [] for entrypoint in installed_logitsprocs_plugins: try: - logger.debug("- Loading logitproc plugin entrypoint=%s target=%s", - entrypoint.name, entrypoint.value) + logger.debug( + "- Loading logitproc plugin entrypoint=%s target=%s", + entrypoint.name, + entrypoint.value, + ) classes.append(entrypoint.load()) except Exception as e: raise RuntimeError( - f"Failed to load LogitsProcessor plugin {entrypoint}") from e + f"Failed to load LogitsProcessor plugin {entrypoint}" + ) from e return classes def _load_logitsprocs_by_fqcns( - logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]] + logits_processors: Optional[Sequence[Union[str, type[LogitsProcessor]]]], ) -> list[type[LogitsProcessor]]: """Load logit processor types, identifying them by fully-qualified class names (FQCNs). @@ -99,13 +107,14 @@ def _load_logitsprocs_by_fqcns( logger.debug( "%s additional custom logits processors specified, checking whether " - "they need to be loaded.", len(logits_processors)) + "they need to be loaded.", + len(logits_processors), + ) classes: list[type[LogitsProcessor]] = [] for ldx, logitproc in enumerate(logits_processors): if isinstance(logitproc, type): - logger.debug(" - Already-loaded logit processor: %s", - logitproc.__name__) + logger.debug(" - Already-loaded logit processor: %s", logitproc.__name__) if not issubclass(logitproc, LogitsProcessor): raise ValueError( f"{logitproc.__name__} is not a subclass of LogitsProcessor" @@ -131,8 +140,7 @@ def _load_logitsprocs_by_fqcns( if not isinstance(obj, type): raise ValueError("Loaded logit processor must be a type.") if not issubclass(obj, LogitsProcessor): - raise ValueError( - f"{obj.__name__} must be a subclass of LogitsProcessor") + raise ValueError(f"{obj.__name__} must be a subclass of LogitsProcessor") classes.append(obj) return classes @@ -155,13 +163,13 @@ def _load_custom_logitsprocs( A list of all loaded logitproc types """ from vllm.platforms import current_platform + if current_platform.is_tpu(): # No logitsprocs specified by caller # TODO(andy) - vLLM V1 on TPU does not support custom logitsprocs return [] - return (_load_logitsprocs_plugins() + - _load_logitsprocs_by_fqcns(logits_processors)) + return _load_logitsprocs_plugins() + _load_logitsprocs_by_fqcns(logits_processors) def build_logitsprocs( @@ -174,23 +182,28 @@ def build_logitsprocs( if is_pooling_model: if custom_logitsprocs: raise ValueError(STR_POOLING_REJECTS_LOGITSPROCS) - logger.debug("Skipping logits processor loading because pooling models" - " do not support logits processors.") + logger.debug( + "Skipping logits processor loading because pooling models" + " do not support logits processors." + ) return LogitsProcessors() custom_logitsprocs_classes = _load_custom_logitsprocs(custom_logitsprocs) return LogitsProcessors( - ctor(vllm_config, device, is_pin_memory) for ctor in itertools.chain( - BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes)) + ctor(vllm_config, device, is_pin_memory) + for ctor in itertools.chain( + BUILTIN_LOGITS_PROCESSORS, custom_logitsprocs_classes + ) + ) class AdapterLogitsProcessor(LogitsProcessor): """Wrapper for per-request logits processors - + To wrap a specific per-request logits processor, * Subclass `AdapterLogitsProcessor` * Implement `self.is_argmax_invariant()` base-class method * Implement `self.new_req_logits_processor(params)` - + `self.__init__(vllm_config, device, is_pin_memory)` does not need to be overridden in general. However, to implement custom constructor behavior - especially any logic which operates on or stores `vllm_config`, `device`, @@ -199,8 +212,9 @@ class AdapterLogitsProcessor(LogitsProcessor): `super().__init__(vllm_config, device, is_pin_memory)` """ - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): """Subclass must invoke `super().__init__(vllm_config, device, is_pin_memory)`. @@ -236,7 +250,7 @@ def new_req_logits_processor( Returns: None if logits processor should not be applied to request; otherwise returns a `RequestLogitsProcessor` instance - + """ raise NotImplementedError @@ -257,11 +271,14 @@ def _new_state( Returns: logits processor partial[Tensor] or None - + """ if req_lp := self.new_req_logits_processor(params): - args = [prompt_ids, output_ids] if (len( - inspect.signature(req_lp).parameters) == 3) else [output_ids] + args = ( + [prompt_ids, output_ids] + if (len(inspect.signature(req_lp).parameters) == 3) + else [output_ids] + ) return partial(req_lp, *args) return None @@ -286,9 +303,17 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: __all__ = [ - "LogitsProcessor", "LogitBiasLogitsProcessor", "MinPLogitsProcessor", - "MinTokensLogitsProcessor", "BatchUpdate", "BatchUpdateBuilder", - "MoveDirectionality", "LogitsProcessors", "build_logitsprocs", - "STR_POOLING_REJECTS_LOGITSPROCS", "LOGITSPROCS_GROUP", - "AdapterLogitsProcessor", "ThinkingTokenBudgetLogitsProcessor" + "LogitsProcessor", + "LogitBiasLogitsProcessor", + "MinPLogitsProcessor", + "MinTokensLogitsProcessor", + "BatchUpdate", + "BatchUpdateBuilder", + "MoveDirectionality", + "LogitsProcessors", + "build_logitsprocs", + "STR_POOLING_REJECTS_LOGITSPROCS", + "LOGITSPROCS_GROUP", + "AdapterLogitsProcessor", + "ThinkingTokenBudgetLogitsProcessor", ] diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index f719824dbec2..625681f0a328 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -6,9 +6,11 @@ import torch from vllm import SamplingParams -from vllm.v1.sample.logits_processor.interface import (BatchUpdate, - LogitsProcessor, - MoveDirectionality) +from vllm.v1.sample.logits_processor.interface import ( + BatchUpdate, + LogitsProcessor, + MoveDirectionality, +) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -17,25 +19,24 @@ class MinPLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.min_p_count: int = 0 - self.min_p_cpu_tensor = torch.zeros((max_num_reqs, ), - dtype=torch.float32, - device="cpu", - pin_memory=is_pin_memory) + self.min_p_cpu_tensor = torch.zeros( + (max_num_reqs,), dtype=torch.float32, device="cpu", pin_memory=is_pin_memory + ) self.min_p_cpu = self.min_p_cpu_tensor.numpy() self.use_double_tensor = torch.device(device).type != "cpu" if self.use_double_tensor: # Pre-allocated device tensor - self.min_p_device: torch.Tensor = torch.empty((max_num_reqs, ), - dtype=torch.float32, - device=device) + self.min_p_device: torch.Tensor = torch.empty( + (max_num_reqs,), dtype=torch.float32, device=device + ) else: self.min_p_device = self.min_p_cpu_tensor # Current slice of the device tensor @@ -93,8 +94,7 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if self.min_p_count and (needs_update or self.min_p.shape[0] != size): self.min_p = self.min_p_device[:size] if self.use_double_tensor: - self.min_p.copy_(self.min_p_cpu_tensor[:size], - non_blocking=True) + self.min_p.copy_(self.min_p_cpu_tensor[:size], non_blocking=True) self.min_p.unsqueeze_(1) def apply(self, logits: torch.Tensor) -> torch.Tensor: @@ -104,28 +104,27 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: # Convert logits to probability distribution probability_values = torch.nn.functional.softmax(logits, dim=-1) # Calculate maximum probabilities per sequence - max_probabilities = torch.amax(probability_values, - dim=-1, - keepdim=True) + max_probabilities = torch.amax(probability_values, dim=-1, keepdim=True) # Adjust min_p adjusted_min_p = max_probabilities.mul_(self.min_p) # Identify valid tokens using threshold comparison invalid_token_mask = probability_values < adjusted_min_p # Apply mask using boolean indexing - logits[invalid_token_mask] = -float('inf') + logits[invalid_token_mask] = -float("inf") return logits class LogitBiasLogitsProcessor(LogitsProcessor): - def __init__(self, _, device: torch.device, is_pin_memory: bool): self.device = device self.pin_memory = is_pin_memory self.biases: dict[int, dict[int, float]] = {} self.bias_tensor: torch.Tensor = torch.tensor(()) - self.logits_slice = (self._device_tensor([], torch.int32), - self._device_tensor([], torch.int32)) + self.logits_slice = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """Logit bias can rebalance token probabilities and change the @@ -134,8 +133,8 @@ def is_argmax_invariant(self) -> bool: def update_state(self, batch_update: Optional[BatchUpdate]): needs_update = process_dict_updates( - self.biases, batch_update, - lambda params, _, __: params.logit_bias or None) + self.biases, batch_update, lambda params, _, __: params.logit_bias or None + ) # Update tensors if needed. if needs_update: @@ -148,15 +147,15 @@ def update_state(self, batch_update: Optional[BatchUpdate]): biases.extend(lb.values()) self.bias_tensor = self._device_tensor(biases, torch.float32) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.biases: @@ -165,20 +164,19 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class MinTokensLogitsProcessor(LogitsProcessor): - - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): # index -> (min_toks, output_token_ids, stop_token_ids) self.device = device self.pin_memory = is_pin_memory self.min_toks: dict[int, tuple[int, Sequence[int], set[int]]] = {} # (req_idx_tensor,eos_tok_id_tensor) - self.logits_slice: tuple[torch.Tensor, - torch.Tensor] = (self._device_tensor( - [], torch.int32), - self._device_tensor( - [], torch.int32)) + self.logits_slice: tuple[torch.Tensor, torch.Tensor] = ( + self._device_tensor([], torch.int32), + self._device_tensor([], torch.int32), + ) def is_argmax_invariant(self) -> bool: """By censoring stop tokens, min-tokens can change the outcome @@ -187,8 +185,7 @@ def is_argmax_invariant(self) -> bool: @staticmethod def add_request( - params: SamplingParams, _: Optional[list[int]], - output_tok_ids: list[int] + params: SamplingParams, _: Optional[list[int]], output_tok_ids: list[int] ) -> Optional[tuple[int, Sequence[int], set[int]]]: min_tokens = params.min_tokens if not min_tokens or len(output_tok_ids) >= min_tokens: @@ -196,13 +193,16 @@ def add_request( return min_tokens, output_tok_ids, params.all_stop_token_ids def update_state(self, batch_update: Optional[BatchUpdate]): - needs_update = process_dict_updates(self.min_toks, batch_update, - self.add_request) + needs_update = process_dict_updates( + self.min_toks, batch_update, self.add_request + ) if self.min_toks: # Check for any requests that have attained their min tokens. - to_remove = tuple(index for index, (min_toks, out_tok_ids, - _) in self.min_toks.items() - if len(out_tok_ids) >= min_toks) + to_remove = tuple( + index + for index, (min_toks, out_tok_ids, _) in self.min_toks.items() + if len(out_tok_ids) >= min_toks + ) if to_remove: needs_update = True for index in to_remove: @@ -216,15 +216,15 @@ def update_state(self, batch_update: Optional[BatchUpdate]): reqs.extend([req] * len(stop_tok_ids)) tok_ids.extend(stop_tok_ids) - self.logits_slice = (self._device_tensor(reqs, torch.int32), - self._device_tensor(tok_ids, torch.int32)) + self.logits_slice = ( + self._device_tensor(reqs, torch.int32), + self._device_tensor(tok_ids, torch.int32), + ) def _device_tensor(self, data: list, dtype: torch.dtype) -> torch.Tensor: - return (torch.tensor(data, - device="cpu", - dtype=dtype, - pin_memory=self.pin_memory).to(device=self.device, - non_blocking=True)) + return torch.tensor( + data, device="cpu", dtype=dtype, pin_memory=self.pin_memory + ).to(device=self.device, non_blocking=True) def apply(self, logits: torch.Tensor) -> torch.Tensor: if self.min_toks: @@ -236,8 +236,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor): """Limits the number of tokens allowed inside a 'thinking' section.""" - def __init__(self, vllm_config: "VllmConfig", device: torch.device, - is_pin_memory: bool): + def __init__( + self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool + ): """ Args: reasoning_config: Configuration for reasoning, which includes @@ -249,13 +250,14 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device, max_num_reqs = vllm_config.scheduler_config.max_num_seqs # Check if thinking is enabled - self.is_enabled = (reasoning_config is not None - and reasoning_config.is_thinking_enabled()) + self.is_enabled = ( + reasoning_config is not None and reasoning_config.is_thinking_enabled() + ) - self.think_start_token_ids = getattr(reasoning_config, - "think_start_token_ids", []) - self.think_end_token_ids = getattr(reasoning_config, - "think_end_token_ids", []) + self.think_start_token_ids = getattr( + reasoning_config, "think_start_token_ids", [] + ) + self.think_end_token_ids = getattr(reasoning_config, "think_end_token_ids", []) self.pin_memory = is_pin_memory self.device = device @@ -275,14 +277,12 @@ def __init__(self, vllm_config: "VllmConfig", device: torch.device, # Preallocate reusable tensors self.mask = torch.zeros(max_num_reqs, dtype=torch.bool, device=device) - self.force_token_ids = torch.full((max_num_reqs, ), - -1, - dtype=torch.long, - device=device) + self.force_token_ids = torch.full( + (max_num_reqs,), -1, dtype=torch.long, device=device + ) @staticmethod - def _find_last_sequence_index(target_list: list[int], - token_ids: list[int]) -> int: + def _find_last_sequence_index(target_list: list[int], token_ids: list[int]) -> int: """ Returns the index of the last occurrence of token_ids in target_list. @@ -293,12 +293,13 @@ def _find_last_sequence_index(target_list: list[int], if not token_ids: return -1 for i in range(len(target_list) - len(token_ids), -1, -1): - if target_list[i:i + len(token_ids)] == token_ids: + if target_list[i : i + len(token_ids)] == token_ids: return i return -1 - def _init_state_entry(self, prompt_tok_ids: Optional[list[int]], - thinking_token_budget: int) -> dict[str, Any]: + def _init_state_entry( + self, prompt_tok_ids: Optional[list[int]], thinking_token_budget: int + ) -> dict[str, Any]: """Initializes the tracking state for a given sequence index.""" if prompt_tok_ids is None: last_start = -1 @@ -307,13 +308,16 @@ def _init_state_entry(self, prompt_tok_ids: Optional[list[int]], think_count = 0 else: last_start = self._find_last_sequence_index( - prompt_tok_ids, self.think_start_token_ids) - last_end = self._find_last_sequence_index(prompt_tok_ids, - self.think_end_token_ids) + prompt_tok_ids, self.think_start_token_ids + ) + last_end = self._find_last_sequence_index( + prompt_tok_ids, self.think_end_token_ids + ) in_think = last_start > last_end if in_think: think_count = len(prompt_tok_ids) - ( - last_start + len(self.think_start_token_ids)) + last_start + len(self.think_start_token_ids) + ) else: think_count = 0 @@ -326,14 +330,13 @@ def _init_state_entry(self, prompt_tok_ids: Optional[list[int]], "prompt_tok_ids": prompt_tok_ids, "output_tok_ids": [], "thinking_token_budget": thinking_token_budget, - "prev_output_length": - 0, # Track previous output length for incremental updates + "prev_output_length": 0, + # Track previous output length for incremental updates } def _update_think_state(self, state: dict[str, Any]): """Updates the state based on newly generated output tokens.""" - if not state.get("in_end", False) and state.get("check_count_down", - 0) > 0: + if not state.get("in_end", False) and state.get("check_count_down", 0) > 0: state["check_count_down"] -= 1 return @@ -363,9 +366,11 @@ def _update_think_state(self, state: dict[str, Any]): # Find any think start/end sequences in recent tokens recent_start_pos = self._find_last_sequence_index( - recent_tokens, self.think_start_token_ids) + recent_tokens, self.think_start_token_ids + ) recent_end_pos = self._find_last_sequence_index( - recent_tokens, self.think_end_token_ids) + recent_tokens, self.think_end_token_ids + ) # Update state based on recent sequences if not state["in_end"]: @@ -373,8 +378,7 @@ def _update_think_state(self, state: dict[str, Any]): if recent_start_pos > recent_end_pos: # Case: ......... - entering think mode absolute_start_pos = check_start_idx + recent_start_pos - new_think_count = current_length - (absolute_start_pos + - start_len) + new_think_count = current_length - (absolute_start_pos + start_len) state["in_think"] = True state["think_count"] = new_think_count else: @@ -384,8 +388,7 @@ def _update_think_state(self, state: dict[str, Any]): elif recent_start_pos >= 0: # Found think start - entering think mode absolute_start_pos = check_start_idx + recent_start_pos - new_think_count = current_length - (absolute_start_pos + - start_len) + new_think_count = current_length - (absolute_start_pos + start_len) state["in_think"] = True state["think_count"] = new_think_count elif recent_end_pos >= 0: @@ -399,14 +402,17 @@ def _update_think_state(self, state: dict[str, Any]): # Set countdown based on current state if state["in_think"]: remaining_budget = max( - 0, state["thinking_token_budget"] - state["think_count"]) + 0, state["thinking_token_budget"] - state["think_count"] + ) state["check_count_down"] = remaining_budget else: state["check_count_down"] = state["thinking_token_budget"] # Check if need to transition to end mode - if state["in_think"] and state["think_count"] >= state[ - "thinking_token_budget"]: + if ( + state["in_think"] + and state["think_count"] >= state["thinking_token_budget"] + ): state["in_think"] = False state["in_end"] = True state["end_count"] = 0 @@ -415,11 +421,13 @@ def _update_think_state(self, state: dict[str, Any]): # In end mode state["end_count"] += 1 if state["end_count"] >= len(self.think_end_token_ids): - state.update({ - "in_end": False, - "end_count": 0, - "check_count_down": state["thinking_token_budget"] - }) + state.update( + { + "in_end": False, + "end_count": 0, + "check_count_down": state["thinking_token_budget"], + } + ) def is_argmax_invariant(self) -> bool: """This logits processor can change the outcome of @@ -431,13 +439,13 @@ def update_state(self, batch_update: Optional[BatchUpdate]): if not self.is_enabled: return if batch_update: - for (index, params, prompt_tok_ids, output_tok_ids) \ - in batch_update.added: + for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: thinking_token_budget = params.thinking_token_budget if thinking_token_budget is not None: self._state[index] = self._init_state_entry( - prompt_tok_ids, thinking_token_budget) + prompt_tok_ids, thinking_token_budget + ) self._state[index]["output_tok_ids"] = output_tok_ids else: # Remove state if no thinking budget @@ -470,12 +478,12 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: state = self._state.get(i) if state and state["in_end"]: self.mask[i] = True - self.force_token_ids[i] = \ - self.think_end_token_ids[state["end_count"]] + self.force_token_ids[i] = self.think_end_token_ids[state["end_count"]] # Check in CPU first not to sync with GPU has_active_thinking = any( - state.get("in_end", False) for state in self._state.values()) + state.get("in_end", False) for state in self._state.values() + ) if has_active_thinking: current_mask = self.mask[:batch_size] @@ -489,9 +497,9 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor: def process_dict_updates( - req_entries: dict[int, T], batch_update: Optional[BatchUpdate], - new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], - Optional[T]] + req_entries: dict[int, T], + batch_update: Optional[BatchUpdate], + new_state: Callable[[SamplingParams, Optional[list[int]], list[int]], Optional[T]], ) -> bool: """Utility function to update dict state for sparse LogitsProcessors.""" @@ -501,8 +509,7 @@ def process_dict_updates( updated = False for index, params, prompt_tok_ids, output_tok_ids in batch_update.added: - if (state := new_state(params, prompt_tok_ids, - output_tok_ids)) is not None: + if (state := new_state(params, prompt_tok_ids, output_tok_ids)) is not None: req_entries[index] = state updated = True elif req_entries.pop(index, None) is not None: From de53277d37b25e5cd958cc701d6f2892b82d9244 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sat, 1 Nov 2025 13:14:23 +0000 Subject: [PATCH 41/61] make is_thinking_enabled property Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/reasoning.py | 10 +++++----- vllm/v1/sample/logits_processor/builtin.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 4ea20623a74e..8030f7ba6fa1 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional from pydantic.dataclasses import dataclass @@ -15,15 +14,16 @@ class ReasoningConfig: """Configuration for reasoning models.""" - think_start_str: Optional[str] = None + think_start_str: str | None = None """String that indicates the start of reasoning.""" - think_end_str: Optional[str] = None + think_end_str: str | None = None """String that indicates the end of reasoning.""" - think_start_token_ids: Optional[list[int]] = None + think_start_token_ids: list[int] | None = None """Token ID that indicates the start of reasoning.""" - think_end_token_ids: Optional[list[int]] = None + think_end_token_ids: list[int] | None = None """Token ID that indicates the end of reasoning.""" + @property def is_thinking_enabled(self) -> bool: """Check if both start and end thinking token IDs are set to enable thinking token budget logic.""" diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 3a5e87ef7348..62ccd78fb888 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -251,7 +251,7 @@ def __init__( # Check if thinking is enabled self.is_enabled = ( - reasoning_config is not None and reasoning_config.is_thinking_enabled() + reasoning_config is not None and reasoning_config.is_thinking_enabled ) self.think_start_token_ids = getattr( From c2155751a3e55d2a07ae755d4a1a7597440eb19a Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 2 Nov 2025 07:34:30 +0000 Subject: [PATCH 42/61] fix readthedocs failed Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 62ccd78fb888..5bacfc55e482 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -239,13 +239,6 @@ class ThinkingTokenBudgetLogitsProcessor(LogitsProcessor): def __init__( self, vllm_config: "VllmConfig", device: torch.device, is_pin_memory: bool ): - """ - Args: - reasoning_config: Configuration for reasoning, which includes - the token IDs for thinking start and end. - pin_memory (bool): Whether to use pinned memory for tensors. - device (torch.device): Device to use for tensor operations. - """ reasoning_config = vllm_config.reasoning_config max_num_reqs = vllm_config.scheduler_config.max_num_seqs From 2a5e6c0c9bdfb7efc2bb6fb82377bcab66d40218 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Tue, 6 Jan 2026 00:05:21 +0800 Subject: [PATCH 43/61] Update vllm/config/reasoning.py Signed-off-by: Chauncey --- vllm/config/reasoning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 8030f7ba6fa1..9997eea59da2 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -6,7 +6,7 @@ from vllm.config.model import ModelConfig from vllm.config.utils import config -from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs +from vllm.tokenizers import cached_tokenizer_from_config @config From e8c020df62a2faec21083721366cfc6a6588de10 Mon Sep 17 00:00:00 2001 From: Chauncey Date: Tue, 6 Jan 2026 00:05:54 +0800 Subject: [PATCH 44/61] Update vllm/config/reasoning.py Signed-off-by: Chauncey --- vllm/config/reasoning.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 9997eea59da2..f22c41d38aca 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -37,7 +37,7 @@ def is_thinking_enabled(self) -> bool: def initialize_token_ids(self, model_config: ModelConfig) -> None: """Initialize reasoning token IDs from strings using the tokenizer.""" if self.think_start_str is not None and self.think_end_str is not None: - tokenizer = init_tokenizer_from_configs(model_config=model_config) + tokenizer = cached_tokenizer_from_config(model_config=model_config) # Convert reasoning strings to token IDs self.think_start_token_ids = tokenizer.convert_tokens_to_ids( From d3b06cb12a3e3160ac57bd71d7af5ac9f9c7dd0f Mon Sep 17 00:00:00 2001 From: Harry Mellor <19981378+hmellor@users.noreply.github.com> Date: Tue, 10 Feb 2026 16:39:13 +0100 Subject: [PATCH 45/61] Remove unused import from reasoning.py Remove unused import of dataclass from pydantic. Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Made-with: Cursor --- vllm/config/reasoning.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index f22c41d38aca..07ae47d5fb52 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -1,16 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -from pydantic.dataclasses import dataclass - from vllm.config.model import ModelConfig from vllm.config.utils import config from vllm.tokenizers import cached_tokenizer_from_config @config -@dataclass class ReasoningConfig: """Configuration for reasoning models.""" From be1e8b6961257c422463e02721fbb7688f2bbd7d Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 27 Feb 2026 04:19:33 +0000 Subject: [PATCH 46/61] make thinking budget logits processor working with async scheduling option Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a3e0adfae218..a2e3b0a709f0 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -573,7 +573,12 @@ def __init__( ), # We currently don't know whether a particular custom logits processor # uses output token ids so we set this conservatively. - logitsprocs_need_output_token_ids=bool(custom_logitsprocs), + # ThinkingTokenBudgetLogitsProcessor also needs output token ids to + # correctly track think start/end token sequences in async scheduling. + logitsprocs_need_output_token_ids=bool(custom_logitsprocs) or ( + self.vllm_config.reasoning_config is not None + and self.vllm_config.reasoning_config.is_thinking_enabled + ), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) From 5cfa548e242cd1eaa9083fad67a7aff37ddb3eae Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 27 Feb 2026 04:32:29 +0000 Subject: [PATCH 47/61] make precommit Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index a2e3b0a709f0..250c3170593a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -575,7 +575,8 @@ def __init__( # uses output token ids so we set this conservatively. # ThinkingTokenBudgetLogitsProcessor also needs output token ids to # correctly track think start/end token sequences in async scheduling. - logitsprocs_need_output_token_ids=bool(custom_logitsprocs) or ( + logitsprocs_need_output_token_ids=bool(custom_logitsprocs) + or ( self.vllm_config.reasoning_config is not None and self.vllm_config.reasoning_config.is_thinking_enabled ), From c035ea081177649a80404c3f6a94b1d04dd1e7e9 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 27 Feb 2026 04:37:17 +0000 Subject: [PATCH 48/61] remove obsolte file Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/entrypoints/openai/protocol.py | 2568 --------------------------- 1 file changed, 2568 deletions(-) delete mode 100644 vllm/entrypoints/openai/protocol.py diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py deleted file mode 100644 index 931e1d29dd62..000000000000 --- a/vllm/entrypoints/openai/protocol.py +++ /dev/null @@ -1,2568 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -# Adapted from -# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py -import json -import time -from http import HTTPStatus -from typing import Annotated, Any, ClassVar, Literal, TypeAlias - -import regex as re -import torch -from fastapi import HTTPException, UploadFile -from openai.types.chat.chat_completion_audio import ( - ChatCompletionAudio as OpenAIChatCompletionAudio, -) -from openai.types.chat.chat_completion_message import Annotation as OpenAIAnnotation -from openai.types.responses import ( - ResponseCodeInterpreterCallCodeDeltaEvent, - ResponseCodeInterpreterCallCodeDoneEvent, - ResponseCodeInterpreterCallCompletedEvent, - ResponseCodeInterpreterCallInProgressEvent, - ResponseCodeInterpreterCallInterpretingEvent, - ResponseContentPartAddedEvent, - ResponseContentPartDoneEvent, - ResponseFunctionToolCall, - ResponseInputItemParam, - ResponseMcpCallArgumentsDeltaEvent, - ResponseMcpCallArgumentsDoneEvent, - ResponseMcpCallCompletedEvent, - ResponseMcpCallInProgressEvent, - ResponseOutputItem, - ResponseOutputItemAddedEvent, - ResponseOutputItemDoneEvent, - ResponsePrompt, - ResponseReasoningTextDeltaEvent, - ResponseReasoningTextDoneEvent, - ResponseStatus, - ResponseWebSearchCallCompletedEvent, - ResponseWebSearchCallInProgressEvent, - ResponseWebSearchCallSearchingEvent, -) -from openai.types.responses import ( - ResponseCompletedEvent as OpenAIResponseCompletedEvent, -) -from openai.types.responses import ResponseCreatedEvent as OpenAIResponseCreatedEvent -from openai.types.responses import ( - ResponseInProgressEvent as OpenAIResponseInProgressEvent, -) -from openai.types.responses.response_reasoning_item import ( - Content as ResponseReasoningTextContent, -) -from openai_harmony import Message as OpenAIHarmonyMessage - -# Backward compatibility for OpenAI client versions -try: # For older openai versions (< 1.100.0) - from openai.types.responses import ResponseTextConfig -except ImportError: # For newer openai versions (>= 1.100.0) - from openai.types.responses import ResponseFormatTextConfig as ResponseTextConfig - - -from openai.types.responses.response import IncompleteDetails, ToolChoice -from openai.types.responses.tool import Tool -from openai.types.shared import Metadata, Reasoning -from pydantic import ( - BaseModel, - ConfigDict, - Field, - ValidationError, - field_serializer, - model_validator, -) - -from vllm.entrypoints.chat_utils import ChatCompletionMessageParam, make_tool_call_id -from vllm.exceptions import VLLMValidationError -from vllm.logger import init_logger -from vllm.logprobs import Logprob -from vllm.sampling_params import ( - BeamSearchParams, - RequestOutputKind, - SamplingParams, - StructuredOutputsParams, -) -from vllm.utils import random_uuid -from vllm.utils.import_utils import resolve_obj_by_qualname - -logger = init_logger(__name__) - -_LONG_INFO = torch.iinfo(torch.long) - - -class OpenAIBaseModel(BaseModel): - # OpenAI API does allow extra fields - model_config = ConfigDict(extra="allow") - - # Cache class field names - field_names: ClassVar[set[str] | None] = None - - @model_validator(mode="wrap") - @classmethod - def __log_extra_fields__(cls, data, handler): - result = handler(data) - if not isinstance(data, dict): - return result - field_names = cls.field_names - if field_names is None: - # Get all class field names and their potential aliases - field_names = set() - for field_name, field in cls.model_fields.items(): - field_names.add(field_name) - if alias := getattr(field, "alias", None): - field_names.add(alias) - cls.field_names = field_names - - # Compare against both field names and aliases - if any(k not in field_names for k in data): - logger.warning( - "The following fields were present in the request but ignored: %s", - data.keys() - field_names, - ) - return result - - -class ErrorInfo(OpenAIBaseModel): - message: str - type: str - param: str | None = None - code: int - - -class ErrorResponse(OpenAIBaseModel): - error: ErrorInfo - - -class ModelPermission(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") - object: str = "model_permission" - created: int = Field(default_factory=lambda: int(time.time())) - allow_create_engine: bool = False - allow_sampling: bool = True - allow_logprobs: bool = True - allow_search_indices: bool = False - allow_view: bool = True - allow_fine_tuning: bool = False - organization: str = "*" - group: str | None = None - is_blocking: bool = False - - -class ModelCard(OpenAIBaseModel): - id: str - object: str = "model" - created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "vllm" - root: str | None = None - parent: str | None = None - max_model_len: int | None = None - permission: list[ModelPermission] = Field(default_factory=list) - - -class ModelList(OpenAIBaseModel): - object: str = "list" - data: list[ModelCard] = Field(default_factory=list) - - -class PromptTokenUsageInfo(OpenAIBaseModel): - cached_tokens: int | None = None - - -class UsageInfo(OpenAIBaseModel): - prompt_tokens: int = 0 - total_tokens: int = 0 - completion_tokens: int | None = 0 - prompt_tokens_details: PromptTokenUsageInfo | None = None - - -class RequestResponseMetadata(BaseModel): - request_id: str - final_usage_info: UsageInfo | None = None - - -class JsonSchemaResponseFormat(OpenAIBaseModel): - name: str - description: str | None = None - # schema is the field in openai but that causes conflicts with pydantic so - # instead use json_schema with an alias - json_schema: dict[str, Any] | None = Field(default=None, alias="schema") - strict: bool | None = None - - -class LegacyStructuralTag(OpenAIBaseModel): - begin: str - # schema is the field, but that causes conflicts with pydantic so - # instead use structural_tag_schema with an alias - structural_tag_schema: dict[str, Any] | None = Field(default=None, alias="schema") - end: str - - -class LegacyStructuralTagResponseFormat(OpenAIBaseModel): - type: Literal["structural_tag"] - structures: list[LegacyStructuralTag] - triggers: list[str] - - -class StructuralTagResponseFormat(OpenAIBaseModel): - type: Literal["structural_tag"] - format: Any - - -AnyStructuralTagResponseFormat: TypeAlias = ( - LegacyStructuralTagResponseFormat | StructuralTagResponseFormat -) - - -class ResponseFormat(OpenAIBaseModel): - # type must be "json_schema", "json_object", or "text" - type: Literal["text", "json_object", "json_schema"] - json_schema: JsonSchemaResponseFormat | None = None - - -AnyResponseFormat: TypeAlias = ( - ResponseFormat | StructuralTagResponseFormat | LegacyStructuralTagResponseFormat -) - - -class StreamOptions(OpenAIBaseModel): - include_usage: bool | None = True - continuous_usage_stats: bool | None = False - - -class FunctionDefinition(OpenAIBaseModel): - name: str - description: str | None = None - parameters: dict[str, Any] | None = None - - -class ChatCompletionToolsParam(OpenAIBaseModel): - type: Literal["function"] = "function" - function: FunctionDefinition - - -class ChatCompletionNamedFunction(OpenAIBaseModel): - name: str - - -class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel): - function: ChatCompletionNamedFunction - type: Literal["function"] = "function" - - -# extra="forbid" is a workaround to have kwargs as a field, -# see https://github.com/pydantic/pydantic/issues/3125 -class LogitsProcessorConstructor(BaseModel): - qualname: str - args: list[Any] | None = None - kwargs: dict[str, Any] | None = None - - model_config = ConfigDict(extra="forbid") - - -LogitsProcessors = list[str | LogitsProcessorConstructor] - - -def get_logits_processors( - processors: LogitsProcessors | None, pattern: str | None -) -> list[Any] | None: - if processors and pattern: - logits_processors = [] - for processor in processors: - qualname = processor if isinstance(processor, str) else processor.qualname - if not re.match(pattern, qualname): - raise ValueError( - f"Logits processor '{qualname}' is not allowed by this " - "server. See --logits-processor-pattern engine argument " - "for more information." - ) - try: - logits_processor = resolve_obj_by_qualname(qualname) - except Exception as e: - raise ValueError( - f"Logits processor '{qualname}' could not be resolved: {e}" - ) from e - if isinstance(processor, LogitsProcessorConstructor): - logits_processor = logits_processor( - *processor.args or [], **processor.kwargs or {} - ) - logits_processors.append(logits_processor) - return logits_processors - elif processors: - raise ValueError( - "The `logits_processors` argument is not supported by this " - "server. See --logits-processor-pattern engine argument " - "for more information." - ) - return None - - -ResponseInputOutputItem: TypeAlias = ResponseInputItemParam | ResponseOutputItem - - -class ResponsesRequest(OpenAIBaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/responses/create - background: bool | None = False - include: ( - list[ - Literal[ - "code_interpreter_call.outputs", - "computer_call_output.output.image_url", - "file_search_call.results", - "message.input_image.image_url", - "message.output_text.logprobs", - "reasoning.encrypted_content", - ], - ] - | None - ) = None - input: str | list[ResponseInputOutputItem] - instructions: str | None = None - max_output_tokens: int | None = None - max_tool_calls: int | None = None - metadata: Metadata | None = None - model: str | None = None - logit_bias: dict[str, float] | None = None - parallel_tool_calls: bool | None = True - previous_response_id: str | None = None - prompt: ResponsePrompt | None = None - reasoning: Reasoning | None = None - service_tier: Literal["auto", "default", "flex", "scale", "priority"] = "auto" - store: bool | None = True - stream: bool | None = False - temperature: float | None = None - text: ResponseTextConfig | None = None - tool_choice: ToolChoice = "auto" - tools: list[Tool] = Field(default_factory=list) - top_logprobs: int | None = 0 - top_p: float | None = None - top_k: int | None = None - truncation: Literal["auto", "disabled"] | None = "disabled" - user: str | None = None - - # --8<-- [start:responses-extra-params] - request_id: str = Field( - default_factory=lambda: f"resp_{random_uuid()}", - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - mm_processor_kwargs: dict[str, Any] | None = Field( - default=None, - description=("Additional kwargs to pass to the HF processor."), - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - cache_salt: str | None = Field( - default=None, - description=( - "If specified, the prefix cache will be salted with the provided " - "string to prevent an attacker to guess prompts in multi-user " - "environments. The salt should be random, protected from " - "access by 3rd parties, and long enough to be " - "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit)." - ), - ) - - enable_response_messages: bool = Field( - default=False, - description=( - "Dictates whether or not to return messages as part of the " - "response object. Currently only supported for" - "non-background and gpt-oss only. " - ), - ) - # similar to input_messages / output_messages in ResponsesResponse - # we take in previous_input_messages (ie in harmony format) - # this cannot be used in conjunction with previous_response_id - # TODO: consider supporting non harmony messages as well - previous_input_messages: list[OpenAIHarmonyMessage | dict] | None = None - # --8<-- [end:responses-extra-params] - - _DEFAULT_SAMPLING_PARAMS = { - "temperature": 1.0, - "top_p": 1.0, - "top_k": 0, - } - - def to_sampling_params( - self, - default_max_tokens: int, - default_sampling_params: dict | None = None, - ) -> SamplingParams: - if self.max_output_tokens is None: - max_tokens = default_max_tokens - else: - max_tokens = min(self.max_output_tokens, default_max_tokens) - - default_sampling_params = default_sampling_params or {} - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - if (top_p := self.top_p) is None: - top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] - ) - if (top_k := self.top_k) is None: - top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] - ) - stop_token_ids = default_sampling_params.get("stop_token_ids") - - # Structured output - structured_outputs = None - if self.text is not None and self.text.format is not None: - response_format = self.text.format - if ( - response_format.type == "json_schema" - and response_format.schema_ is not None - ): - structured_outputs = StructuredOutputsParams( - json=response_format.schema_ - ) - elif response_format.type == "json_object": - raise NotImplementedError("json_object is not supported") - - # TODO: add more parameters - return SamplingParams.from_optional( - temperature=temperature, - top_p=top_p, - top_k=top_k, - max_tokens=max_tokens, - logprobs=self.top_logprobs if self.is_include_output_logprobs() else None, - stop_token_ids=stop_token_ids, - output_kind=( - RequestOutputKind.DELTA if self.stream else RequestOutputKind.FINAL_ONLY - ), - structured_outputs=structured_outputs, - logit_bias=self.logit_bias, - skip_clone=True, # Created fresh per request, safe to skip clone - ) - - def is_include_output_logprobs(self) -> bool: - """Check if the request includes output logprobs.""" - if self.include is None: - return False - return ( - isinstance(self.include, list) - and "message.output_text.logprobs" in self.include - ) - - @model_validator(mode="before") - def validate_background(cls, data): - if not data.get("background"): - return data - if not data.get("store", True): - raise ValueError("background can only be used when `store` is true") - return data - - @model_validator(mode="before") - def validate_prompt(cls, data): - if data.get("prompt") is not None: - raise VLLMValidationError( - "prompt template is not supported", parameter="prompt" - ) - return data - - @model_validator(mode="before") - def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None and ( - not isinstance(data["cache_salt"], str) or not data["cache_salt"] - ): - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) - return data - - @model_validator(mode="before") - def function_call_parsing(cls, data): - """Parse function_call dictionaries into ResponseFunctionToolCall objects. - This ensures Pydantic can properly resolve union types in the input field. - Function calls provided as dicts are converted to ResponseFunctionToolCall - objects before validation, while invalid structures are left for Pydantic - to reject with appropriate error messages. - """ - - input_data = data.get("input") - - # Early return for None, strings, or bytes - # (strings are iterable but shouldn't be processed) - if input_data is None or isinstance(input_data, (str, bytes)): - return data - - # Convert iterators (like ValidatorIterator) to list - if not isinstance(input_data, list): - try: - input_data = list(input_data) - except TypeError: - # Not iterable, leave as-is for Pydantic to handle - return data - - processed_input = [] - for item in input_data: - if isinstance(item, dict) and item.get("type") == "function_call": - try: - processed_input.append(ResponseFunctionToolCall(**item)) - except ValidationError: - # Let Pydantic handle validation for malformed function calls - logger.debug( - "Failed to parse function_call to ResponseFunctionToolCall, " - "leaving for Pydantic validation" - ) - processed_input.append(item) - else: - processed_input.append(item) - - data["input"] = processed_input - return data - - -class ChatCompletionRequest(OpenAIBaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/chat/create - messages: list[ChatCompletionMessageParam] - model: str | None = None - frequency_penalty: float | None = 0.0 - logit_bias: dict[str, float] | None = None - logprobs: bool | None = False - top_logprobs: int | None = 0 - max_tokens: int | None = Field( - default=None, - deprecated="max_tokens is deprecated in favor of " - "the max_completion_tokens field", - ) - max_completion_tokens: int | None = None - n: int | None = 1 - presence_penalty: float | None = 0.0 - response_format: AnyResponseFormat | None = None - seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: str | list[str] | None = [] - stream: bool | None = False - stream_options: StreamOptions | None = None - temperature: float | None = None - top_p: float | None = None - tools: list[ChatCompletionToolsParam] | None = None - tool_choice: ( - Literal["none"] - | Literal["auto"] - | Literal["required"] - | ChatCompletionNamedToolChoiceParam - | None - ) = "none" - reasoning_effort: Literal["low", "medium", "high"] | None = None - thinking_token_budget: int | None = None - include_reasoning: bool = True - parallel_tool_calls: bool | None = True - - # NOTE this will be ignored by vLLM - user: str | None = None - - # --8<-- [start:chat-completion-sampling-params] - use_beam_search: bool = False - top_k: int | None = None - min_p: float | None = None - repetition_penalty: float | None = None - length_penalty: float = 1.0 - stop_token_ids: list[int] | None = [] - include_stop_str_in_output: bool = False - ignore_eos: bool = False - min_tokens: int = 0 - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = ( - None - ) - prompt_logprobs: int | None = None - allowed_token_ids: list[int] | None = None - bad_words: list[str] = Field(default_factory=list) - # --8<-- [end:chat-completion-sampling-params] - - # --8<-- [start:chat-completion-extra-params] - echo: bool = Field( - default=False, - description=( - "If true, the new message will be prepended with the last message " - "if they belong to the same role." - ), - ) - add_generation_prompt: bool = Field( - default=True, - description=( - "If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model." - ), - ) - continue_final_message: bool = Field( - default=False, - description=( - "If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - 'This allows you to "prefill" part of the model\'s response for it. ' - "Cannot be used at the same time as `add_generation_prompt`." - ), - ) - add_special_tokens: bool = Field( - default=False, - description=( - "If true, special tokens (e.g. BOS) will be added to the prompt " - "on top of what is added by the chat template. " - "For most models, the chat template takes care of adding the " - "special tokens so this should be set to false (as is the " - "default)." - ), - ) - documents: list[dict[str, str]] | None = Field( - default=None, - description=( - "A list of dicts representing documents that will be accessible to " - "the model if it is performing RAG (retrieval-augmented generation)." - " If the template does not support RAG, this argument will have no " - "effect. We recommend that each document should be a dict containing " - '"title" and "text" keys.' - ), - ) - chat_template: str | None = Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ) - chat_template_kwargs: dict[str, Any] | None = Field( - default=None, - description=( - "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ) - mm_processor_kwargs: dict[str, Any] | None = Field( - default=None, - description=("Additional kwargs to pass to the HF processor."), - ) - structured_outputs: StructuredOutputsParams | None = Field( - default=None, - description="Additional kwargs for structured outputs", - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - request_id: str = Field( - default_factory=random_uuid, - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - logits_processors: LogitsProcessors | None = Field( - default=None, - description=( - "A list of either qualified names of logits processors, or " - "constructor objects, to apply when sampling. A constructor is " - "a JSON object with a required 'qualname' field specifying the " - "qualified name of the processor class/factory, and optional " - "'args' and 'kwargs' fields containing positional and keyword " - "arguments. For example: {'qualname': " - "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}." - ), - ) - return_tokens_as_token_ids: bool | None = Field( - default=None, - description=( - "If specified with 'logprobs', tokens are represented " - " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified." - ), - ) - return_token_ids: bool | None = Field( - default=None, - description=( - "If specified, the result will include token IDs alongside the " - "generated text. In streaming mode, prompt_token_ids is included " - "only in the first chunk, and token_ids contains the delta tokens " - "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens." - ), - ) - cache_salt: str | None = Field( - default=None, - description=( - "If specified, the prefix cache will be salted with the provided " - "string to prevent an attacker to guess prompts in multi-user " - "environments. The salt should be random, protected from " - "access by 3rd parties, and long enough to be " - "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit)." - ), - ) - kv_transfer_params: dict[str, Any] | None = Field( - default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) - - vllm_xargs: dict[str, str | int | float | list[str | int | float]] | None = Field( - default=None, - description=( - "Additional request parameters with (list of) string or " - "numeric values, used by custom extensions." - ), - ) - - # --8<-- [end:chat-completion-extra-params] - - # Default sampling parameters for chat completion requests - _DEFAULT_SAMPLING_PARAMS: dict = { - "repetition_penalty": 1.0, - "temperature": 1.0, - "top_p": 1.0, - "top_k": 0, - "min_p": 0.0, - } - - def to_beam_search_params( - self, max_tokens: int, default_sampling_params: dict - ) -> BeamSearchParams: - n = self.n if self.n is not None else 1 - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - - return BeamSearchParams( - beam_width=n, - max_tokens=max_tokens, - ignore_eos=self.ignore_eos, - temperature=temperature, - length_penalty=self.length_penalty, - include_stop_str_in_output=self.include_stop_str_in_output, - ) - - def to_sampling_params( - self, - max_tokens: int, - logits_processor_pattern: str | None, - default_sampling_params: dict, - ) -> SamplingParams: - # Default parameters - if (repetition_penalty := self.repetition_penalty) is None: - repetition_penalty = default_sampling_params.get( - "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], - ) - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - if (top_p := self.top_p) is None: - top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] - ) - if (top_k := self.top_k) is None: - top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] - ) - if (min_p := self.min_p) is None: - min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] - ) - - prompt_logprobs = self.prompt_logprobs - if prompt_logprobs is None and self.echo: - prompt_logprobs = self.top_logprobs - - response_format = self.response_format - if response_format is not None: - # If structured outputs wasn't already enabled, - # we must enable it for these features to work - if self.structured_outputs is None: - self.structured_outputs = StructuredOutputsParams() - - # Set structured output params for response format - if response_format.type == "json_object": - self.structured_outputs.json_object = True - elif response_format.type == "json_schema": - json_schema = response_format.json_schema - assert json_schema is not None - self.structured_outputs.json = json_schema.json_schema - elif response_format.type == "structural_tag": - structural_tag = response_format - assert structural_tag is not None and isinstance( - structural_tag, - ( - LegacyStructuralTagResponseFormat, - StructuralTagResponseFormat, - ), - ) - s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structured_outputs.structural_tag = json.dumps(s_tag_obj) - - extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} - if self.kv_transfer_params: - # Pass in kv_transfer_params via extra_args - extra_args["kv_transfer_params"] = self.kv_transfer_params - return SamplingParams.from_optional( - n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - seed=self.seed, - stop=self.stop, - stop_token_ids=self.stop_token_ids, - logprobs=self.top_logprobs if self.logprobs else None, - prompt_logprobs=prompt_logprobs, - ignore_eos=self.ignore_eos, - max_tokens=max_tokens, - min_tokens=self.min_tokens, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self.spaces_between_special_tokens, - logits_processors=get_logits_processors( - self.logits_processors, logits_processor_pattern - ), - include_stop_str_in_output=self.include_stop_str_in_output, - truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream - else RequestOutputKind.FINAL_ONLY, - structured_outputs=self.structured_outputs, - logit_bias=self.logit_bias, - bad_words=self.bad_words, - thinking_token_budget=self.thinking_token_budget, - allowed_token_ids=self.allowed_token_ids, - extra_args=extra_args or None, - skip_clone=True, # Created fresh per request, safe to skip clone - ) - - @model_validator(mode="before") - @classmethod - def validate_stream_options(cls, data): - if data.get("stream_options") and not data.get("stream"): - raise VLLMValidationError( - "Stream options can only be defined when `stream=True`.", - parameter="stream_options", - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): - raise VLLMValidationError( - "`prompt_logprobs` are not available when `stream=True`.", - parameter="prompt_logprobs", - ) - - if prompt_logprobs < 0 and prompt_logprobs != -1: - raise VLLMValidationError( - "`prompt_logprobs` must be a positive value or -1.", - parameter="prompt_logprobs", - value=prompt_logprobs, - ) - if (top_logprobs := data.get("top_logprobs")) is not None: - if top_logprobs < 0 and top_logprobs != -1: - raise VLLMValidationError( - "`top_logprobs` must be a positive value or -1.", - parameter="top_logprobs", - value=top_logprobs, - ) - - if (top_logprobs == -1 or top_logprobs > 0) and not data.get("logprobs"): - raise VLLMValidationError( - "when using `top_logprobs`, `logprobs` must be set to true.", - parameter="top_logprobs", - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_structured_outputs_count(cls, data): - if isinstance(data, ValueError): - raise data - - if data.get("structured_outputs", None) is None: - return data - - structured_outputs_kwargs = data["structured_outputs"] - count = sum( - structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice") - ) - # you can only use one kind of constraints for structured outputs - if count > 1: - raise ValueError( - "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice')." - ) - # you can only either use structured outputs or tools, not both - if count > 1 and data.get("tool_choice", "none") not in ( - "none", - "auto", - "required", - ): - raise ValueError( - "You can only either use constraints for structured outputs " - "or tools, not both." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_tool_usage(cls, data): - # if "tool_choice" is not specified but tools are provided, - # default to "auto" tool_choice - if "tool_choice" not in data and data.get("tools"): - data["tool_choice"] = "auto" - - # if "tool_choice" is "none" -- no validation is needed for tools - if "tool_choice" in data and data["tool_choice"] == "none": - return data - - # if "tool_choice" is specified -- validation - if "tool_choice" in data and data["tool_choice"] is not None: - # ensure that if "tool choice" is specified, tools are present - if "tools" not in data or data["tools"] is None: - raise ValueError("When using `tool_choice`, `tools` must be set.") - - # make sure that tool choice is either a named tool - # OR that it's set to "auto" or "required" - if data["tool_choice"] not in ["auto", "required"] and not isinstance( - data["tool_choice"], dict - ): - raise ValueError( - f"Invalid value for `tool_choice`: {data['tool_choice']}! " - 'Only named tools, "none", "auto" or "required" ' - "are supported." - ) - - # if tool_choice is "required" but the "tools" list is empty, - # override the data to behave like "none" to align with - # OpenAI’s behavior. - if ( - data["tool_choice"] == "required" - and isinstance(data["tools"], list) - and len(data["tools"]) == 0 - ): - data["tool_choice"] = "none" - del data["tools"] - return data - - # ensure that if "tool_choice" is specified as an object, - # it matches a valid tool - correct_usage_message = ( - 'Correct usage: `{"type": "function",' - ' "function": {"name": "my_function"}}`' - ) - if isinstance(data["tool_choice"], dict): - valid_tool = False - function = data["tool_choice"].get("function") - if not isinstance(function, dict): - raise ValueError( - f"Invalid value for `function`: `{function}` in " - f"`tool_choice`! {correct_usage_message}" - ) - if "name" not in function: - raise ValueError( - f"Expected field `name` in `function` in " - f"`tool_choice`! {correct_usage_message}" - ) - function_name = function["name"] - if not isinstance(function_name, str) or len(function_name) == 0: - raise ValueError( - f"Invalid `name` in `function`: `{function_name}`" - f" in `tool_choice`! {correct_usage_message}" - ) - for tool in data["tools"]: - if tool["function"]["name"] == function_name: - valid_tool = True - break - if not valid_tool: - raise ValueError( - "The tool specified in `tool_choice` does not match any" - " of the specified `tools`" - ) - return data - - @model_validator(mode="before") - @classmethod - def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get("add_generation_prompt"): - raise ValueError( - "Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True." - ) - return data - - @model_validator(mode="before") - @classmethod - def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None and ( - not isinstance(data["cache_salt"], str) or not data["cache_salt"] - ): - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) - return data - - -class CompletionRequest(OpenAIBaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/completions/create - model: str | None = None - prompt: list[int] | list[list[int]] | str | list[str] | None = None - echo: bool | None = False - frequency_penalty: float | None = 0.0 - logit_bias: dict[str, float] | None = None - logprobs: int | None = None - max_tokens: int | None = 16 - n: int = 1 - presence_penalty: float | None = 0.0 - seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - stop: str | list[str] | None = [] - stream: bool | None = False - stream_options: StreamOptions | None = None - suffix: str | None = None - temperature: float | None = None - top_p: float | None = None - user: str | None = None - - # --8<-- [start:completion-sampling-params] - use_beam_search: bool = False - top_k: int | None = None - min_p: float | None = None - repetition_penalty: float | None = None - length_penalty: float = 1.0 - stop_token_ids: list[int] | None = [] - include_stop_str_in_output: bool = False - ignore_eos: bool = False - min_tokens: int = 0 - skip_special_tokens: bool = True - spaces_between_special_tokens: bool = True - truncate_prompt_tokens: Annotated[int, Field(ge=-1, le=_LONG_INFO.max)] | None = ( - None - ) - allowed_token_ids: list[int] | None = None - prompt_logprobs: int | None = None - # --8<-- [end:completion-sampling-params] - - # --8<-- [start:completion-extra-params] - prompt_embeds: bytes | list[bytes] | None = None - add_special_tokens: bool = Field( - default=True, - description=( - "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt." - ), - ) - response_format: AnyResponseFormat | None = Field( - default=None, - description=( - "Similar to chat completion, this parameter specifies the format " - "of output. Only {'type': 'json_object'}, {'type': 'json_schema'}" - ", {'type': 'structural_tag'}, or {'type': 'text' } is supported." - ), - ) - structured_outputs: StructuredOutputsParams | None = Field( - default=None, - description="Additional kwargs for structured outputs", - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - request_id: str = Field( - default_factory=random_uuid, - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - logits_processors: LogitsProcessors | None = Field( - default=None, - description=( - "A list of either qualified names of logits processors, or " - "constructor objects, to apply when sampling. A constructor is " - "a JSON object with a required 'qualname' field specifying the " - "qualified name of the processor class/factory, and optional " - "'args' and 'kwargs' fields containing positional and keyword " - "arguments. For example: {'qualname': " - "'my_module.MyLogitsProcessor', 'args': [1, 2], 'kwargs': " - "{'param': 'value'}}." - ), - ) - - return_tokens_as_token_ids: bool | None = Field( - default=None, - description=( - "If specified with 'logprobs', tokens are represented " - " as strings of the form 'token_id:{token_id}' so that tokens " - "that are not JSON-encodable can be identified." - ), - ) - return_token_ids: bool | None = Field( - default=None, - description=( - "If specified, the result will include token IDs alongside the " - "generated text. In streaming mode, prompt_token_ids is included " - "only in the first chunk, and token_ids contains the delta tokens " - "for each chunk. This is useful for debugging or when you " - "need to map generated text back to input tokens." - ), - ) - - cache_salt: str | None = Field( - default=None, - description=( - "If specified, the prefix cache will be salted with the provided " - "string to prevent an attacker to guess prompts in multi-user " - "environments. The salt should be random, protected from " - "access by 3rd parties, and long enough to be " - "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit)." - ), - ) - - kv_transfer_params: dict[str, Any] | None = Field( - default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) - - vllm_xargs: dict[str, str | int | float] | None = Field( - default=None, - description=( - "Additional request parameters with string or " - "numeric values, used by custom extensions." - ), - ) - - # --8<-- [end:completion-extra-params] - - # Default sampling parameters for completion requests - _DEFAULT_SAMPLING_PARAMS: dict = { - "repetition_penalty": 1.0, - "temperature": 1.0, - "top_p": 1.0, - "top_k": 0, - "min_p": 0.0, - } - - def to_beam_search_params( - self, - max_tokens: int, - default_sampling_params: dict | None = None, - ) -> BeamSearchParams: - if default_sampling_params is None: - default_sampling_params = {} - n = self.n if self.n is not None else 1 - - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get("temperature", 1.0) - - return BeamSearchParams( - beam_width=n, - max_tokens=max_tokens, - ignore_eos=self.ignore_eos, - temperature=temperature, - length_penalty=self.length_penalty, - include_stop_str_in_output=self.include_stop_str_in_output, - ) - - def to_sampling_params( - self, - max_tokens: int, - logits_processor_pattern: str | None, - default_sampling_params: dict | None = None, - ) -> SamplingParams: - if default_sampling_params is None: - default_sampling_params = {} - - # Default parameters - if (repetition_penalty := self.repetition_penalty) is None: - repetition_penalty = default_sampling_params.get( - "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], - ) - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - if (top_p := self.top_p) is None: - top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] - ) - if (top_k := self.top_k) is None: - top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] - ) - if (min_p := self.min_p) is None: - min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] - ) - - prompt_logprobs = self.prompt_logprobs - if prompt_logprobs is None and self.echo: - prompt_logprobs = self.logprobs - - echo_without_generation = self.echo and self.max_tokens == 0 - - response_format = self.response_format - if response_format is not None: - # If structured outputs wasn't already enabled, - # we must enable it for these features to work - if self.structured_outputs is None: - self.structured_outputs = StructuredOutputsParams() - - # Set structured output params for response format - if response_format.type == "json_object": - self.structured_outputs.json_object = True - elif response_format.type == "json_schema": - json_schema = response_format.json_schema - assert json_schema is not None - self.structured_outputs.json = json_schema.json_schema - elif response_format.type == "structural_tag": - structural_tag = response_format - assert structural_tag is not None and isinstance( - structural_tag, - ( - LegacyStructuralTagResponseFormat, - StructuralTagResponseFormat, - ), - ) - s_tag_obj = structural_tag.model_dump(by_alias=True) - self.structured_outputs.structural_tag = json.dumps(s_tag_obj) - - extra_args: dict[str, Any] = self.vllm_xargs if self.vllm_xargs else {} - if self.kv_transfer_params: - # Pass in kv_transfer_params via extra_args - extra_args["kv_transfer_params"] = self.kv_transfer_params - return SamplingParams.from_optional( - n=self.n, - presence_penalty=self.presence_penalty, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - seed=self.seed, - stop=self.stop, - stop_token_ids=self.stop_token_ids, - logprobs=self.logprobs, - ignore_eos=self.ignore_eos, - max_tokens=max_tokens if not echo_without_generation else 1, - min_tokens=self.min_tokens, - prompt_logprobs=prompt_logprobs, - skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=self.spaces_between_special_tokens, - include_stop_str_in_output=self.include_stop_str_in_output, - logits_processors=get_logits_processors( - self.logits_processors, logits_processor_pattern - ), - truncate_prompt_tokens=self.truncate_prompt_tokens, - output_kind=RequestOutputKind.DELTA - if self.stream - else RequestOutputKind.FINAL_ONLY, - structured_outputs=self.structured_outputs, - logit_bias=self.logit_bias, - allowed_token_ids=self.allowed_token_ids, - extra_args=extra_args or None, - skip_clone=True, # Created fresh per request, safe to skip clone - ) - - @model_validator(mode="before") - @classmethod - def check_structured_outputs_count(cls, data): - if data.get("structured_outputs", None) is None: - return data - - structured_outputs_kwargs = data["structured_outputs"] - count = sum( - structured_outputs_kwargs.get(k) is not None - for k in ("json", "regex", "choice") - ) - if count > 1: - raise VLLMValidationError( - "You can only use one kind of constraints for structured " - "outputs ('json', 'regex' or 'choice').", - parameter="structured_outputs", - ) - return data - - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if (prompt_logprobs := data.get("prompt_logprobs")) is not None: - if data.get("stream") and (prompt_logprobs > 0 or prompt_logprobs == -1): - raise VLLMValidationError( - "`prompt_logprobs` are not available when `stream=True`.", - parameter="prompt_logprobs", - ) - - if prompt_logprobs < 0 and prompt_logprobs != -1: - raise VLLMValidationError( - "`prompt_logprobs` must be a positive value or -1.", - parameter="prompt_logprobs", - value=prompt_logprobs, - ) - if (logprobs := data.get("logprobs")) is not None and logprobs < 0: - raise VLLMValidationError( - "`logprobs` must be a positive value.", - parameter="logprobs", - value=logprobs, - ) - - return data - - @model_validator(mode="before") - @classmethod - def validate_stream_options(cls, data): - if data.get("stream_options") and not data.get("stream"): - raise VLLMValidationError( - "Stream options can only be defined when `stream=True`.", - parameter="stream_options", - ) - - return data - - @model_validator(mode="before") - @classmethod - def validate_prompt_and_prompt_embeds(cls, data): - prompt = data.get("prompt") - prompt_embeds = data.get("prompt_embeds") - - prompt_is_empty = prompt is None or (isinstance(prompt, str) and prompt == "") - embeds_is_empty = prompt_embeds is None or ( - isinstance(prompt_embeds, list) and len(prompt_embeds) == 0 - ) - - if prompt_is_empty and embeds_is_empty: - raise ValueError( - "Either prompt or prompt_embeds must be provided and non-empty." - ) - - return data - - @model_validator(mode="before") - @classmethod - def check_cache_salt_support(cls, data): - if data.get("cache_salt") is not None and ( - not isinstance(data["cache_salt"], str) or not data["cache_salt"] - ): - raise ValueError( - "Parameter 'cache_salt' must be a non-empty string if provided." - ) - return data - - -class CompletionLogProbs(OpenAIBaseModel): - text_offset: list[int] = Field(default_factory=list) - token_logprobs: list[float | None] = Field(default_factory=list) - tokens: list[str] = Field(default_factory=list) - top_logprobs: list[dict[str, float] | None] = Field(default_factory=list) - - -class CompletionResponseChoice(OpenAIBaseModel): - index: int - text: str - logprobs: CompletionLogProbs | None = None - finish_reason: str | None = None - stop_reason: int | str | None = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token" - ), - ) - token_ids: list[int] | None = None # For response - prompt_logprobs: list[dict[int, Logprob] | None] | None = None - prompt_token_ids: list[int] | None = None # For prompt - - -class CompletionResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: Literal["text_completion"] = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[CompletionResponseChoice] - service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None - system_fingerprint: str | None = None - usage: UsageInfo - - # vLLM-specific fields that are not in OpenAI spec - kv_transfer_params: dict[str, Any] | None = Field( - default=None, description="KVTransfer parameters." - ) - - -class CompletionResponseStreamChoice(OpenAIBaseModel): - index: int - text: str - logprobs: CompletionLogProbs | None = None - finish_reason: str | None = None - stop_reason: int | str | None = Field( - default=None, - description=( - "The stop string or token id that caused the completion " - "to stop, None if the completion finished for some other reason " - "including encountering the EOS token" - ), - ) - # not part of the OpenAI spec but for tracing the tokens - # prompt tokens is put into choice to align with CompletionResponseChoice - prompt_token_ids: list[int] | None = None - token_ids: list[int] | None = None - - -class CompletionStreamResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") - object: str = "text_completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[CompletionResponseStreamChoice] - usage: UsageInfo | None = Field(default=None) - - -class FunctionCall(OpenAIBaseModel): - name: str - arguments: str - - -class ToolCall(OpenAIBaseModel): - id: str = Field(default_factory=make_tool_call_id) - type: Literal["function"] = "function" - function: FunctionCall - - -class DeltaFunctionCall(BaseModel): - name: str | None = None - arguments: str | None = None - - -# a tool call delta where everything is optional -class DeltaToolCall(OpenAIBaseModel): - id: str | None = None - type: Literal["function"] | None = None - index: int - function: DeltaFunctionCall | None = None - - -class ExtractedToolCallInformation(BaseModel): - # indicate if tools were called - tools_called: bool - - # extracted tool calls - tool_calls: list[ToolCall] - - # content - per OpenAI spec, content AND tool calls can be returned rarely - # But some models will do this intentionally - content: str | None = None - - -class ChatMessage(OpenAIBaseModel): - role: str - content: str | None = None - refusal: str | None = None - annotations: OpenAIAnnotation | None = None - audio: OpenAIChatCompletionAudio | None = None - function_call: FunctionCall | None = None - tool_calls: list[ToolCall] = Field(default_factory=list) - - # vLLM-specific fields that are not in OpenAI spec - reasoning: str | None = None - reasoning_content: str | None = None - """Deprecated: use `reasoning` instead.""" - - @model_validator(mode="after") - def handle_deprecated_reasoning_content(self): - """Copy reasoning to reasoning_content for backward compatibility.""" - self.reasoning_content = self.reasoning - return self - - -class ChatCompletionLogProb(OpenAIBaseModel): - token: str - logprob: float = -9999.0 - bytes: list[int] | None = None - - -class ChatCompletionLogProbsContent(ChatCompletionLogProb): - # Workaround: redefine fields name cache so that it's not - # shared with the super class. - field_names: ClassVar[set[str] | None] = None - top_logprobs: list[ChatCompletionLogProb] = Field(default_factory=list) - - -class ChatCompletionLogProbs(OpenAIBaseModel): - content: list[ChatCompletionLogProbsContent] | None = None - - -class ChatCompletionResponseChoice(OpenAIBaseModel): - index: int - message: ChatMessage - logprobs: ChatCompletionLogProbs | None = None - # per OpenAI spec this is the default - finish_reason: str | None = "stop" - # not part of the OpenAI spec but included in vLLM for legacy reasons - stop_reason: int | str | None = None - # not part of the OpenAI spec but is useful for tracing the tokens - # in agent scenarios - token_ids: list[int] | None = None - - -class ChatCompletionResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: Literal["chat.completion"] = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[ChatCompletionResponseChoice] - service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None - system_fingerprint: str | None = None - usage: UsageInfo - - # vLLM-specific fields that are not in OpenAI spec - prompt_logprobs: list[dict[int, Logprob] | None] | None = None - prompt_token_ids: list[int] | None = None - kv_transfer_params: dict[str, Any] | None = Field( - default=None, description="KVTransfer parameters." - ) - - -class DeltaMessage(OpenAIBaseModel): - role: str | None = None - content: str | None = None - reasoning: str | None = None - reasoning_content: str | None = None - """Deprecated: use `reasoning` instead.""" - tool_calls: list[DeltaToolCall] = Field(default_factory=list) - - @model_validator(mode="after") - def handle_deprecated_reasoning_content(self): - """Copy reasoning to reasoning_content for backward compatibility.""" - self.reasoning_content = self.reasoning - return self - - -class ChatCompletionResponseStreamChoice(OpenAIBaseModel): - index: int - delta: DeltaMessage - logprobs: ChatCompletionLogProbs | None = None - finish_reason: str | None = None - stop_reason: int | str | None = None - # not part of the OpenAI spec but for tracing the tokens - token_ids: list[int] | None = None - - -class ChatCompletionStreamResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") - object: Literal["chat.completion.chunk"] = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[ChatCompletionResponseStreamChoice] - usage: UsageInfo | None = Field(default=None) - # not part of the OpenAI spec but for tracing the tokens - prompt_token_ids: list[int] | None = None - - -class TranscriptionResponseStreamChoice(OpenAIBaseModel): - delta: DeltaMessage - finish_reason: str | None = None - stop_reason: int | str | None = None - - -class TranscriptionStreamResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}") - object: Literal["transcription.chunk"] = "transcription.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[TranscriptionResponseStreamChoice] - usage: UsageInfo | None = Field(default=None) - - -class InputTokensDetails(OpenAIBaseModel): - cached_tokens: int - input_tokens_per_turn: list[int] = Field(default_factory=list) - cached_tokens_per_turn: list[int] = Field(default_factory=list) - - -class OutputTokensDetails(OpenAIBaseModel): - reasoning_tokens: int = 0 - tool_output_tokens: int = 0 - output_tokens_per_turn: list[int] = Field(default_factory=list) - tool_output_tokens_per_turn: list[int] = Field(default_factory=list) - - -class ResponseUsage(OpenAIBaseModel): - input_tokens: int - input_tokens_details: InputTokensDetails - output_tokens: int - output_tokens_details: OutputTokensDetails - total_tokens: int - - -def serialize_message(msg): - """ - Serializes a single message - """ - if isinstance(msg, dict): - return msg - elif hasattr(msg, "to_dict"): - return msg.to_dict() - else: - # fallback to pyandic dump - return msg.model_dump_json() - - -def serialize_messages(msgs): - """ - Serializes multiple messages - """ - return [serialize_message(msg) for msg in msgs] if msgs else None - - -class ResponseRawMessageAndToken(OpenAIBaseModel): - """Class to show the raw message. - If message / tokens diverge, tokens is the source of truth""" - - message: str - tokens: list[int] - type: Literal["raw_message_tokens"] = "raw_message_tokens" - - -ResponseInputOutputMessage: TypeAlias = ( - list[ChatCompletionMessageParam] | list[ResponseRawMessageAndToken] -) - - -class ResponsesResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"resp_{random_uuid()}") - created_at: int = Field(default_factory=lambda: int(time.time())) - # error: Optional[ResponseError] = None - incomplete_details: IncompleteDetails | None = None - instructions: str | None = None - metadata: Metadata | None = None - model: str - object: Literal["response"] = "response" - output: list[ResponseOutputItem] - parallel_tool_calls: bool - temperature: float - tool_choice: ToolChoice - tools: list[Tool] - top_p: float - background: bool - max_output_tokens: int - max_tool_calls: int | None = None - previous_response_id: str | None = None - prompt: ResponsePrompt | None = None - reasoning: Reasoning | None = None - service_tier: Literal["auto", "default", "flex", "scale", "priority"] - status: ResponseStatus - text: ResponseTextConfig | None = None - top_logprobs: int | None = None - truncation: Literal["auto", "disabled"] - usage: ResponseUsage | None = None - user: str | None = None - - # --8<-- [start:responses-response-extra-params] - # These are populated when enable_response_messages is set to True - # NOTE: custom serialization is needed - # see serialize_input_messages and serialize_output_messages - input_messages: ResponseInputOutputMessage | None = Field( - default=None, - description=( - "If enable_response_messages, we can show raw token input to model." - ), - ) - output_messages: ResponseInputOutputMessage | None = Field( - default=None, - description=( - "If enable_response_messages, we can show raw token output of model." - ), - ) - # --8<-- [end:responses-response-extra-params] - - # NOTE: openAI harmony doesn't serialize TextContent properly, - # TODO: this fixes for TextContent, but need to verify for tools etc - # https://github.com/openai/harmony/issues/78 - @field_serializer("output_messages", when_used="json") - def serialize_output_messages(self, msgs, _info): - return serialize_messages(msgs) - - # NOTE: openAI harmony doesn't serialize TextContent properly, this fixes it - # https://github.com/openai/harmony/issues/78 - @field_serializer("input_messages", when_used="json") - def serialize_input_messages(self, msgs, _info): - return serialize_messages(msgs) - - @classmethod - def from_request( - cls, - request: ResponsesRequest, - sampling_params: SamplingParams, - model_name: str, - created_time: int, - output: list[ResponseOutputItem], - status: ResponseStatus, - usage: ResponseUsage | None = None, - input_messages: ResponseInputOutputMessage | None = None, - output_messages: ResponseInputOutputMessage | None = None, - ) -> "ResponsesResponse": - incomplete_details: IncompleteDetails | None = None - if status == "incomplete": - incomplete_details = IncompleteDetails(reason="max_output_tokens") - # TODO: implement the other reason for incomplete_details, - # which is content_filter - # incomplete_details = IncompleteDetails(reason='content_filter') - return cls( - id=request.request_id, - created_at=created_time, - incomplete_details=incomplete_details, - instructions=request.instructions, - metadata=request.metadata, - model=model_name, - output=output, - input_messages=input_messages, - output_messages=output_messages, - parallel_tool_calls=request.parallel_tool_calls, - temperature=sampling_params.temperature, - tool_choice=request.tool_choice, - tools=request.tools, - top_p=sampling_params.top_p, - background=request.background, - max_output_tokens=sampling_params.max_tokens, - max_tool_calls=request.max_tool_calls, - previous_response_id=request.previous_response_id, - prompt=request.prompt, - reasoning=request.reasoning, - service_tier=request.service_tier, - status=status, - text=request.text, - top_logprobs=sampling_params.logprobs, - truncation=request.truncation, - user=request.user, - usage=usage, - ) - - -# TODO: this code can be removed once -# https://github.com/openai/openai-python/issues/2634 has been resolved -class ResponseReasoningPartDoneEvent(OpenAIBaseModel): - content_index: int - """The index of the content part that is done.""" - - item_id: str - """The ID of the output item that the content part was added to.""" - - output_index: int - """The index of the output item that the content part was added to.""" - - part: ResponseReasoningTextContent - """The content part that is done.""" - - sequence_number: int - """The sequence number of this event.""" - - type: Literal["response.reasoning_part.done"] - """The type of the event. Always `response.reasoning_part.done`.""" - - -# TODO: this code can be removed once -# https://github.com/openai/openai-python/issues/2634 has been resolved -class ResponseReasoningPartAddedEvent(OpenAIBaseModel): - content_index: int - """The index of the content part that is done.""" - - item_id: str - """The ID of the output item that the content part was added to.""" - - output_index: int - """The index of the output item that the content part was added to.""" - - part: ResponseReasoningTextContent - """The content part that is done.""" - - sequence_number: int - """The sequence number of this event.""" - - type: Literal["response.reasoning_part.added"] - """The type of the event. Always `response.reasoning_part.added`.""" - - -# vLLM Streaming Events -# Note: we override the response type with the vLLM ResponsesResponse type -class ResponseCompletedEvent(OpenAIResponseCompletedEvent): - response: ResponsesResponse # type: ignore[override] - - -class ResponseCreatedEvent(OpenAIResponseCreatedEvent): - response: ResponsesResponse # type: ignore[override] - - -class ResponseInProgressEvent(OpenAIResponseInProgressEvent): - response: ResponsesResponse # type: ignore[override] - - -StreamingResponsesResponse: TypeAlias = ( - ResponseCreatedEvent - | ResponseInProgressEvent - | ResponseCompletedEvent - | ResponseOutputItemAddedEvent - | ResponseOutputItemDoneEvent - | ResponseContentPartAddedEvent - | ResponseContentPartDoneEvent - | ResponseReasoningTextDeltaEvent - | ResponseReasoningTextDoneEvent - | ResponseReasoningPartAddedEvent - | ResponseReasoningPartDoneEvent - | ResponseCodeInterpreterCallInProgressEvent - | ResponseCodeInterpreterCallCodeDeltaEvent - | ResponseWebSearchCallInProgressEvent - | ResponseWebSearchCallSearchingEvent - | ResponseWebSearchCallCompletedEvent - | ResponseCodeInterpreterCallCodeDoneEvent - | ResponseCodeInterpreterCallInterpretingEvent - | ResponseCodeInterpreterCallCompletedEvent - | ResponseMcpCallArgumentsDeltaEvent - | ResponseMcpCallArgumentsDoneEvent - | ResponseMcpCallInProgressEvent - | ResponseMcpCallCompletedEvent -) - - -class TokenizeCompletionRequest(OpenAIBaseModel): - model: str | None = None - prompt: str - - add_special_tokens: bool = Field( - default=True, - description=( - "If true (the default), special tokens (e.g. BOS) will be added to " - "the prompt." - ), - ) - return_token_strs: bool | None = Field( - default=False, - description=( - "If true, also return the token strings corresponding to the token ids." - ), - ) - - -class TokenizeChatRequest(OpenAIBaseModel): - model: str | None = None - messages: list[ChatCompletionMessageParam] - - add_generation_prompt: bool = Field( - default=True, - description=( - "If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model." - ), - ) - return_token_strs: bool | None = Field( - default=False, - description=( - "If true, also return the token strings corresponding to the token ids." - ), - ) - continue_final_message: bool = Field( - default=False, - description=( - "If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - 'This allows you to "prefill" part of the model\'s response for it. ' - "Cannot be used at the same time as `add_generation_prompt`." - ), - ) - add_special_tokens: bool = Field( - default=False, - description=( - "If true, special tokens (e.g. BOS) will be added to the prompt " - "on top of what is added by the chat template. " - "For most models, the chat template takes care of adding the " - "special tokens so this should be set to false (as is the " - "default)." - ), - ) - chat_template: str | None = Field( - default=None, - description=( - "A Jinja template to use for this conversion. " - "As of transformers v4.44, default chat template is no longer " - "allowed, so you must provide a chat template if the tokenizer " - "does not define one." - ), - ) - chat_template_kwargs: dict[str, Any] | None = Field( - default=None, - description=( - "Additional keyword args to pass to the template renderer. " - "Will be accessible by the chat template." - ), - ) - mm_processor_kwargs: dict[str, Any] | None = Field( - default=None, - description=("Additional kwargs to pass to the HF processor."), - ) - tools: list[ChatCompletionToolsParam] | None = Field( - default=None, - description=("A list of tools the model may call."), - ) - - @model_validator(mode="before") - @classmethod - def check_generation_prompt(cls, data): - if data.get("continue_final_message") and data.get("add_generation_prompt"): - raise ValueError( - "Cannot set both `continue_final_message` and " - "`add_generation_prompt` to True." - ) - return data - - -TokenizeRequest: TypeAlias = TokenizeCompletionRequest | TokenizeChatRequest - - -class TokenizeResponse(OpenAIBaseModel): - count: int - max_model_len: int - tokens: list[int] - token_strs: list[str] | None = None - - -class DetokenizeRequest(OpenAIBaseModel): - model: str | None = None - tokens: list[int] - - -class DetokenizeResponse(OpenAIBaseModel): - prompt: str - - -class TokenizerInfoResponse(OpenAIBaseModel): - """ - Response containing tokenizer configuration - equivalent to tokenizer_config.json - """ - - model_config = ConfigDict(extra="allow") - tokenizer_class: str - - -class LoadLoRAAdapterRequest(BaseModel): - lora_name: str - lora_path: str - - -class UnloadLoRAAdapterRequest(BaseModel): - lora_name: str - lora_int_id: int | None = Field(default=None) - - -## Protocols for Audio -AudioResponseFormat: TypeAlias = Literal["json", "text", "srt", "verbose_json", "vtt"] - - -class TranscriptionRequest(OpenAIBaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/audio/createTranscription - - file: UploadFile - """ - The audio file object (not file name) to transcribe, in one of these - formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - """ - - model: str | None = None - """ID of the model to use. - """ - - language: str | None = None - """The language of the input audio. - - Supplying the input language in - [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format - will improve accuracy and latency. - """ - - prompt: str = Field(default="") - """An optional text to guide the model's style or continue a previous audio - segment. - - The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) - should match the audio language. - """ - - response_format: AudioResponseFormat = Field(default="json") - """ - The format of the output, in one of these options: `json`, `text`, `srt`, - `verbose_json`, or `vtt`. - """ - - ## TODO (varun) : Support if set to 0, certain thresholds are met !! - - timestamp_granularities: list[Literal["word", "segment"]] = Field( - alias="timestamp_granularities[]", default=[] - ) - """The timestamp granularities to populate for this transcription. - - `response_format` must be set `verbose_json` to use timestamp granularities. - Either or both of these options are supported: `word`, or `segment`. Note: - There is no additional latency for segment timestamps, but generating word - timestamps incurs additional latency. - """ - - stream: bool | None = False - """When set, it will enable output to be streamed in a similar fashion - as the Chat Completion endpoint. - """ - # --8<-- [start:transcription-extra-params] - # Flattened stream option to simplify form data. - stream_include_usage: bool | None = False - stream_continuous_usage_stats: bool | None = False - - vllm_xargs: dict[str, str | int | float] | None = Field( - default=None, - description=( - "Additional request parameters with string or " - "numeric values, used by custom extensions." - ), - ) - # --8<-- [end:transcription-extra-params] - - to_language: str | None = None - """The language of the output audio we transcribe to. - - Please note that this is not currently used by supported models at this - time, but it is a placeholder for future use, matching translation api. - """ - - # --8<-- [start:transcription-sampling-params] - temperature: float = Field(default=0.0) - """The sampling temperature, between 0 and 1. - - Higher values like 0.8 will make the output more random, while lower values - like 0.2 will make it more focused / deterministic. If set to 0, the model - will use [log probability](https://en.wikipedia.org/wiki/Log_probability) - to automatically increase the temperature until certain thresholds are hit. - """ - - top_p: float | None = None - """Enables nucleus (top-p) sampling, where tokens are selected from the - smallest possible set whose cumulative probability exceeds `p`. - """ - - top_k: int | None = None - """Limits sampling to the `k` most probable tokens at each step.""" - - min_p: float | None = None - """Filters out tokens with a probability lower than `min_p`, ensuring a - minimum likelihood threshold during sampling. - """ - - seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - """The seed to use for sampling.""" - - frequency_penalty: float | None = 0.0 - """The frequency penalty to use for sampling.""" - - repetition_penalty: float | None = None - """The repetition penalty to use for sampling.""" - - presence_penalty: float | None = 0.0 - """The presence penalty to use for sampling.""" - - max_completion_tokens: int | None = None - """The maximum number of tokens to generate.""" - # --8<-- [end:transcription-sampling-params] - - # Default sampling parameters for transcription requests. - _DEFAULT_SAMPLING_PARAMS: dict = { - "repetition_penalty": 1.0, - "temperature": 1.0, - "top_p": 1.0, - "top_k": 0, - "min_p": 0.0, - } - - def to_sampling_params( - self, default_max_tokens: int, default_sampling_params: dict | None = None - ) -> SamplingParams: - max_tokens = default_max_tokens - - if default_sampling_params is None: - default_sampling_params = {} - - # Default parameters - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - if (top_p := self.top_p) is None: - top_p = default_sampling_params.get( - "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"] - ) - if (top_k := self.top_k) is None: - top_k = default_sampling_params.get( - "top_k", self._DEFAULT_SAMPLING_PARAMS["top_k"] - ) - if (min_p := self.min_p) is None: - min_p = default_sampling_params.get( - "min_p", self._DEFAULT_SAMPLING_PARAMS["min_p"] - ) - - if (repetition_penalty := self.repetition_penalty) is None: - repetition_penalty = default_sampling_params.get( - "repetition_penalty", - self._DEFAULT_SAMPLING_PARAMS["repetition_penalty"], - ) - - return SamplingParams.from_optional( - temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - top_p=top_p, - top_k=top_k, - min_p=min_p, - frequency_penalty=self.frequency_penalty, - repetition_penalty=repetition_penalty, - presence_penalty=self.presence_penalty, - output_kind=RequestOutputKind.DELTA - if self.stream - else RequestOutputKind.FINAL_ONLY, - extra_args=self.vllm_xargs, - skip_clone=True, # Created fresh per request, safe to skip clone - ) - - @model_validator(mode="before") - @classmethod - def validate_transcription_request(cls, data): - if isinstance(data.get("file"), str): - raise HTTPException( - status_code=HTTPStatus.UNPROCESSABLE_ENTITY, - detail="Expected 'file' to be a file-like object, not 'str'.", - ) - - stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] - stream = data.get("stream", False) - if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - # Find which specific stream option was set - invalid_param = next( - (so for so in stream_opts if data.get(so, False)), - "stream_include_usage", - ) - raise VLLMValidationError( - "Stream options can only be defined when `stream=True`.", - parameter=invalid_param, - ) - - return data - - -# Transcription response objects -class TranscriptionUsageAudio(OpenAIBaseModel): - type: Literal["duration"] = "duration" - seconds: int - - -class TranscriptionResponse(OpenAIBaseModel): - text: str - """The transcribed text.""" - usage: TranscriptionUsageAudio - - -class TranscriptionWord(OpenAIBaseModel): - end: float - """End time of the word in seconds.""" - - start: float - """Start time of the word in seconds.""" - - word: str - """The text content of the word.""" - - -class TranscriptionSegment(OpenAIBaseModel): - id: int - """Unique identifier of the segment.""" - - avg_logprob: float | None = None - """Average logprob of the segment. - - If the value is lower than -1, consider the logprobs failed. - """ - - compression_ratio: float | None = None - """Compression ratio of the segment. - - If the value is greater than 2.4, consider the compression failed. - """ - - end: float - """End time of the segment in seconds.""" - - no_speech_prob: float | None = None - """Probability of no speech in the segment. - - If the value is higher than 1.0 and the `avg_logprob` is below -1, consider - this segment silent. - """ - - seek: int - """Seek offset of the segment.""" - - start: float - """Start time of the segment in seconds.""" - - temperature: float - """Temperature parameter used for generating the segment.""" - - text: str - """Text content of the segment.""" - - tokens: list[int] - """Array of token IDs for the text content.""" - - -class TranscriptionResponseVerbose(OpenAIBaseModel): - duration: str - """The duration of the input audio.""" - - language: str - """The language of the input audio.""" - - text: str - """The transcribed text.""" - - segments: list[TranscriptionSegment] | None = None - """Segments of the transcribed text and their corresponding details.""" - - words: list[TranscriptionWord] | None = None - """Extracted words and their corresponding timestamps.""" - - -TranscriptionResponseVariant: TypeAlias = ( - TranscriptionResponse | TranscriptionResponseVerbose -) - - -class TranslationResponseStreamChoice(OpenAIBaseModel): - delta: DeltaMessage - finish_reason: str | None = None - stop_reason: int | str | None = None - - -class TranslationStreamResponse(OpenAIBaseModel): - id: str = Field(default_factory=lambda: f"trsl-{random_uuid()}") - object: Literal["translation.chunk"] = "translation.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[TranslationResponseStreamChoice] - usage: UsageInfo | None = Field(default=None) - - -class TranslationRequest(OpenAIBaseModel): - # Ordered by official OpenAI API documentation - # https://platform.openai.com/docs/api-reference/audio/createTranslation - - file: UploadFile - """ - The audio file object (not file name) to translate, in one of these - formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. - """ - - model: str | None = None - """ID of the model to use. - """ - - prompt: str = Field(default="") - """An optional text to guide the model's style or continue a previous audio - segment. - - The [prompt](https://platform.openai.com/docs/guides/speech-to-text#prompting) - should match the audio language. - """ - - response_format: AudioResponseFormat = Field(default="json") - """ - The format of the output, in one of these options: `json`, `text`, `srt`, - `verbose_json`, or `vtt`. - """ - - # TODO support additional sampling parameters - # --8<-- [start:translation-sampling-params] - seed: int | None = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) - """The seed to use for sampling.""" - - temperature: float = Field(default=0.0) - """The sampling temperature, between 0 and 1. - - Higher values like 0.8 will make the output more random, while lower values - like 0.2 will make it more focused / deterministic. If set to 0, the model - will use [log probability](https://en.wikipedia.org/wiki/Log_probability) - to automatically increase the temperature until certain thresholds are hit. - """ - # --8<-- [end:translation-sampling-params] - - # --8<-- [start:translation-extra-params] - language: str | None = None - """The language of the input audio we translate from. - - Supplying the input language in - [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format - will improve accuracy. - """ - - to_language: str | None = None - """The language of the input audio we translate to. - - Please note that this is not supported by all models, refer to the specific - model documentation for more details. - For instance, Whisper only supports `to_language=en`. - """ - - stream: bool | None = False - """Custom field not present in the original OpenAI definition. When set, - it will enable output to be streamed in a similar fashion as the Chat - Completion endpoint. - """ - # Flattened stream option to simplify form data. - stream_include_usage: bool | None = False - stream_continuous_usage_stats: bool | None = False - - max_completion_tokens: int | None = None - """The maximum number of tokens to generate.""" - # --8<-- [end:translation-extra-params] - - # Default sampling parameters for translation requests. - _DEFAULT_SAMPLING_PARAMS: dict = { - "temperature": 0, - } - - def to_sampling_params( - self, default_max_tokens: int, default_sampling_params: dict | None = None - ) -> SamplingParams: - max_tokens = default_max_tokens - - if default_sampling_params is None: - default_sampling_params = {} - # Default parameters - if (temperature := self.temperature) is None: - temperature = default_sampling_params.get( - "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"] - ) - - return SamplingParams.from_optional( - temperature=temperature, - max_tokens=max_tokens, - seed=self.seed, - output_kind=RequestOutputKind.DELTA - if self.stream - else RequestOutputKind.FINAL_ONLY, - skip_clone=True, # Created fresh per request, safe to skip clone - ) - - @model_validator(mode="before") - @classmethod - def validate_stream_options(cls, data): - stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] - stream = data.get("stream", False) - if any(bool(data.get(so, False)) for so in stream_opts) and not stream: - # Find which specific stream option was set - invalid_param = next( - (so for so in stream_opts if data.get(so, False)), - "stream_include_usage", - ) - raise VLLMValidationError( - "Stream options can only be defined when `stream=True`.", - parameter=invalid_param, - ) - - return data - - -# Translation response objects -class TranslationResponse(OpenAIBaseModel): - text: str - """The translated text.""" - - -class TranslationWord(OpenAIBaseModel): - end: float - """End time of the word in seconds.""" - - start: float - """Start time of the word in seconds.""" - - word: str - """The text content of the word.""" - - -class TranslationSegment(OpenAIBaseModel): - id: int - """Unique identifier of the segment.""" - - avg_logprob: float | None = None - """Average logprob of the segment. - - If the value is lower than -1, consider the logprobs failed. - """ - - compression_ratio: float | None = None - """Compression ratio of the segment. - - If the value is greater than 2.4, consider the compression failed. - """ - - end: float - """End time of the segment in seconds.""" - - no_speech_prob: float | None = None - """Probability of no speech in the segment. - - If the value is higher than 1.0 and the `avg_logprob` is below -1, consider - this segment silent. - """ - - seek: int - """Seek offset of the segment.""" - - start: float - """Start time of the segment in seconds.""" - - temperature: float - """Temperature parameter used for generating the segment.""" - - text: str - """Text content of the segment.""" - - tokens: list[int] - """Array of token IDs for the text content.""" - - -class TranslationResponseVerbose(OpenAIBaseModel): - duration: str - """The duration of the input audio.""" - - language: str - """The language of the input audio.""" - - text: str - """The translated text.""" - - segments: list[TranslationSegment] | None = None - """Segments of the translated text and their corresponding details.""" - - words: list[TranslationWord] | None = None - """Extracted words and their corresponding timestamps.""" - - -TranslationResponseVariant: TypeAlias = TranslationResponse | TranslationResponseVerbose - - -####### Tokens IN <> Tokens OUT ####### -class GenerateRequest(BaseModel): - request_id: str = Field( - default_factory=random_uuid, - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - token_ids: list[int] - """The token ids to generate text from.""" - - # features: MultiModalFeatureSpec - # TODO (NickLucche): implement once Renderer work is completed - features: str | None = None - """The processed MM inputs for the model.""" - - sampling_params: SamplingParams - """The sampling parameters for the model.""" - - model: str | None = None - - stream: bool | None = False - stream_options: StreamOptions | None = None - cache_salt: str | None = Field( - default=None, - description=( - "If specified, the prefix cache will be salted with the provided " - "string to prevent an attacker to guess prompts in multi-user " - "environments. The salt should be random, protected from " - "access by 3rd parties, and long enough to be " - "unpredictable (e.g., 43 characters base64-encoded, corresponding " - "to 256 bit)." - ), - ) - priority: int = Field( - default=0, - description=( - "The priority of the request (lower means earlier handling; " - "default: 0). Any priority other than 0 will raise an error " - "if the served model does not use priority scheduling." - ), - ) - kv_transfer_params: dict[str, Any] | None = Field( - default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) - - -class GenerateResponseChoice(BaseModel): - index: int - logprobs: ChatCompletionLogProbs | None = None - # per OpenAI spec this is the default - finish_reason: str | None = "stop" - token_ids: list[int] | None = None - - -class GenerateResponse(BaseModel): - request_id: str = Field( - default_factory=random_uuid, - description=( - "The request_id related to this request. If the caller does " - "not set it, a random_uuid will be generated. This id is used " - "through out the inference process and return in response." - ), - ) - choices: list[GenerateResponseChoice] - - prompt_logprobs: list[dict[int, Logprob] | None] | None = None - - kv_transfer_params: dict[str, Any] | None = Field( - default=None, - description="KVTransfer parameters used for disaggregated serving.", - ) From 651635c1f6f600ce8cbde7ede2b12b0cdaf8b767 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 27 Feb 2026 09:58:12 +0000 Subject: [PATCH 49/61] add docs for thinking budget control Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- docs/features/reasoning_outputs.md | 80 ++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 2bb7eeb311fc..6d982da661e6 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -240,6 +240,86 @@ response = client.chat.completions.create( ) ``` +## Thinking Budget Control + +Some models, such as [Qwen3](https://qwen.readthedocs.io/en/latest/getting_started/quickstart.html#thinking-budget), [DeepSeek](https://www.alibabacloud.com/help/en/model-studio/deep-thinking), and [Nemotron3](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16), support a thinking budget that limits the maximum number of tokens used for reasoning. + +Token counting starts from `think_start_str`. Once the reasoning token count reaches the configured `thinking_token_budget`, vLLM forces the model to produce `think_end_str`, effectively terminating the reasoning block. + +To use this feature: + +- `--reasoning-parser` enables reasoning extraction. +- `--reasoning-config` defines the reasoning boundary tokens (e.g., `think_start_str`, `think_end_str`). +- `thinking_token_budget` (a sampling parameter) sets the per-request reasoning token limit. + +If `thinking_token_budget` is not specified, no explicit reasoning limit is applied beyond normal generation constraints such as `max_tokens`. + +`--reasoning-config` accepts a JSON object corresponding to +[ReasoningConfig][vllm.config.ReasoningConfig] with the following fields: + +| Field | Type | Description | +|-------|------|-------------| +| `think_start_str` | `str \| null` | String that marks the start of reasoning content | +| `think_end_str` | `str \| null` | String that marks the end of reasoning content. Can be the reasoning parser's think end token alone (e.g., ``), or a phrase that includes it as a suffix (e.g., `I have to give the solution based on the thinking directly now.`). | +| `think_start_token_ids` | `list[int] \| null` | Token IDs that mark the start of reasoning content. Use to configure the think start boundary via token IDs instead of a string. | +| `think_end_token_ids` | `list[int] \| null` | Token IDs that mark the end of reasoning content. Use to configure the think end boundary via token IDs instead of a string. | + +!!! note + `think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now."` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural. + +### Online Serving + +```bash +vllm serve Qwen/Qwen3-0.6B \ + --reasoning-parser qwen3 \ + --reasoning-config '{"think_start_str": "", "think_end_str": "I have to give the solution based on the thinking directly now."}' +``` + +Then make a request with `thinking_token_budget` to limit the reasoning tokens: + +```bash +curl http://localhost:8000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "Qwen/Qwen3-0.6B", + "messages": [ + { "role": "user", "content": "9.11 and 9.8, which is greater?" } + ], + "extra_body": { + "thinking_token_budget": 10 + } + }' +``` + +### Offline Inference + +```python +from vllm import LLM, SamplingParams +from vllm.config import ReasoningConfig + +llm = LLM( + model="Qwen/Qwen3-0.6B", + reasoning_config=ReasoningConfig( + think_start_str="", + think_end_str="I have to give the solution based on the thinking directly now.", + ), +) + +sampling_params = SamplingParams(thinking_token_budget=10) + +messages = [ + {"role": "user", "content": "9.11 and 9.8, which is greater?"}, +] + +outputs = llm.chat(messages, sampling_params=sampling_params) + +for output in outputs: + reasoning = output.outputs[0].reasoning + content = output.outputs[0].text + print("reasoning:", reasoning) + print("content:", content) +``` + ## Limitations - The reasoning content is only available for online serving's chat completion endpoint (`/v1/chat/completions`). From 7bd0db0eb8273d25de8f6b0b254220d48127b1f9 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 1 Mar 2026 09:42:32 +0000 Subject: [PATCH 50/61] fix docs Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- docs/features/reasoning_outputs.md | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 6d982da661e6..1809e85b92c7 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -314,10 +314,7 @@ messages = [ outputs = llm.chat(messages, sampling_params=sampling_params) for output in outputs: - reasoning = output.outputs[0].reasoning - content = output.outputs[0].text - print("reasoning:", reasoning) - print("content:", content) + print("text:", output.outputs[0].text) ``` ## Limitations From 12023dc099e2e5e8faabaa288e909c4936a92e7e Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Tue, 3 Mar 2026 01:58:12 +0000 Subject: [PATCH 51/61] do not expose think start end token ids field Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- docs/features/reasoning_outputs.md | 2 -- vllm/config/reasoning.py | 28 ++++++++++++++++++++-------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index 1809e85b92c7..615389cafda4 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -261,8 +261,6 @@ If `thinking_token_budget` is not specified, no explicit reasoning limit is appl |-------|------|-------------| | `think_start_str` | `str \| null` | String that marks the start of reasoning content | | `think_end_str` | `str \| null` | String that marks the end of reasoning content. Can be the reasoning parser's think end token alone (e.g., ``), or a phrase that includes it as a suffix (e.g., `I have to give the solution based on the thinking directly now.`). | -| `think_start_token_ids` | `list[int] \| null` | Token IDs that mark the start of reasoning content. Use to configure the think start boundary via token IDs instead of a string. | -| `think_end_token_ids` | `list[int] \| null` | Token IDs that mark the end of reasoning content. Use to configure the think end boundary via token IDs instead of a string. | !!! note `think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now."` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural. diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 07ae47d5fb52..2ad36b89bdbe 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import field + from vllm.config.model import ModelConfig from vllm.config.utils import config from vllm.tokenizers import cached_tokenizer_from_config @@ -8,21 +10,31 @@ @config class ReasoningConfig: - """Configuration for reasoning models.""" + """Configuration for reasoning models. + + Set `think_start_str` and `think_end_str` to the strings that delimit + the reasoning block (e.g. `""` and `""`). The + corresponding token IDs are derived automatically via + `initialize_token_ids` and are not intended to be set directly. + """ think_start_str: str | None = None """String that indicates the start of reasoning.""" think_end_str: str | None = None - """String that indicates the end of reasoning.""" - think_start_token_ids: list[int] | None = None - """Token ID that indicates the start of reasoning.""" - think_end_token_ids: list[int] | None = None - """Token ID that indicates the end of reasoning.""" + """String that indicates the end of reasoning content.""" + + think_start_token_ids: list[int] | None = field( + default=None, init=False, repr=False + ) + """Token IDs derived from `think_start_str`. Set automatically by + `initialize_token_ids`. Not intended to be configured directly.""" + think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False) + """Token IDs derived from `think_end_str`. Set automatically by + `initialize_token_ids`. Not intended to be configured directly.""" @property def is_thinking_enabled(self) -> bool: - """Check if both start and end thinking token IDs - are set to enable thinking token budget logic.""" + """Check if thinking boundaries are configured.""" return ( self.think_start_token_ids is not None and self.think_end_token_ids is not None From e643d5b1b1d4bec3658e7a3fe4d7b0fafe3794d4 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 8 Mar 2026 13:11:02 +0000 Subject: [PATCH 52/61] make think_start/end_str are required and remove is_thinking_enabled method Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- .../v1/logits_processors/test_correctness.py | 3 -- vllm/config/reasoning.py | 37 +++++++++---------- vllm/v1/sample/logits_processor/builtin.py | 4 +- vllm/v1/worker/gpu_model_runner.py | 5 +-- 4 files changed, 19 insertions(+), 30 deletions(-) diff --git a/tests/v1/logits_processors/test_correctness.py b/tests/v1/logits_processors/test_correctness.py index 04a0297c0141..792168877663 100644 --- a/tests/v1/logits_processors/test_correctness.py +++ b/tests/v1/logits_processors/test_correctness.py @@ -106,9 +106,6 @@ class MockReasoningConfig: think_start_token_ids = [THINK_START_TOKEN_ID] think_end_token_ids = [THINK_END_TOKEN_ID] - def is_thinking_enabled(self) -> bool: - return True - def _generate_fake_sampling_metadata( num_output_tokens: int, diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 2ad36b89bdbe..a41436de104b 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -18,9 +18,9 @@ class ReasoningConfig: `initialize_token_ids` and are not intended to be set directly. """ - think_start_str: str | None = None + think_start_str: str = "" """String that indicates the start of reasoning.""" - think_end_str: str | None = None + think_end_str: str = "" """String that indicates the end of reasoning content.""" think_start_token_ids: list[int] | None = field( @@ -32,25 +32,22 @@ class ReasoningConfig: """Token IDs derived from `think_end_str`. Set automatically by `initialize_token_ids`. Not intended to be configured directly.""" - @property - def is_thinking_enabled(self) -> bool: - """Check if thinking boundaries are configured.""" - return ( - self.think_start_token_ids is not None - and self.think_end_token_ids is not None - and len(self.think_start_token_ids) > 0 - and len(self.think_end_token_ids) > 0 - ) - def initialize_token_ids(self, model_config: ModelConfig) -> None: """Initialize reasoning token IDs from strings using the tokenizer.""" - if self.think_start_str is not None and self.think_end_str is not None: - tokenizer = cached_tokenizer_from_config(model_config=model_config) + tokenizer = cached_tokenizer_from_config(model_config=model_config) - # Convert reasoning strings to token IDs - self.think_start_token_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(self.think_start_str) - ) - self.think_end_token_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(self.think_end_str) + # Convert reasoning strings to token IDs + self.think_start_token_ids = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(self.think_start_str) + ) + self.think_end_token_ids = tokenizer.convert_tokens_to_ids( + tokenizer.tokenize(self.think_end_str) + ) + + if not self.think_start_token_ids or not self.think_end_token_ids: + raise ValueError( + f"ReasoningConfig: failed to tokenize reasoning strings: " + f"think_start_str='{self.think_start_str}', " + f"think_end_str='{self.think_end_str}'. " + "Ensure the strings are valid tokens in the model's vocabulary." ) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index 3978e3026d57..d213d0a87de0 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -301,9 +301,7 @@ def __init__( max_num_reqs = vllm_config.scheduler_config.max_num_seqs # Check if thinking is enabled - self.is_enabled = ( - reasoning_config is not None and reasoning_config.is_thinking_enabled - ) + self.is_enabled = reasoning_config is not None self.think_start_token_ids = getattr( reasoning_config, "think_start_token_ids", [] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index cbd30fc7817f..b8f2e65aa56e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -584,10 +584,7 @@ def __init__( # ThinkingTokenBudgetLogitsProcessor also needs output token ids to # correctly track think start/end token sequences in async scheduling. logitsprocs_need_output_token_ids=bool(custom_logitsprocs) - or ( - self.vllm_config.reasoning_config is not None - and self.vllm_config.reasoning_config.is_thinking_enabled - ), + or (self.vllm_config.reasoning_config is not None), is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) From cceb341ac362f304446fde418c8733b5318aaef7 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 8 Mar 2026 14:24:06 +0000 Subject: [PATCH 53/61] fix swap part Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index d213d0a87de0..d801f49b8a1c 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -505,13 +505,16 @@ def update_state(self, batch_update: BatchUpdate | None): for i1, i2, direction in batch_update.moved: if direction == MoveDirectionality.SWAP: - state1 = self._state.get(i1, {}) - state2 = self._state.get(i2, {}) - if state1 or state2: - self._state[i1] = state2 + state1 = self._state.pop(i1, None) + state2 = self._state.pop(i2, None) + if state1 is not None: self._state[i2] = state1 + if state2 is not None: + self._state[i1] = state2 else: - self._state[i2] = self._state.pop(i1, {}) + state = self._state.pop(i1, None) + if state is not None: + self._state[i2] = state for state in self._state.values(): self._update_think_state(state) From 43ae6c439fd8d63eedf466866a25ba273c84b689 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 8 Mar 2026 15:45:55 +0000 Subject: [PATCH 54/61] fix: ensure reasoning token count exactly matches thinking_token_budget Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/sample/logits_processor/builtin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/sample/logits_processor/builtin.py b/vllm/v1/sample/logits_processor/builtin.py index d801f49b8a1c..c92f334021fc 100644 --- a/vllm/v1/sample/logits_processor/builtin.py +++ b/vllm/v1/sample/logits_processor/builtin.py @@ -453,7 +453,7 @@ def _update_think_state(self, state: dict[str, Any]): remaining_budget = max( 0, state["thinking_token_budget"] - state["think_count"] ) - state["check_count_down"] = remaining_budget + state["check_count_down"] = max(0, remaining_budget - 1) else: state["check_count_down"] = state["thinking_token_budget"] From 29bb06902eccfd0e8bdbb0999f1c5c297b722116 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 8 Mar 2026 15:46:12 +0000 Subject: [PATCH 55/61] add e2e test Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- .../openai/test_thinking_token_budget.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 tests/v1/entrypoints/openai/test_thinking_token_budget.py diff --git a/tests/v1/entrypoints/openai/test_thinking_token_budget.py b/tests/v1/entrypoints/openai/test_thinking_token_budget.py new file mode 100644 index 000000000000..7ae8039fa5f8 --- /dev/null +++ b/tests/v1/entrypoints/openai/test_thinking_token_budget.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""E2E tests for thinking_token_budget with reasoning models.""" + +import openai +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "Qwen/Qwen3-0.6B" +MESSAGES = [{"role": "user", "content": "What is 1+1? Be concise."}] +THINK_BUDGET = 5 + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--reasoning-parser", + "qwen3", + "--reasoning-config", + '{"think_start_str": "", "think_end_str": ""}', + "--max-model-len", + "2048", + "--enforce-eager", + "--no-async-scheduling", + "--gpu-memory-utilization", + "0.1", + ] + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_thinking_token_budget_mixed_requests(client: openai.AsyncOpenAI): + """Test that mixed requests (some with thinking_token_budget, some without) + complete successfully without errors.""" + + response_with_budget = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + max_tokens=100, + extra_body={"thinking_token_budget": THINK_BUDGET}, + ) + response_without_budget = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + max_tokens=100, + ) + + msg_with = response_with_budget.choices[0].message + msg_without = response_without_budget.choices[0].message + + assert msg_with.content or getattr(msg_with, "reasoning", None) + assert msg_without.content or getattr(msg_without, "reasoning", None) + + +@pytest.mark.asyncio +async def test_thinking_token_budget_limits_reasoning(client: openai.AsyncOpenAI): + """Test that thinking_token_budget limits the number of reasoning tokens. + + In streaming mode each reasoning delta corresponds to one token, so + counting non-empty reasoning_content chunks gives the exact token count. + """ + + reasoning_token_count = 0 + stream = await client.chat.completions.create( + model=MODEL_NAME, + messages=MESSAGES, + max_tokens=100, + stream=True, + extra_body={"thinking_token_budget": THINK_BUDGET}, + ) + async for chunk in stream: + delta = chunk.choices[0].delta + if getattr(delta, "reasoning", None): + reasoning_token_count += 1 + + assert reasoning_token_count == THINK_BUDGET, ( + f"reasoning tokens ({reasoning_token_count}) != " + f"thinking_token_budget ({THINK_BUDGET})" + ) From 56ea93440369860eaf939897fa49de4931f2f516 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Sun, 8 Mar 2026 15:47:31 +0000 Subject: [PATCH 56/61] remove gpu util option from e2e test Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- tests/v1/entrypoints/openai/test_thinking_token_budget.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/v1/entrypoints/openai/test_thinking_token_budget.py b/tests/v1/entrypoints/openai/test_thinking_token_budget.py index 7ae8039fa5f8..f574b07b6b81 100644 --- a/tests/v1/entrypoints/openai/test_thinking_token_budget.py +++ b/tests/v1/entrypoints/openai/test_thinking_token_budget.py @@ -25,8 +25,6 @@ def server(): "2048", "--enforce-eager", "--no-async-scheduling", - "--gpu-memory-utilization", - "0.1", ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: yield remote_server From 00df8fe507f70f4b3caf42cf73c6b533a9dc87cc Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Mon, 9 Mar 2026 08:31:32 +0000 Subject: [PATCH 57/61] make precommit Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- docs/features/reasoning_outputs.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/features/reasoning_outputs.md b/docs/features/reasoning_outputs.md index ce11db7cda5f..fcf1d9244d6e 100644 --- a/docs/features/reasoning_outputs.md +++ b/docs/features/reasoning_outputs.md @@ -257,10 +257,10 @@ If `thinking_token_budget` is not specified, no explicit reasoning limit is appl `--reasoning-config` accepts a JSON object corresponding to [ReasoningConfig][vllm.config.ReasoningConfig] with the following fields: -| Field | Type | Description | -|-------|------|-------------| -| `think_start_str` | `str \| null` | String that marks the start of reasoning content | -| `think_end_str` | `str \| null` | String that marks the end of reasoning content. Can be the reasoning parser's think end token alone (e.g., ``), or a phrase that includes it as a suffix (e.g., `I have to give the solution based on the thinking directly now.`). | +| Field | Type | Description | +|-------------------|----------------|--------------------------------------------------| +| `think_start_str` | `str \| null` | String that marks the start of reasoning content | +| `think_end_str` | `str \| null` | String that marks the end of reasoning content | !!! note `think_end_str` can include a transition phrase before the think end token. For example, setting `think_end_str` to `"I have to give the solution based on the thinking directly now."` instructs the model to emit that phrase when the budget is exhausted, making the reasoning termination more natural. From 8d5d70e3e5fa0058ec599822844677b9109a8e9e Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 13 Mar 2026 13:00:40 +0000 Subject: [PATCH 58/61] use tokenizer encode instead of convert_token_to_ids Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/reasoning.py | 8 ++++---- vllm/v1/worker/gpu_model_runner.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index a41436de104b..9af2f253a3b1 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -37,11 +37,11 @@ def initialize_token_ids(self, model_config: ModelConfig) -> None: tokenizer = cached_tokenizer_from_config(model_config=model_config) # Convert reasoning strings to token IDs - self.think_start_token_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(self.think_start_str) + self.think_start_token_ids = tokenizer.encode( + self.think_start_str, add_special_tokens=False ) - self.think_end_token_ids = tokenizer.convert_tokens_to_ids( - tokenizer.tokenize(self.think_end_str) + self.think_end_token_ids = tokenizer.encode( + self.think_end_str, add_special_tokens=False ) if not self.think_start_token_ids or not self.think_end_token_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 24751d68a7b4..2923ccf93e24 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -622,7 +622,7 @@ def __init__( # ThinkingTokenBudgetLogitsProcessor also needs output token ids to # correctly track think start/end token sequences in async scheduling. logitsprocs_need_output_token_ids=bool(custom_logitsprocs) - or (self.vllm_config.reasoning_config is not None), + or self.vllm_config.reasoning_config is not None, is_pooling_model=self.is_pooling_model, cp_kv_cache_interleave_size=self.parallel_config.cp_kv_cache_interleave_size, ) From 45bed67db63b95de1d205653ab25a7bf4be9b774 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 13 Mar 2026 13:08:55 +0000 Subject: [PATCH 59/61] raise ValueError when thinking_token_budget is set but reasoning_config is not configured Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/v1/engine/input_processor.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/v1/engine/input_processor.py b/vllm/v1/engine/input_processor.py index aab560544635..b77b9277a48d 100644 --- a/vllm/v1/engine/input_processor.py +++ b/vllm/v1/engine/input_processor.py @@ -99,6 +99,16 @@ def _validate_params( self.structured_outputs_config, self.tokenizer, ) + + if ( + params.thinking_token_budget is not None + and self.vllm_config.reasoning_config is None + ): + raise ValueError( + "thinking_token_budget is set but reasoning_config is " + "not configured. Please set --reasoning-config to use " + "thinking_token_budget." + ) elif isinstance(params, PoolingParams): supported_pooling_tasks = [ task for task in supported_tasks if task in POOLING_TASKS From 825217582e4266a2e325520a1cfb1171ba0aabc7 Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:15:42 +0000 Subject: [PATCH 60/61] make sure that think start/end token ids are derived from string Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/reasoning.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 9af2f253a3b1..172cafe95ad6 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -23,28 +23,45 @@ class ReasoningConfig: think_end_str: str = "" """String that indicates the end of reasoning content.""" - think_start_token_ids: list[int] | None = field( + _think_start_token_ids: list[int] | None = field( default=None, init=False, repr=False ) - """Token IDs derived from `think_start_str`. Set automatically by + """Private backing field for `think_start_token_ids`. Set by `initialize_token_ids`. Not intended to be configured directly.""" - think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False) - """Token IDs derived from `think_end_str`. Set automatically by + _think_end_token_ids: list[int] | None = field(default=None, init=False, repr=False) + """Private backing field for `think_end_token_ids`. Set by `initialize_token_ids`. Not intended to be configured directly.""" + @property + def think_start_token_ids(self) -> list[int] | None: + """Token IDs derived from `think_start_str`. Set automatically by + `initialize_token_ids`. Not intended to be configured directly.""" + return self._think_start_token_ids + + @property + def think_end_token_ids(self) -> list[int] | None: + """Token IDs derived from `think_end_str`. Set automatically by + `initialize_token_ids`. Not intended to be configured directly.""" + return self._think_end_token_ids + def initialize_token_ids(self, model_config: ModelConfig) -> None: """Initialize reasoning token IDs from strings using the tokenizer.""" + if ( + self._think_start_token_ids is not None + and self._think_end_token_ids is not None + ): + return + tokenizer = cached_tokenizer_from_config(model_config=model_config) - # Convert reasoning strings to token IDs - self.think_start_token_ids = tokenizer.encode( + self._think_start_token_ids = tokenizer.encode( self.think_start_str, add_special_tokens=False ) - self.think_end_token_ids = tokenizer.encode( + self._think_end_token_ids = tokenizer.encode( self.think_end_str, add_special_tokens=False ) - if not self.think_start_token_ids or not self.think_end_token_ids: + if not self._think_start_token_ids or not self._think_end_token_ids: raise ValueError( f"ReasoningConfig: failed to tokenize reasoning strings: " f"think_start_str='{self.think_start_str}', " From 4624a77e83137575b997777105745f2241f9a3de Mon Sep 17 00:00:00 2001 From: Sungjae Lee <33976427+llsj14@users.noreply.github.com> Date: Fri, 13 Mar 2026 14:19:45 +0000 Subject: [PATCH 61/61] add comment about automation related to ReasoningConfig Signed-off-by: Sungjae Lee <33976427+llsj14@users.noreply.github.com> --- vllm/config/reasoning.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/config/reasoning.py b/vllm/config/reasoning.py index 172cafe95ad6..872e05580908 100644 --- a/vllm/config/reasoning.py +++ b/vllm/config/reasoning.py @@ -18,6 +18,8 @@ class ReasoningConfig: `initialize_token_ids` and are not intended to be set directly. """ + # NOTE: These parameters are temporary, the intent is to derive them + # automatically from the reasoning parser in a future version. think_start_str: str = "" """String that indicates the start of reasoning.""" think_end_str: str = ""