From c213bb516ae90aabe274e33177f4841dcf922668 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 6 Jun 2025 18:01:40 +0530 Subject: [PATCH 01/14] Update --- .../openai_frontend/engine/triton_engine.py | 34 +++++++++++++++++++ .../openai_frontend/engine/utils/triton.py | 2 ++ 2 files changed, 36 insertions(+) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 32357cfdd4..c9e3e3145c 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -30,6 +30,7 @@ import json import time import uuid +import ctypes from dataclasses import dataclass from typing import ( Any, @@ -65,6 +66,7 @@ ChatCompletionStreamResponseDelta, ChatCompletionToolChoiceOption1, Choice, + CompletionUsage, CreateChatCompletionRequest, CreateChatCompletionResponse, CreateChatCompletionStreamResponse, @@ -225,6 +227,37 @@ async def chat( backend=metadata.backend, ) + prompt_tokens = None + completion_tokens = None + usage = None + + if ( + "num_input_tokens" in response.outputs + and "num_output_tokens" in response.outputs + ): + input_token_tensor = response.outputs["num_input_tokens"] + output_token_tensor = response.outputs["num_output_tokens"] + + if input_token_tensor.data_type == tritonserver.DataType.UINT32: + prompt_tokens_ptr = ctypes.cast( + input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) + ) + prompt_tokens = prompt_tokens_ptr[0] + + if output_token_tensor.data_type == tritonserver.DataType.UINT32: + completion_tokens_ptr = ctypes.cast( + output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) + ) + completion_tokens = completion_tokens_ptr[0] + + if prompt_tokens is not None and completion_tokens is not None: + total_tokens = prompt_tokens + completion_tokens + usage = CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + return CreateChatCompletionResponse( id=request_id, choices=[ @@ -239,6 +272,7 @@ async def chat( model=request.model, system_fingerprint=None, object=ObjectType.chat_completion, + usage=usage, ) def _get_chat_completion_response_message( diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index ed16c2c1b8..e323393ef9 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -100,6 +100,8 @@ def _create_vllm_inference_request( # Pass sampling_parameters as serialized JSON string input to support List # fields like 'stop' that aren't supported by TRITONSERVER_Parameters yet. inputs["sampling_parameters"] = [sampling_parameters] + inputs["return_num_input_tokens"] = np.bool_([True]) + inputs["return_num_output_tokens"] = np.bool_([True]) return model.create_request(inputs=inputs) From 4a340b69d7255c0ecb0c6baddbb6d1a6684b40a4 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Thu, 19 Jun 2025 11:24:24 +0530 Subject: [PATCH 02/14] Add "usage" support --- .../openai_frontend/engine/triton_engine.py | 90 ++++++++----- .../openai_frontend/engine/utils/triton.py | 39 ++++++ .../openai/openai_frontend/schemas/openai.py | 24 ++-- python/openai/tests/test_completions.py | 14 +- python/openai/tests/test_openai_client.py | 121 ++++++++++++++++++ 5 files changed, 243 insertions(+), 45 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index c9e3e3145c..c1c6f48ee8 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -52,6 +52,7 @@ _create_trtllm_inference_request, _create_vllm_inference_request, _get_output, + _get_usage_from_response, _get_vllm_lora_names, _validate_triton_responses_non_streaming, ) @@ -227,36 +228,7 @@ async def chat( backend=metadata.backend, ) - prompt_tokens = None - completion_tokens = None - usage = None - - if ( - "num_input_tokens" in response.outputs - and "num_output_tokens" in response.outputs - ): - input_token_tensor = response.outputs["num_input_tokens"] - output_token_tensor = response.outputs["num_output_tokens"] - - if input_token_tensor.data_type == tritonserver.DataType.UINT32: - prompt_tokens_ptr = ctypes.cast( - input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) - ) - prompt_tokens = prompt_tokens_ptr[0] - - if output_token_tensor.data_type == tritonserver.DataType.UINT32: - completion_tokens_ptr = ctypes.cast( - output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) - ) - completion_tokens = completion_tokens_ptr[0] - - if prompt_tokens is not None and completion_tokens is not None: - total_tokens = prompt_tokens + completion_tokens - usage = CompletionUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) + usage = _get_usage_from_response(response) return CreateChatCompletionResponse( id=request_id, @@ -345,7 +317,7 @@ async def completion( created = int(time.time()) if request.stream: return self._streaming_completion_iterator( - request_id, created, request.model, responses + request_id, created, request, responses ) # Response validation with decoupled models in mind @@ -354,6 +326,8 @@ async def completion( response = responses[0] text = _get_output(response) + usage = _get_usage_from_response(response) + choice = Choice( finish_reason=FinishReason.stop, index=0, @@ -367,6 +341,7 @@ async def completion( object=ObjectType.text_completion, created=created, model=request.model, + usage=usage, ) # TODO: This behavior should be tested further @@ -447,6 +422,7 @@ def _get_streaming_chat_response_chunk( request_id: str, created: int, model: str, + usage: Optional[CompletionUsage] = None, ) -> CreateChatCompletionStreamResponse: return CreateChatCompletionStreamResponse( id=request_id, @@ -455,6 +431,7 @@ def _get_streaming_chat_response_chunk( model=model, system_fingerprint=None, object=ObjectType.chat_completion_chunk, + usage=usage, ) def _get_first_streaming_chat_response( @@ -470,7 +447,7 @@ def _get_first_streaming_chat_response( finish_reason=None, ) chunk = self._get_streaming_chat_response_chunk( - choice, request_id, created, model + choice, request_id, created, model, usage=None ) return chunk @@ -496,6 +473,7 @@ async def _streaming_chat_iterator( ) previous_text = "" + usage_payload: Optional[CompletionUsage] = None chunk = self._get_first_streaming_chat_response( request_id, created, model, role @@ -505,6 +483,11 @@ async def _streaming_chat_iterator( async for response in responses: delta_text = _get_output(response) + # If this is the backend's final response for the entire inference call, + # attempt to get token counts. For vLLM, these are sent with the final packet. + if response.final: + usage_payload = _get_usage_from_response(response) + ( response_delta, finish_reason, @@ -537,11 +520,26 @@ async def _streaming_chat_iterator( finish_reason=finish_reason, ) + # All intermediate chunks have usage=None. chunk = self._get_streaming_chat_response_chunk( - choice, request_id, created, model + choice, request_id, created, model, usage=None ) yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + # After the loop, send the final usage chunk if requested via stream_options. + if request.stream_options and request.stream_options.include_usage: + if usage_payload: + final_usage_chunk = CreateChatCompletionStreamResponse( + id=request_id, + choices=[], + created=created, + model=model, + system_fingerprint=None, + object=ObjectType.chat_completion_chunk, + usage=usage_payload, + ) + yield f"data: {final_usage_chunk.model_dump_json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" def _get_streaming_response_delta( @@ -724,9 +722,19 @@ def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest): ) async def _streaming_completion_iterator( - self, request_id: str, created: int, model: str, responses: AsyncIterable + self, + request_id: str, + created: int, + request: CreateCompletionRequest, + responses: AsyncIterable, ) -> AsyncIterator[str]: + model = request.model + usage_payload: Optional[CompletionUsage] = None + async for response in responses: + if response.final: + usage_payload = _get_usage_from_response(response) + text = _get_output(response) choice = Choice( finish_reason=FinishReason.stop if response.final else None, @@ -741,10 +749,24 @@ async def _streaming_completion_iterator( object=ObjectType.text_completion, created=created, model=model, + usage=usage_payload, ) yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" + if request.stream_options and request.stream_options.include_usage: + if usage_payload: + final_usage_chunk = CreateCompletionResponse( + id=request_id, + choices=[], + system_fingerprint=None, + object=ObjectType.text_completion, + created=created, + model=model, + usage=usage_payload, + ) + yield f"data: {final_usage_chunk.model_dump_json(exclude_unset=True)}\n\n" + yield "data: [DONE]\n\n" def _validate_completion_request( diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index e323393ef9..1bd5f20a84 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -37,7 +37,9 @@ ChatCompletionNamedToolChoice, ChatCompletionToolChoiceOption1, CreateChatCompletionRequest, + CreateCompletionResponse, CreateCompletionRequest, + CompletionUsage, ) @@ -186,6 +188,43 @@ def _to_string(tensor: tritonserver.Tensor) -> str: return _construct_string_from_pointer(tensor.data_ptr + 4, tensor.size - 4) +def _get_usage_from_response( + response: tritonserver._api._response.InferenceResponse, +) -> Optional[CompletionUsage]: + """Extracts token usage statistics from a Triton inference response.""" + prompt_tokens = None + completion_tokens = None + + if ( + "num_input_tokens" in response.outputs + and "num_output_tokens" in response.outputs + ): + input_token_tensor = response.outputs["num_input_tokens"] + output_token_tensor = response.outputs["num_output_tokens"] + + if input_token_tensor.data_type == tritonserver.DataType.UINT32: + prompt_tokens_ptr = ctypes.cast( + input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) + ) + prompt_tokens = prompt_tokens_ptr[0] + + if output_token_tensor.data_type == tritonserver.DataType.UINT32: + completion_tokens_ptr = ctypes.cast( + output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32) + ) + completion_tokens = completion_tokens_ptr[0] + + if prompt_tokens is not None and completion_tokens is not None: + total_tokens = prompt_tokens + completion_tokens + return CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + + return None + + # TODO: Use tritonserver.InferenceResponse when support is published def _get_output(response: tritonserver._api._response.InferenceResponse) -> str: if "text_output" in response.outputs: diff --git a/python/openai/openai_frontend/schemas/openai.py b/python/openai/openai_frontend/schemas/openai.py index a2438e8394..8bcc188fbf 100644 --- a/python/openai/openai_frontend/schemas/openai.py +++ b/python/openai/openai_frontend/schemas/openai.py @@ -133,6 +133,10 @@ class CreateCompletionRequest(BaseModel): False, description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", ) + stream_options: Optional[StreamOptions] = Field( + None, + description="Options for streaming response. Only set this when you set `stream: true`.", + ) suffix: Optional[str] = Field( None, description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", @@ -467,6 +471,13 @@ class ResponseFormat(BaseModel): ) +class StreamOptions(BaseModel): + include_usage: Optional[bool] = Field( + False, + description='If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.', + ) + + class FunctionCall3(Enum): none = "none" auto = "auto" @@ -526,14 +537,6 @@ class Logprobs2(BaseModel): ) -class ChatCompletionFinishReason(Enum): - stop = "stop" - length = "length" - tool_calls = "tool_calls" - content_filter = "content_filter" - function_call = "function_call" - - class ChatCompletionStreamingResponseChoice(BaseModel): delta: ChatCompletionStreamResponseDelta logprobs: Optional[Logprobs2] = Field( @@ -573,6 +576,7 @@ class CreateChatCompletionStreamResponse(BaseModel): object: Object4 = Field( ..., description="The object type, which is always `chat.completion.chunk`." ) + usage: Optional[CompletionUsage] = None class CreateChatCompletionImageResponse(BaseModel): @@ -885,6 +889,10 @@ class CreateChatCompletionRequest(BaseModel): False, description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", ) + stream_options: Optional[StreamOptions] = Field( + None, + description="Options for streaming response. Only set this when you set `stream: true`.", + ) temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( 0.7, description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index 1a411a0124..ab189c4e77 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -43,11 +43,19 @@ def test_completions_defaults(self, client, model: str, prompt: str): print("Response:", response.json()) assert response.status_code == 200 + response_json = response.json() # NOTE: Could be improved to look for certain quality of response, # or tested with dummy identity model. - assert response.json()["choices"][0]["text"].strip() - # "usage" currently not supported - assert not response.json()["usage"] + assert response_json["choices"][0]["text"].strip() + # "usage" is now supported + usage = response_json.get("usage") + assert usage is not None + assert isinstance(usage["prompt_tokens"], int) + assert isinstance(usage["completion_tokens"], int) + assert isinstance(usage["total_tokens"], int) + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] @pytest.mark.parametrize( "sampling_parameter, value", diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 618e4052ed..81676de1b2 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -60,6 +60,15 @@ def test_openai_client_completion( assert completion.choices[0].text assert completion.choices[0].finish_reason == "stop" + usage = completion.usage + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + def test_openai_client_chat_completion( self, client: openai.OpenAI, model: str, messages: List[dict] ): @@ -72,6 +81,15 @@ def test_openai_client_chat_completion( assert chat_completion.choices[0].message.content assert chat_completion.choices[0].finish_reason == "stop" + usage = chat_completion.usage + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + @pytest.mark.parametrize("echo", [False, True]) def test_openai_client_completion_echo( self, client: openai.OpenAI, echo: bool, backend: str, model: str, prompt: str @@ -128,6 +146,15 @@ async def test_openai_client_completion( assert completion.choices[0].text assert completion.choices[0].finish_reason == "stop" + usage = completion.usage + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + @pytest.mark.asyncio async def test_openai_client_chat_completion( self, client: openai.AsyncOpenAI, model: str, messages: List[dict] @@ -139,6 +166,16 @@ async def test_openai_client_chat_completion( assert chat_completion.choices[0].message.content assert chat_completion.choices[0].finish_reason == "stop" + + usage = chat_completion.usage + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + print(f"Chat completion results: {chat_completion}") @pytest.mark.asyncio @@ -245,3 +282,87 @@ async def test_chat_streaming( assert len(chunks) > 1 streamed_output = "".join(chunks) assert streamed_output == output + + @pytest.mark.asyncio + async def test_chat_streaming_usage_option( + self, client: openai.AsyncOpenAI, model: str, messages: List[dict] + ): + # First, run with include_usage=False to establish a baseline chunk count. + stream_false = await client.chat.completions.create( + model=model, + messages=messages, + max_tokens=16, + stream=True, + stream_options={"include_usage": False}, + ) + chunks_false = [chunk async for chunk in stream_false] + for chunk in chunks_false: + assert chunk.usage is None, "Usage should be null when include_usage=False" + + # Now, run with include_usage=True. + stream_true = await client.chat.completions.create( + model=model, + messages=messages, + max_tokens=16, + stream=True, + stream_options={"include_usage": True}, + ) + chunks_true = [chunk async for chunk in stream_true] + + # Verify that we received exactly one extra chunk. + assert len(chunks_true) == len(chunks_false) + 1 + + # Verify the final chunk has usage data and empty choices. + final_chunk = chunks_true[-1] + assert final_chunk.usage is not None + assert len(final_chunk.choices) == 0 + usage = final_chunk.usage + assert isinstance(usage.prompt_tokens, int) and usage.prompt_tokens > 0 + assert isinstance(usage.completion_tokens, int) and usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + + # Verify other chunks have no usage data. + for chunk in chunks_true[:-1]: + assert chunk.usage is None + + @pytest.mark.asyncio + async def test_completion_streaming_usage_option( + self, client: openai.AsyncOpenAI, model: str, prompt: str + ): + # First, run with include_usage=False to establish a baseline chunk count. + stream_false = await client.completions.create( + model=model, + prompt=prompt, + max_tokens=10, + stream=True, + stream_options={"include_usage": False}, + ) + chunks_false = [chunk async for chunk in stream_false] + for chunk in chunks_false: + assert chunk.usage is None + + # Now, run with include_usage=True. + stream_true = await client.completions.create( + model=model, + prompt=prompt, + max_tokens=10, + stream=True, + stream_options={"include_usage": True}, + ) + chunks_true = [chunk async for chunk in stream_true] + + # Verify that we received exactly one extra chunk. + assert len(chunks_true) == len(chunks_false) + 1 + + # Verify the final chunk has usage data and empty choices. + final_chunk = chunks_true[-1] + assert final_chunk.usage is not None + assert len(final_chunk.choices) == 0 + usage = final_chunk.usage + assert isinstance(usage.prompt_tokens, int) and usage.prompt_tokens > 0 + assert isinstance(usage.completion_tokens, int) and usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + + # Verify other chunks have no usage data. + for chunk in chunks_true[:-1]: + assert chunk.usage is None From b396e4527a0fd7483066a89583f1472b8a825afb Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 20 Jun 2025 15:37:03 +0530 Subject: [PATCH 03/14] Update --- .../openai_frontend/engine/triton_engine.py | 11 +++--- .../openai_frontend/engine/utils/triton.py | 10 ++++-- python/openai/tests/test_chat_completions.py | 23 +++++++++--- python/openai/tests/test_completions.py | 36 +++++++++++-------- 4 files changed, 55 insertions(+), 25 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index c1c6f48ee8..2ba1062618 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -228,7 +228,7 @@ async def chat( backend=metadata.backend, ) - usage = _get_usage_from_response(response) + usage = _get_usage_from_response(response, metadata.backend) return CreateChatCompletionResponse( id=request_id, @@ -317,7 +317,7 @@ async def completion( created = int(time.time()) if request.stream: return self._streaming_completion_iterator( - request_id, created, request, responses + request_id, created, request, responses, metadata.backend ) # Response validation with decoupled models in mind @@ -326,7 +326,7 @@ async def completion( response = responses[0] text = _get_output(response) - usage = _get_usage_from_response(response) + usage = _get_usage_from_response(response, metadata.backend) choice = Choice( finish_reason=FinishReason.stop, @@ -486,7 +486,7 @@ async def _streaming_chat_iterator( # If this is the backend's final response for the entire inference call, # attempt to get token counts. For vLLM, these are sent with the final packet. if response.final: - usage_payload = _get_usage_from_response(response) + usage_payload = _get_usage_from_response(response, backend) ( response_delta, @@ -727,13 +727,14 @@ async def _streaming_completion_iterator( created: int, request: CreateCompletionRequest, responses: AsyncIterable, + backend: str, ) -> AsyncIterator[str]: model = request.model usage_payload: Optional[CompletionUsage] = None async for response in responses: if response.final: - usage_payload = _get_usage_from_response(response) + usage_payload = _get_usage_from_response(response, backend) text = _get_output(response) choice = Choice( diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 1bd5f20a84..631ac7b22b 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -37,7 +37,6 @@ ChatCompletionNamedToolChoice, ChatCompletionToolChoiceOption1, CreateChatCompletionRequest, - CreateCompletionResponse, CreateCompletionRequest, CompletionUsage, ) @@ -190,8 +189,15 @@ def _to_string(tensor: tritonserver.Tensor) -> str: def _get_usage_from_response( response: tritonserver._api._response.InferenceResponse, + backend: str, ) -> Optional[CompletionUsage]: - """Extracts token usage statistics from a Triton inference response.""" + """ + Extracts token usage statistics from a Triton inference response. + Only vLLM backend currently provides these output tensors. + """ + if backend != "vllm": + return None + prompt_tokens = None completion_tokens = None diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index 532d898788..ed6daf9e92 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -52,7 +52,7 @@ def test_chat_completions_defaults(self, client, model: str, messages: List[dict assert message["content"].strip() assert message["role"] == "assistant" # "usage" currently not supported - assert not response.json()["usage"] + assert response.json()["usage"] def test_chat_completions_system_prompt(self, client, model: str): # NOTE: Currently just sanity check that there are no issues when a @@ -497,9 +497,24 @@ def test_request_logprobs(self): def test_request_logit_bias(self): pass - @pytest.mark.skip(reason="Not Implemented Yet") - def test_usage_response(self): - pass + def test_usage_response(self, client, model: str, messages: List[dict], backend: str): + if backend != "vllm": + pytest.skip("Usage reporting is currently only supported for vLLM backend") + + response = client.post( + "/v1/chat/completions", + json={"model": model, "messages": messages}, + ) + + assert response.status_code == 200 + usage = response.json().get("usage") + assert usage is not None + assert isinstance(usage["prompt_tokens"], int) + assert isinstance(usage["completion_tokens"], int) + assert isinstance(usage["total_tokens"], int) + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] # For tests that won't use the same pytest fixture for server startup across diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index ab189c4e77..3043b02ce8 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -43,19 +43,12 @@ def test_completions_defaults(self, client, model: str, prompt: str): print("Response:", response.json()) assert response.status_code == 200 - response_json = response.json() # NOTE: Could be improved to look for certain quality of response, # or tested with dummy identity model. - assert response_json["choices"][0]["text"].strip() - # "usage" is now supported - usage = response_json.get("usage") - assert usage is not None - assert isinstance(usage["prompt_tokens"], int) - assert isinstance(usage["completion_tokens"], int) - assert isinstance(usage["total_tokens"], int) - assert usage["prompt_tokens"] > 0 - assert usage["completion_tokens"] > 0 - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert response.json()["choices"][0]["text"].strip() + # "usage" is now validated in its own test. + # Depending on backend, it may or may not be present. + assert "usage" in response.json() @pytest.mark.parametrize( "sampling_parameter, value", @@ -375,6 +368,21 @@ def test_lora(self): def test_multi_lora(self): pass - @pytest.mark.skip(reason="Not Implemented Yet") - def test_usage_response(self): - pass + def test_usage_response(self, client, model: str, prompt: str, backend: str): + if backend != "vllm": + pytest.skip("Usage reporting is currently only supported for vLLM backend") + + response = client.post( + "/v1/completions", + json={"model": model, "prompt": prompt}, + ) + + assert response.status_code == 200 + usage = response.json().get("usage") + assert usage is not None + assert isinstance(usage["prompt_tokens"], int) + assert isinstance(usage["completion_tokens"], int) + assert isinstance(usage["total_tokens"], int) + assert usage["prompt_tokens"] > 0 + assert usage["completion_tokens"] > 0 + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] From 657578aab6ade815b649c499a85b279b92a085d2 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 20 Jun 2025 16:48:15 +0530 Subject: [PATCH 04/14] Update --- python/openai/tests/test_chat_completions.py | 23 +++++++++++++++----- python/openai/tests/test_completions.py | 15 ++++++++----- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index ed6daf9e92..c1c01c701b 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -41,18 +41,25 @@ class TestChatCompletions: def client(self, fastapi_client_class_scope): yield fastapi_client_class_scope - def test_chat_completions_defaults(self, client, model: str, messages: List[dict]): + def test_chat_completions_defaults( + self, client, model: str, messages: List[dict], backend: str + ): response = client.post( "/v1/chat/completions", json={"model": model, "messages": messages}, ) assert response.status_code == 200 - message = response.json()["choices"][0]["message"] + response_json = response.json() + message = response_json["choices"][0]["message"] assert message["content"].strip() assert message["role"] == "assistant" - # "usage" currently not supported - assert response.json()["usage"] + + usage = response_json.get("usage") + if backend == "vllm": + assert usage is not None + else: + assert usage is None def test_chat_completions_system_prompt(self, client, model: str): # NOTE: Currently just sanity check that there are no issues when a @@ -497,7 +504,9 @@ def test_request_logprobs(self): def test_request_logit_bias(self): pass - def test_usage_response(self, client, model: str, messages: List[dict], backend: str): + def test_usage_response( + self, client, model: str, messages: List[dict], backend: str + ): if backend != "vllm": pytest.skip("Usage reporting is currently only supported for vLLM backend") @@ -514,7 +523,9 @@ def test_usage_response(self, client, model: str, messages: List[dict], backend: assert isinstance(usage["total_tokens"], int) assert usage["prompt_tokens"] > 0 assert usage["completion_tokens"] > 0 - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) # For tests that won't use the same pytest fixture for server startup across diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index 3043b02ce8..e1e32097a5 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -35,7 +35,7 @@ class TestCompletions: def client(self, fastapi_client_class_scope): yield fastapi_client_class_scope - def test_completions_defaults(self, client, model: str, prompt: str): + def test_completions_defaults(self, client, model: str, prompt: str, backend: str): response = client.post( "/v1/completions", json={"model": model, "prompt": prompt}, @@ -46,9 +46,12 @@ def test_completions_defaults(self, client, model: str, prompt: str): # NOTE: Could be improved to look for certain quality of response, # or tested with dummy identity model. assert response.json()["choices"][0]["text"].strip() - # "usage" is now validated in its own test. - # Depending on backend, it may or may not be present. - assert "usage" in response.json() + + usage = response.json().get("usage") + if backend == "vllm": + assert usage is not None + else: + assert usage is None @pytest.mark.parametrize( "sampling_parameter, value", @@ -385,4 +388,6 @@ def test_usage_response(self, client, model: str, prompt: str, backend: str): assert isinstance(usage["total_tokens"], int) assert usage["prompt_tokens"] > 0 assert usage["completion_tokens"] > 0 - assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert ( + usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + ) From 25cb0a7cd0822f7323e4836b749af1224eb46201 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 20 Jun 2025 17:17:59 +0530 Subject: [PATCH 05/14] Update --- python/openai/tests/test_chat_completions.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index c1c01c701b..b994f66c7f 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -50,12 +50,11 @@ def test_chat_completions_defaults( ) assert response.status_code == 200 - response_json = response.json() - message = response_json["choices"][0]["message"] + message = response.json()["choices"][0]["message"] assert message["content"].strip() assert message["role"] == "assistant" - usage = response_json.get("usage") + usage = response.json().get("usage") if backend == "vllm": assert usage is not None else: From cd669d9767fa4605a7e7b0c5c95803d468a0301e Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 20 Jun 2025 18:25:26 +0530 Subject: [PATCH 06/14] Update tests --- .../openai_frontend/engine/triton_engine.py | 5 +- .../openai_frontend/engine/utils/triton.py | 3 +- python/openai/tests/test_openai_client.py | 130 +++++++++++++++--- 3 files changed, 116 insertions(+), 22 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 2ba1062618..6ea62ffe8a 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -30,7 +30,6 @@ import json import time import uuid -import ctypes from dataclasses import dataclass from typing import ( Any, @@ -473,7 +472,7 @@ async def _streaming_chat_iterator( ) previous_text = "" - usage_payload: Optional[CompletionUsage] = None + usage_payload = None chunk = self._get_first_streaming_chat_response( request_id, created, model, role @@ -730,7 +729,7 @@ async def _streaming_completion_iterator( backend: str, ) -> AsyncIterator[str]: model = request.model - usage_payload: Optional[CompletionUsage] = None + usage_payload = None async for response in responses: if response.final: diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 631ac7b22b..19bd2c39fb 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -193,8 +193,9 @@ def _get_usage_from_response( ) -> Optional[CompletionUsage]: """ Extracts token usage statistics from a Triton inference response. - Only vLLM backend currently provides these output tensors. """ + # TODO: Remove this check once TRT-LLM backend supports both "num_input_tokens" + # and "num_output_tokens", and update the test cases accordingly. if backend != "vllm": return None diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 81676de1b2..b36d07365c 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -285,84 +285,178 @@ async def test_chat_streaming( @pytest.mark.asyncio async def test_chat_streaming_usage_option( - self, client: openai.AsyncOpenAI, model: str, messages: List[dict] + self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str ): - # First, run with include_usage=False to establish a baseline chunk count. + if backend != "vllm": + pytest.skip("Usage reporting is currently only supported for vLLM backend") + + # Get usage and content from a non-streaming call stream_false = await client.chat.completions.create( model=model, messages=messages, max_tokens=16, + temperature=0, + seed=0, + stream=False, + ) + usage_stream_false = stream_false.usage + stream_false_output = stream_false.choices[0].message.content + assert usage_stream_false is not None + assert stream_false_output is not None + + # First, run with include_usage=False. + stream_options_false = await client.chat.completions.create( + model=model, + messages=messages, + max_tokens=16, + temperature=0, + seed=0, stream=True, stream_options={"include_usage": False}, ) - chunks_false = [chunk async for chunk in stream_false] + chunks_false = [chunk async for chunk in stream_options_false] for chunk in chunks_false: assert chunk.usage is None, "Usage should be null when include_usage=False" + stream_options_false_output = "".join( + c.choices[0].delta.content + for c in chunks_false + if c.choices and c.choices[0].delta.content + ) # Now, run with include_usage=True. - stream_true = await client.chat.completions.create( + stream_options_true = await client.chat.completions.create( model=model, messages=messages, max_tokens=16, + temperature=0, + seed=0, stream=True, stream_options={"include_usage": True}, ) - chunks_true = [chunk async for chunk in stream_true] + chunks_true = [chunk async for chunk in stream_options_true] + content_chunks = [c for c in chunks_true if c.usage is None] # Verify that we received exactly one extra chunk. assert len(chunks_true) == len(chunks_false) + 1 + # Verify content is consistent + stream_options_true_output = "".join( + c.choices[0].delta.content + for c in content_chunks + if c.choices and c.choices[0].delta.content + ) + assert stream_options_true_output == stream_false_output + assert stream_options_true_output == stream_options_false_output + # Verify the final chunk has usage data and empty choices. final_chunk = chunks_true[-1] assert final_chunk.usage is not None assert len(final_chunk.choices) == 0 - usage = final_chunk.usage - assert isinstance(usage.prompt_tokens, int) and usage.prompt_tokens > 0 - assert isinstance(usage.completion_tokens, int) and usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + usage_stream_options_true = final_chunk.usage + assert ( + isinstance(usage_stream_options_true.prompt_tokens, int) + and usage_stream_options_true.prompt_tokens > 0 + ) + assert ( + isinstance(usage_stream_options_true.completion_tokens, int) + and usage_stream_options_true.completion_tokens > 0 + ) + assert ( + usage_stream_options_true.total_tokens + == usage_stream_options_true.prompt_tokens + + usage_stream_options_true.completion_tokens + ) # Verify other chunks have no usage data. for chunk in chunks_true[:-1]: assert chunk.usage is None + # Assert usage is consistent between streaming and non-streaming calls + assert usage_stream_false.model_dump() == usage_stream_options_true.model_dump() + @pytest.mark.asyncio async def test_completion_streaming_usage_option( - self, client: openai.AsyncOpenAI, model: str, prompt: str + self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str ): - # First, run with include_usage=False to establish a baseline chunk count. + if backend != "vllm": + pytest.skip("Usage reporting is currently only supported for vLLM backend") + + # Get baseline usage and content from a non-streaming call stream_false = await client.completions.create( model=model, prompt=prompt, max_tokens=10, + temperature=0.0, + stream=False, + seed=0, + ) + usage_stream_false = stream_false.usage + stream_false_output = stream_false.choices[0].text + assert usage_stream_false is not None + assert stream_false_output is not None + + # First, run with include_usage=False to establish a baseline chunk count and content. + stream_options_false = await client.completions.create( + model=model, + prompt=prompt, + max_tokens=10, + temperature=0.0, + seed=0, stream=True, stream_options={"include_usage": False}, ) - chunks_false = [chunk async for chunk in stream_false] + chunks_false = [chunk async for chunk in stream_options_false] for chunk in chunks_false: assert chunk.usage is None + stream_options_false_output = "".join( + c.choices[0].text for c in chunks_false if c.choices and c.choices[0].text + ) # Now, run with include_usage=True. - stream_true = await client.completions.create( + stream_options_true = await client.completions.create( model=model, prompt=prompt, max_tokens=10, + temperature=0.0, stream=True, + seed=0, stream_options={"include_usage": True}, ) - chunks_true = [chunk async for chunk in stream_true] + chunks_true = [chunk async for chunk in stream_options_true] + content_chunks = [c for c in chunks_true if c.usage is None] # Verify that we received exactly one extra chunk. assert len(chunks_true) == len(chunks_false) + 1 + # Verify content is consistent + stream_options_true_output = "".join( + c.choices[0].text for c in content_chunks if c.choices and c.choices[0].text + ) + assert stream_options_true_output == stream_false_output + assert stream_options_true_output == stream_options_false_output + # Verify the final chunk has usage data and empty choices. final_chunk = chunks_true[-1] assert final_chunk.usage is not None assert len(final_chunk.choices) == 0 - usage = final_chunk.usage - assert isinstance(usage.prompt_tokens, int) and usage.prompt_tokens > 0 - assert isinstance(usage.completion_tokens, int) and usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + usage_stream_options_true = final_chunk.usage + assert ( + isinstance(usage_stream_options_true.prompt_tokens, int) + and usage_stream_options_true.prompt_tokens > 0 + ) + assert ( + isinstance(usage_stream_options_true.completion_tokens, int) + and usage_stream_options_true.completion_tokens > 0 + ) + assert ( + usage_stream_options_true.total_tokens + == usage_stream_options_true.prompt_tokens + + usage_stream_options_true.completion_tokens + ) # Verify other chunks have no usage data. for chunk in chunks_true[:-1]: assert chunk.usage is None + + # Assert usage is consistent between streaming and non-streaming calls + assert usage_stream_false.model_dump() == usage_stream_options_true.model_dump() From a31c3e289b5fd8a477575391d613c8ffc30e34f4 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 20 Jun 2025 18:49:42 +0530 Subject: [PATCH 07/14] Update --- .../openai/openai_frontend/schemas/openai.py | 6 +-- python/openai/tests/test_chat_completions.py | 2 +- python/openai/tests/test_completions.py | 2 +- python/openai/tests/test_openai_client.py | 53 +++++++++++-------- 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/python/openai/openai_frontend/schemas/openai.py b/python/openai/openai_frontend/schemas/openai.py index 8bcc188fbf..a0baf24e8f 100644 --- a/python/openai/openai_frontend/schemas/openai.py +++ b/python/openai/openai_frontend/schemas/openai.py @@ -135,7 +135,7 @@ class CreateCompletionRequest(BaseModel): ) stream_options: Optional[StreamOptions] = Field( None, - description="Options for streaming response. Only set this when you set `stream: true`.", + description="Options for streaming responses. Only use when `stream` is set to `true`.", ) suffix: Optional[str] = Field( None, @@ -474,7 +474,7 @@ class ResponseFormat(BaseModel): class StreamOptions(BaseModel): include_usage: Optional[bool] = Field( False, - description='If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.', + description="If enabled, an additional chunk is sent before the `data: [DONE]` message. That chunk’s `usage` field reports the total token usage for the request and its `choices` array is always empty. All other chunks include a `usage` field with a null value.", ) @@ -891,7 +891,7 @@ class CreateChatCompletionRequest(BaseModel): ) stream_options: Optional[StreamOptions] = Field( None, - description="Options for streaming response. Only set this when you set `stream: true`.", + description="Options for streaming responses. Only use when `stream` is set to `true`.", ) temperature: Optional[confloat(ge=0.0, le=2.0)] = Field( 0.7, diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index b994f66c7f..8b904617af 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -507,7 +507,7 @@ def test_usage_response( self, client, model: str, messages: List[dict], backend: str ): if backend != "vllm": - pytest.skip("Usage reporting is currently only supported for vLLM backend") + pytest.skip("Usage reporting is currently available only for the vLLM backend.") response = client.post( "/v1/chat/completions", diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index e1e32097a5..3871edb2d5 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -373,7 +373,7 @@ def test_multi_lora(self): def test_usage_response(self, client, model: str, prompt: str, backend: str): if backend != "vllm": - pytest.skip("Usage reporting is currently only supported for vLLM backend") + pytest.skip("Usage reporting is currently available only for the vLLM backend.") response = client.post( "/v1/completions", diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index b36d07365c..ae82a5a93e 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -272,6 +272,7 @@ async def test_chat_streaming( chunks.append(delta.content) if chunk.choices[0].finish_reason is not None: finish_reason_count += 1 + assert chunk.usage is None # finish reason should only return in last block assert finish_reason_count == 1 @@ -288,15 +289,19 @@ async def test_chat_streaming_usage_option( self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str ): if backend != "vllm": - pytest.skip("Usage reporting is currently only supported for vLLM backend") + pytest.skip("Usage reporting is currently available only for the vLLM backend.") + seed = 0 + temperature = 0.0 + max_tokens = 16 + # Get usage and content from a non-streaming call stream_false = await client.chat.completions.create( model=model, messages=messages, - max_tokens=16, - temperature=0, - seed=0, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, stream=False, ) usage_stream_false = stream_false.usage @@ -308,9 +313,9 @@ async def test_chat_streaming_usage_option( stream_options_false = await client.chat.completions.create( model=model, messages=messages, - max_tokens=16, - temperature=0, - seed=0, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, stream=True, stream_options={"include_usage": False}, ) @@ -327,9 +332,9 @@ async def test_chat_streaming_usage_option( stream_options_true = await client.chat.completions.create( model=model, messages=messages, - max_tokens=16, - temperature=0, - seed=0, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, stream=True, stream_options={"include_usage": True}, ) @@ -379,29 +384,33 @@ async def test_completion_streaming_usage_option( self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str ): if backend != "vllm": - pytest.skip("Usage reporting is currently only supported for vLLM backend") + pytest.skip("Usage reporting is currently available only for the vLLM backend.") - # Get baseline usage and content from a non-streaming call + seed = 0 + temperature = 0.0 + max_tokens = 16 + + # Get usage and content from a non-streaming call stream_false = await client.completions.create( model=model, prompt=prompt, - max_tokens=10, - temperature=0.0, + max_tokens=max_tokens, + temperature=temperature, stream=False, - seed=0, + seed=seed, ) usage_stream_false = stream_false.usage stream_false_output = stream_false.choices[0].text assert usage_stream_false is not None assert stream_false_output is not None - # First, run with include_usage=False to establish a baseline chunk count and content. + # First, run with include_usage=False. stream_options_false = await client.completions.create( model=model, prompt=prompt, - max_tokens=10, - temperature=0.0, - seed=0, + max_tokens=max_tokens, + temperature=temperature, + seed=seed, stream=True, stream_options={"include_usage": False}, ) @@ -416,10 +425,10 @@ async def test_completion_streaming_usage_option( stream_options_true = await client.completions.create( model=model, prompt=prompt, - max_tokens=10, - temperature=0.0, + max_tokens=max_tokens, + temperature=temperature, stream=True, - seed=0, + seed=seed, stream_options={"include_usage": True}, ) chunks_true = [chunk async for chunk in stream_options_true] From 05338ee0d2b04d02e597c1e2f20178ba38d555a0 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Tue, 24 Jun 2025 18:26:13 +0530 Subject: [PATCH 08/14] Update --- .../openai_frontend/engine/triton_engine.py | 39 ++++++++++++------- .../openai_frontend/engine/utils/triton.py | 38 +++++++++++++++++- python/openai/tests/test_lora.py | 12 ++++++ 3 files changed, 74 insertions(+), 15 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 6ea62ffe8a..28740e98f2 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -48,6 +48,7 @@ from engine.utils.tokenizer import get_tokenizer from engine.utils.tool_call_parsers import ToolCallParser, ToolParserManager from engine.utils.triton import ( + _StreamingUsageAccumulator, _create_trtllm_inference_request, _create_vllm_inference_request, _get_output, @@ -472,7 +473,13 @@ async def _streaming_chat_iterator( ) previous_text = "" - usage_payload = None + include_usage = ( + # TODO: Remove backend check condition once tensorrt-llm backend also supports usage + backend == "vllm" + and request.stream_options + and request.stream_options.include_usage + ) + usage_accumulator = _StreamingUsageAccumulator(backend) chunk = self._get_first_streaming_chat_response( request_id, created, model, role @@ -481,11 +488,8 @@ async def _streaming_chat_iterator( async for response in responses: delta_text = _get_output(response) - - # If this is the backend's final response for the entire inference call, - # attempt to get token counts. For vLLM, these are sent with the final packet. - if response.final: - usage_payload = _get_usage_from_response(response, backend) + if include_usage: + usage_accumulator.update(response) ( response_delta, @@ -525,8 +529,9 @@ async def _streaming_chat_iterator( ) yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - # After the loop, send the final usage chunk if requested via stream_options. - if request.stream_options and request.stream_options.include_usage: + # Send the final usage chunk if requested via stream_options. + if include_usage: + usage_payload = usage_accumulator.get_final_usage() if usage_payload: final_usage_chunk = CreateChatCompletionStreamResponse( id=request_id, @@ -729,11 +734,17 @@ async def _streaming_completion_iterator( backend: str, ) -> AsyncIterator[str]: model = request.model - usage_payload = None + include_usage = ( + # TODO: Remove backend check condition once tensorrt-llm backend also supports usage + backend == "vllm" + and request.stream_options + and request.stream_options.include_usage + ) + usage_accumulator = _StreamingUsageAccumulator(backend) async for response in responses: - if response.final: - usage_payload = _get_usage_from_response(response, backend) + if include_usage: + usage_accumulator.update(response) text = _get_output(response) choice = Choice( @@ -749,12 +760,14 @@ async def _streaming_completion_iterator( object=ObjectType.text_completion, created=created, model=model, - usage=usage_payload, + usage=None, ) yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n" - if request.stream_options and request.stream_options.include_usage: + # Send the final usage chunk if requested via stream_options. + if include_usage: + usage_payload = usage_accumulator.get_final_usage() if usage_payload: final_usage_chunk = CreateCompletionResponse( id=request_id, diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 19bd2c39fb..7ecffdec08 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -27,7 +27,7 @@ import json import os import re -from dataclasses import asdict +from dataclasses import asdict, dataclass, field from typing import Iterable, List, Optional, Union import numpy as np @@ -187,6 +187,40 @@ def _to_string(tensor: tritonserver.Tensor) -> str: return _construct_string_from_pointer(tensor.data_ptr + 4, tensor.size - 4) +@dataclass +class _StreamingUsageAccumulator: + """Helper class to accumulate token usage from a streaming response.""" + + backend: str + prompt_tokens: int = 0 + completion_tokens: int = 0 + _prompt_tokens_set: bool = field(init=False, default=False) + + def update(self, response: tritonserver.InferenceResponse): + """Extracts usage from a response and updates the token counts.""" + usage = _get_usage_from_response(response, self.backend) + if usage: + # The prompt_tokens is received with every chunk but should only be set once. + if not self._prompt_tokens_set: + self.prompt_tokens = usage.prompt_tokens + self._prompt_tokens_set = True + self.completion_tokens += usage.completion_tokens + + def get_final_usage(self) -> Optional[CompletionUsage]: + """ + Returns the final populated CompletionUsage object if any tokens were tracked. + """ + # If _prompt_tokens_set is True, it means we have received and processed + # at least one valid usage payload. + if self._prompt_tokens_set: + return CompletionUsage( + prompt_tokens=self.prompt_tokens, + completion_tokens=self.completion_tokens, + total_tokens=self.prompt_tokens + self.completion_tokens, + ) + return None + + def _get_usage_from_response( response: tritonserver._api._response.InferenceResponse, backend: str, @@ -195,7 +229,7 @@ def _get_usage_from_response( Extracts token usage statistics from a Triton inference response. """ # TODO: Remove this check once TRT-LLM backend supports both "num_input_tokens" - # and "num_output_tokens", and update the test cases accordingly. + # and "num_output_tokens", and also update the test cases accordingly. if backend != "vllm": return None diff --git a/python/openai/tests/test_lora.py b/python/openai/tests/test_lora.py index 2c89789bbc..4a5ff10361 100644 --- a/python/openai/tests/test_lora.py +++ b/python/openai/tests/test_lora.py @@ -131,6 +131,18 @@ def _create_model_repository_mock_llm(self): name: "exclude_input_in_output" data_type: TYPE_BOOL dims: [ 1 ] + }, + { + name: "return_num_input_tokens" + data_type: TYPE_BOOL + dims: [1] + optional: true + }, + { + name: "return_num_output_tokens" + data_type: TYPE_BOOL + dims: [1] + optional: true } ] output [ From 75412f2be993ae8e5e7a1cbfbe10345e327dbb6f Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Tue, 24 Jun 2025 23:12:43 +0530 Subject: [PATCH 09/14] Update tests --- .../openai_frontend/engine/triton_engine.py | 1 - python/openai/tests/test_openai_client.py | 88 +++++++++++-------- 2 files changed, 52 insertions(+), 37 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 2e26487884..5cdacc5719 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -531,7 +531,6 @@ async def _streaming_chat_iterator( finish_reason=finish_reason, ) - # All intermediate chunks have usage=None. chunk = self._get_streaming_chat_response_chunk( choice, request_id, created, model, usage=None ) diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 47eb1ce63d..09be380e63 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -49,7 +49,7 @@ def test_openai_client_models(self, client: openai.OpenAI, backend: str): raise Exception(f"Unexpected backend {backend=}") def test_openai_client_completion( - self, client: openai.OpenAI, model: str, prompt: str + self, client: openai.OpenAI, model: str, prompt: str, backend: str ): completion = client.completions.create( prompt=prompt, @@ -61,16 +61,19 @@ def test_openai_client_completion( assert completion.choices[0].finish_reason == "stop" usage = completion.usage - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + if backend == "vllm": + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + else: + assert usage is None def test_openai_client_chat_completion( - self, client: openai.OpenAI, model: str, messages: List[dict] + self, client: openai.OpenAI, model: str, messages: List[dict], backend: str ): chat_completion = client.chat.completions.create( messages=messages, @@ -82,13 +85,16 @@ def test_openai_client_chat_completion( assert chat_completion.choices[0].finish_reason == "stop" usage = chat_completion.usage - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + if backend == "vllm": + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + else: + assert usage is None @pytest.mark.parametrize("echo", [False, True]) def test_openai_client_completion_echo( @@ -135,7 +141,7 @@ async def test_openai_client_models(self, client: openai.AsyncOpenAI, backend: s @pytest.mark.asyncio async def test_openai_client_completion( - self, client: openai.AsyncOpenAI, model: str, prompt: str + self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str ): completion = await client.completions.create( prompt=prompt, @@ -147,17 +153,20 @@ async def test_openai_client_completion( assert completion.choices[0].finish_reason == "stop" usage = completion.usage - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + if backend == "vllm": + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + else: + assert usage is None @pytest.mark.asyncio async def test_openai_client_chat_completion( - self, client: openai.AsyncOpenAI, model: str, messages: List[dict] + self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str ): chat_completion = await client.chat.completions.create( messages=messages, @@ -168,13 +177,16 @@ async def test_openai_client_chat_completion( assert chat_completion.choices[0].finish_reason == "stop" usage = chat_completion.usage - assert usage is not None - assert isinstance(usage.prompt_tokens, int) - assert isinstance(usage.completion_tokens, int) - assert isinstance(usage.total_tokens, int) - assert usage.prompt_tokens > 0 - assert usage.completion_tokens > 0 - assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + if backend == "vllm": + assert usage is not None + assert isinstance(usage.prompt_tokens, int) + assert isinstance(usage.completion_tokens, int) + assert isinstance(usage.total_tokens, int) + assert usage.prompt_tokens > 0 + assert usage.completion_tokens > 0 + assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens + else: + assert usage is None print(f"Chat completion results: {chat_completion}") @@ -291,12 +303,14 @@ async def test_chat_streaming_usage_option( self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str ): if backend != "vllm": - pytest.skip("Usage reporting is currently available only for the vLLM backend.") + pytest.skip( + "Usage reporting is currently available only for the vLLM backend." + ) seed = 0 temperature = 0.0 max_tokens = 16 - + # Get usage and content from a non-streaming call stream_false = await client.chat.completions.create( model=model, @@ -386,12 +400,14 @@ async def test_completion_streaming_usage_option( self, client: openai.AsyncOpenAI, model: str, prompt: str, backend: str ): if backend != "vllm": - pytest.skip("Usage reporting is currently available only for the vLLM backend.") + pytest.skip( + "Usage reporting is currently available only for the vLLM backend." + ) seed = 0 temperature = 0.0 max_tokens = 16 - + # Get usage and content from a non-streaming call stream_false = await client.completions.create( model=model, From 498ad7abdc4dbec9793db20ee3c3d2310a766e07 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Wed, 25 Jun 2025 10:49:00 +0000 Subject: [PATCH 10/14] Update --- python/openai/README.md | 22 +++++++++++++------ .../openai_frontend/engine/triton_engine.py | 2 +- .../openai_frontend/engine/utils/triton.py | 2 +- python/openai/tests/test_chat_completions.py | 4 +++- python/openai/tests/test_completions.py | 4 +++- 5 files changed, 23 insertions(+), 11 deletions(-) diff --git a/python/openai/README.md b/python/openai/README.md index 3d770caa2e..c5e05cee90 100644 --- a/python/openai/README.md +++ b/python/openai/README.md @@ -98,7 +98,7 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ ```json { - "id": "cmpl-6930b296-7ef8-11ef-bdd1-107c6149ca79", + "id": "cmpl-0242093d-51ae-11f0-b339-e7480668bfbe",, "choices": [ { "finish_reason": "stop", @@ -113,11 +113,15 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ "logprobs": null } ], - "created": 1727679085, + "created": 1750846825, "model": "llama-3.1-8b-instruct", "system_fingerprint": null, "object": "chat.completion", - "usage": null + "usage": { + "completion_tokens": 7, + "prompt_tokens": 42, + "total_tokens": 49 + } } ``` @@ -138,20 +142,24 @@ curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json' ```json { - "id": "cmpl-d51df75c-7ef8-11ef-bdd1-107c6149ca79", + "id": "cmpl-58fba3a0-51ae-11f0-859d-e7480668bfbe", "choices": [ { "finish_reason": "stop", "index": 0, "logprobs": null, - "text": " a field of computer science that focuses on developing algorithms that allow computers to learn from" + "text": " an amazing field that can truly understand the hidden patterns that exist in the data," } ], - "created": 1727679266, + "created": 1750846970, "model": "llama-3.1-8b-instruct", "system_fingerprint": null, "object": "text_completion", - "usage": null + "usage": { + "completion_tokens": 16, + "prompt_tokens": 4, + "total_tokens": 20 + } } ``` diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 5cdacc5719..c0af2e4b93 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -48,12 +48,12 @@ from engine.utils.tokenizer import get_tokenizer from engine.utils.tool_call_parsers import ToolCallParser, ToolParserManager from engine.utils.triton import ( - _StreamingUsageAccumulator, _create_trtllm_inference_request, _create_vllm_inference_request, _get_output, _get_usage_from_response, _get_vllm_lora_names, + _StreamingUsageAccumulator, _validate_triton_responses_non_streaming, ) from schemas.openai import ( diff --git a/python/openai/openai_frontend/engine/utils/triton.py b/python/openai/openai_frontend/engine/utils/triton.py index 42a8aa4a87..3104c49911 100644 --- a/python/openai/openai_frontend/engine/utils/triton.py +++ b/python/openai/openai_frontend/engine/utils/triton.py @@ -36,9 +36,9 @@ from schemas.openai import ( ChatCompletionNamedToolChoice, ChatCompletionToolChoiceOption1, + CompletionUsage, CreateChatCompletionRequest, CreateCompletionRequest, - CompletionUsage, ) diff --git a/python/openai/tests/test_chat_completions.py b/python/openai/tests/test_chat_completions.py index 54a3372327..5402be451d 100644 --- a/python/openai/tests/test_chat_completions.py +++ b/python/openai/tests/test_chat_completions.py @@ -540,7 +540,9 @@ def test_usage_response( self, client, model: str, messages: List[dict], backend: str ): if backend != "vllm": - pytest.skip("Usage reporting is currently available only for the vLLM backend.") + pytest.skip( + "Usage reporting is currently available only for the vLLM backend." + ) response = client.post( "/v1/chat/completions", diff --git a/python/openai/tests/test_completions.py b/python/openai/tests/test_completions.py index 3871edb2d5..9ec3ffe7f7 100644 --- a/python/openai/tests/test_completions.py +++ b/python/openai/tests/test_completions.py @@ -373,7 +373,9 @@ def test_multi_lora(self): def test_usage_response(self, client, model: str, prompt: str, backend: str): if backend != "vllm": - pytest.skip("Usage reporting is currently available only for the vLLM backend.") + pytest.skip( + "Usage reporting is currently available only for the vLLM backend." + ) response = client.post( "/v1/completions", From 07ca9c748cc534414888c494b4da6c25f55d3579 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Wed, 25 Jun 2025 22:36:06 +0530 Subject: [PATCH 11/14] Update python/openai/README.md Co-authored-by: richardhuo-nv --- python/openai/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/openai/README.md b/python/openai/README.md index c5e05cee90..39b10e201b 100644 --- a/python/openai/README.md +++ b/python/openai/README.md @@ -98,7 +98,7 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/ ```json { - "id": "cmpl-0242093d-51ae-11f0-b339-e7480668bfbe",, + "id": "cmpl-0242093d-51ae-11f0-b339-e7480668bfbe", "choices": [ { "finish_reason": "stop", From 960290b60bb9d7c4d1887d9b33e6330f3b8df08a Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Thu, 26 Jun 2025 23:58:35 +0530 Subject: [PATCH 12/14] Update request validation --- .../openai_frontend/engine/triton_engine.py | 34 ++++++++--- python/openai/tests/test_openai_client.py | 57 +++++++++++++++++++ 2 files changed, 83 insertions(+), 8 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index c0af2e4b93..351397d223 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -482,10 +482,7 @@ async def _streaming_chat_iterator( previous_text = "" include_usage = ( - # TODO: Remove backend check condition once tensorrt-llm backend also supports usage - backend == "vllm" - and request.stream_options - and request.stream_options.include_usage + request.stream_options and request.stream_options.include_usage ) usage_accumulator = _StreamingUsageAccumulator(backend) @@ -697,6 +694,18 @@ def _validate_chat_request( self._verify_chat_tool_call_settings(request=request) + if request.stream_options and not request.stream: + raise Exception("`stream_options` can only be used when `stream` is True") + + if ( + request.stream_options + and request.stream_options.include_usage + and metadata.backend != "vllm" + ): + raise Exception( + "`stream_options.include_usage` is currently only supported for the vLLM backend" + ) + def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest): if ( request.tool_choice @@ -742,10 +751,7 @@ async def _streaming_completion_iterator( ) -> AsyncIterator[str]: model = request.model include_usage = ( - # TODO: Remove backend check condition once tensorrt-llm backend also supports usage - backend == "vllm" - and request.stream_options - and request.stream_options.include_usage + request.stream_options and request.stream_options.include_usage ) usage_accumulator = _StreamingUsageAccumulator(backend) @@ -839,6 +845,18 @@ def _validate_completion_request( if request.logit_bias is not None or request.logprobs is not None: raise Exception("logit bias and log probs not supported") + if request.stream_options and not request.stream: + raise Exception("`stream_options` can only be used when `stream` is True") + + if ( + request.stream_options + and request.stream_options.include_usage + and metadata.backend != "vllm" + ): + raise Exception( + "`stream_options.include_usage` is currently only supported for the vLLM backend" + ) + def _should_stream_with_auto_tool_parsing( self, request: CreateChatCompletionRequest ): diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index 09be380e63..c591e1c45a 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -487,3 +487,60 @@ async def test_completion_streaming_usage_option( # Assert usage is consistent between streaming and non-streaming calls assert usage_stream_false.model_dump() == usage_stream_options_true.model_dump() + + @pytest.mark.asyncio + async def test_stream_options_without_streaming( + self, client: openai.AsyncOpenAI, model: str, prompt: str + ): + with pytest.raises(openai.BadRequestError) as e: + await client.completions.create( + model=model, + prompt=prompt, + stream=False, + stream_options={"include_usage": True}, + ) + assert "`stream_options` can only be used when `stream` is True" in str( + e.value + ) + + with pytest.raises(openai.BadRequestError) as e: + await client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + stream=False, + stream_options={"include_usage": True}, + ) + assert "`stream_options` can only be used when `stream` is True" in str( + e.value + ) + + @pytest.mark.asyncio + async def test_streaming_usage_unsupported_backend( + self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str + ): + if backend == "vllm": + pytest.skip("This test is for backends that do not support usage reporting.") + + with pytest.raises(openai.BadRequestError) as e: + await client.completions.create( + model=model, + prompt="Test prompt", + stream=True, + stream_options={"include_usage": True}, + ) + assert ( + "`stream_options.include_usage` is currently only supported for the vLLM backend" + in str(e.value) + ) + + with pytest.raises(openai.BadRequestError) as e: + await client.chat.completions.create( + model=model, + messages=messages, + stream=True, + stream_options={"include_usage": True}, + ) + assert ( + "`stream_options.include_usage` is currently only supported for the vLLM backend" + in str(e.value) + ) From 77354e0e6ca8e5b327d697d2e5e8ffe16354f8d9 Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Thu, 26 Jun 2025 23:59:10 +0530 Subject: [PATCH 13/14] Update formatting --- python/openai/openai_frontend/engine/triton_engine.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/python/openai/openai_frontend/engine/triton_engine.py b/python/openai/openai_frontend/engine/triton_engine.py index 351397d223..499cc623e7 100644 --- a/python/openai/openai_frontend/engine/triton_engine.py +++ b/python/openai/openai_frontend/engine/triton_engine.py @@ -481,9 +481,7 @@ async def _streaming_chat_iterator( ) previous_text = "" - include_usage = ( - request.stream_options and request.stream_options.include_usage - ) + include_usage = request.stream_options and request.stream_options.include_usage usage_accumulator = _StreamingUsageAccumulator(backend) chunk = self._get_first_streaming_chat_response( @@ -750,9 +748,7 @@ async def _streaming_completion_iterator( backend: str, ) -> AsyncIterator[str]: model = request.model - include_usage = ( - request.stream_options and request.stream_options.include_usage - ) + include_usage = request.stream_options and request.stream_options.include_usage usage_accumulator = _StreamingUsageAccumulator(backend) async for response in responses: From 3168c27a9f13a3e187a1ee5aea88c75f67249a3d Mon Sep 17 00:00:00 2001 From: Sai Kiran Polisetty Date: Fri, 27 Jun 2025 00:01:26 +0530 Subject: [PATCH 14/14] Update formatting --- python/openai/tests/test_openai_client.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/openai/tests/test_openai_client.py b/python/openai/tests/test_openai_client.py index c591e1c45a..1a1001329b 100644 --- a/python/openai/tests/test_openai_client.py +++ b/python/openai/tests/test_openai_client.py @@ -499,9 +499,7 @@ async def test_stream_options_without_streaming( stream=False, stream_options={"include_usage": True}, ) - assert "`stream_options` can only be used when `stream` is True" in str( - e.value - ) + assert "`stream_options` can only be used when `stream` is True" in str(e.value) with pytest.raises(openai.BadRequestError) as e: await client.chat.completions.create( @@ -510,16 +508,16 @@ async def test_stream_options_without_streaming( stream=False, stream_options={"include_usage": True}, ) - assert "`stream_options` can only be used when `stream` is True" in str( - e.value - ) + assert "`stream_options` can only be used when `stream` is True" in str(e.value) @pytest.mark.asyncio async def test_streaming_usage_unsupported_backend( self, client: openai.AsyncOpenAI, model: str, messages: List[dict], backend: str ): if backend == "vllm": - pytest.skip("This test is for backends that do not support usage reporting.") + pytest.skip( + "This test is for backends that do not support usage reporting." + ) with pytest.raises(openai.BadRequestError) as e: await client.completions.create(