Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,13 +1190,15 @@ async def init_app_state(
tool_parser=args.tool_call_parser,
reasoning_parser=args.reasoning_parser,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_force_include_usage=args.enable_force_include_usage,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
Expand Down
5 changes: 5 additions & 0 deletions vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,11 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
action='store_true',
default=False,
help="If set to True, enable prompt_tokens_details in usage.")
parser.add_argument(
"--enable-force-include-usage",
action='store_true',
default=False,
help="If set to True, including usage on every request.")
parser.add_argument(
"--enable-server-load-tracking",
action='store_true',
Expand Down
19 changes: 15 additions & 4 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ def __init__(
enable_auto_tools: bool = False,
tool_parser: Optional[str] = None,
enable_prompt_tokens_details: bool = False,
enable_force_include_usage: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)

self.response_role = response_role
self.chat_template = chat_template
Expand Down Expand Up @@ -110,6 +112,7 @@ def __init__(
"been registered") from e

self.enable_prompt_tokens_details = enable_prompt_tokens_details
self.enable_force_include_usage = enable_force_include_usage
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
Expand Down Expand Up @@ -261,8 +264,14 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, model_name,
conversation, tokenizer, request_metadata)
request,
result_generator,
request_id,
model_name,
conversation,
tokenizer,
request_metadata,
enable_force_include_usage=self.enable_force_include_usage)

try:
return await self.chat_completion_full_generator(
Expand Down Expand Up @@ -405,6 +414,7 @@ async def chat_completion_stream_generator(
conversation: list[ConversationMessage],
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
created_time = int(time.time())
chunk_object_type: Final = "chat.completion.chunk"
Expand Down Expand Up @@ -471,7 +481,8 @@ async def chat_completion_stream_generator(

stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_usage = stream_options.include_usage \
or enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
Expand Down
11 changes: 8 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ def __init__(
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
return_tokens_as_token_ids=return_tokens_as_token_ids,
enable_force_include_usage=enable_force_include_usage)
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
Expand Down Expand Up @@ -227,7 +229,8 @@ async def create_completion(
model_name,
num_prompts=num_prompts,
tokenizer=tokenizer,
request_metadata=request_metadata)
request_metadata=request_metadata,
enable_force_include_usage=self.enable_force_include_usage)

# Non-streaming response
final_res_batch: list[Optional[RequestOutput]] = [None] * num_prompts
Expand Down Expand Up @@ -289,6 +292,7 @@ async def completion_stream_generator(
num_prompts: int,
tokenizer: AnyTokenizer,
request_metadata: RequestResponseMetadata,
enable_force_include_usage: bool,
) -> AsyncGenerator[str, None]:
num_choices = 1 if request.n is None else request.n
previous_text_lens = [0] * num_choices * num_prompts
Expand All @@ -298,7 +302,8 @@ async def completion_stream_generator(

stream_options = request.stream_options
if stream_options:
include_usage = stream_options.include_usage
include_usage = stream_options.include_usage or \
enable_force_include_usage
include_continuous_usage = include_usage and \
stream_options.continuous_usage_stats
else:
Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def is_embeds_prompt(prompt: RequestPrompt) -> TypeIs[EmbedsPrompt]:

class RequestProcessingMixin(BaseModel):
"""
Mixin for request processing,
Mixin for request processing,
handling prompt preparation and engine input.
"""
request_prompts: Optional[Sequence[RequestPrompt]] = []
Expand All @@ -144,7 +144,7 @@ class RequestProcessingMixin(BaseModel):

class ResponseGenerationMixin(BaseModel):
"""
Mixin for response generation,
Mixin for response generation,
managing result generators and final batch results.
"""
result_generator: Optional[AsyncGenerator[tuple[int, Union[
Expand Down Expand Up @@ -208,6 +208,7 @@ def __init__(
*,
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
enable_force_include_usage: bool = False,
):
super().__init__()

Expand All @@ -219,6 +220,7 @@ def __init__(

self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
self.enable_force_include_usage = enable_force_include_usage

self._tokenizer_executor = ThreadPoolExecutor(max_workers=1)

Expand Down