Skip to content
22 changes: 15 additions & 7 deletions python/openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/

```json
{
"id": "cmpl-6930b296-7ef8-11ef-bdd1-107c6149ca79",
"id": "cmpl-0242093d-51ae-11f0-b339-e7480668bfbe",,
"choices": [
{
"finish_reason": "stop",
Expand All @@ -113,11 +113,15 @@ curl -s http://localhost:9000/v1/chat/completions -H 'Content-Type: application/
"logprobs": null
}
],
"created": 1727679085,
"created": 1750846825,
"model": "llama-3.1-8b-instruct",
"system_fingerprint": null,
"object": "chat.completion",
"usage": null
"usage": {
"completion_tokens": 7,
"prompt_tokens": 42,
"total_tokens": 49
}
}
```

Expand All @@ -138,20 +142,24 @@ curl -s http://localhost:9000/v1/completions -H 'Content-Type: application/json'

```json
{
"id": "cmpl-d51df75c-7ef8-11ef-bdd1-107c6149ca79",
"id": "cmpl-58fba3a0-51ae-11f0-859d-e7480668bfbe",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"text": " a field of computer science that focuses on developing algorithms that allow computers to learn from"
"text": " an amazing field that can truly understand the hidden patterns that exist in the data,"
}
],
"created": 1727679266,
"created": 1750846970,
"model": "llama-3.1-8b-instruct",
"system_fingerprint": null,
"object": "text_completion",
"usage": null
"usage": {
"completion_tokens": 16,
"prompt_tokens": 4,
"total_tokens": 20
}
}
```

Expand Down
76 changes: 72 additions & 4 deletions python/openai/openai_frontend/engine/triton_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@
_create_trtllm_inference_request,
_create_vllm_inference_request,
_get_output,
_get_usage_from_response,
_get_vllm_lora_names,
_StreamingUsageAccumulator,
_validate_triton_responses_non_streaming,
)
from schemas.openai import (
Expand All @@ -65,6 +67,7 @@
ChatCompletionStreamResponseDelta,
ChatCompletionToolChoiceOption1,
Choice,
CompletionUsage,
CreateChatCompletionRequest,
CreateChatCompletionResponse,
CreateChatCompletionStreamResponse,
Expand Down Expand Up @@ -229,6 +232,8 @@ async def chat(
backend=metadata.backend,
)

usage = _get_usage_from_response(response, metadata.backend)

return CreateChatCompletionResponse(
id=request_id,
choices=[
Expand All @@ -243,6 +248,7 @@ async def chat(
model=request.model,
system_fingerprint=None,
object=ObjectType.chat_completion,
usage=usage,
)

def _get_chat_completion_response_message(
Expand Down Expand Up @@ -319,7 +325,7 @@ async def completion(
created = int(time.time())
if request.stream:
return self._streaming_completion_iterator(
request_id, created, request.model, responses
request_id, created, request, responses, metadata.backend
)

# Response validation with decoupled models in mind
Expand All @@ -328,6 +334,8 @@ async def completion(
response = responses[0]
text = _get_output(response)

usage = _get_usage_from_response(response, metadata.backend)

choice = Choice(
finish_reason=FinishReason.stop,
index=0,
Expand All @@ -341,6 +349,7 @@ async def completion(
object=ObjectType.text_completion,
created=created,
model=request.model,
usage=usage,
)

# TODO: This behavior should be tested further
Expand Down Expand Up @@ -421,6 +430,7 @@ def _get_streaming_chat_response_chunk(
request_id: str,
created: int,
model: str,
usage: Optional[CompletionUsage] = None,
) -> CreateChatCompletionStreamResponse:
return CreateChatCompletionStreamResponse(
id=request_id,
Expand All @@ -429,6 +439,7 @@ def _get_streaming_chat_response_chunk(
model=model,
system_fingerprint=None,
object=ObjectType.chat_completion_chunk,
usage=usage,
)

def _get_first_streaming_chat_response(
Expand All @@ -444,7 +455,7 @@ def _get_first_streaming_chat_response(
finish_reason=None,
)
chunk = self._get_streaming_chat_response_chunk(
choice, request_id, created, model
choice, request_id, created, model, usage=None
)
return chunk

Expand All @@ -470,6 +481,13 @@ async def _streaming_chat_iterator(
)

previous_text = ""
include_usage = (
# TODO: Remove backend check condition once tensorrt-llm backend also supports usage
backend == "vllm"
and request.stream_options
and request.stream_options.include_usage
)
usage_accumulator = _StreamingUsageAccumulator(backend)

chunk = self._get_first_streaming_chat_response(
request_id, created, model, role
Expand All @@ -478,6 +496,8 @@ async def _streaming_chat_iterator(

async for response in responses:
delta_text = _get_output(response)
if include_usage:
usage_accumulator.update(response)

(
response_delta,
Expand Down Expand Up @@ -512,10 +532,25 @@ async def _streaming_chat_iterator(
)

chunk = self._get_streaming_chat_response_chunk(
choice, request_id, created, model
choice, request_id, created, model, usage=None
)
yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"

# Send the final usage chunk if requested via stream_options.
if include_usage:
usage_payload = usage_accumulator.get_final_usage()
if usage_payload:
final_usage_chunk = CreateChatCompletionStreamResponse(
id=request_id,
choices=[],
created=created,
model=model,
system_fingerprint=None,
object=ObjectType.chat_completion_chunk,
usage=usage_payload,
)
yield f"data: {final_usage_chunk.model_dump_json(exclude_unset=True)}\n\n"

yield "data: [DONE]\n\n"

def _get_streaming_response_delta(
Expand Down Expand Up @@ -698,9 +733,26 @@ def _verify_chat_tool_call_settings(self, request: CreateChatCompletionRequest):
)

async def _streaming_completion_iterator(
self, request_id: str, created: int, model: str, responses: AsyncIterable
self,
request_id: str,
created: int,
request: CreateCompletionRequest,
responses: AsyncIterable,
backend: str,
) -> AsyncIterator[str]:
model = request.model
include_usage = (
# TODO: Remove backend check condition once tensorrt-llm backend also supports usage
backend == "vllm"
and request.stream_options
and request.stream_options.include_usage
)
usage_accumulator = _StreamingUsageAccumulator(backend)

async for response in responses:
if include_usage:
usage_accumulator.update(response)

text = _get_output(response)
choice = Choice(
finish_reason=FinishReason.stop if response.final else None,
Expand All @@ -715,10 +767,26 @@ async def _streaming_completion_iterator(
object=ObjectType.text_completion,
created=created,
model=model,
usage=None,
)

yield f"data: {chunk.model_dump_json(exclude_unset=True)}\n\n"

# Send the final usage chunk if requested via stream_options.
if include_usage:
usage_payload = usage_accumulator.get_final_usage()
if usage_payload:
final_usage_chunk = CreateCompletionResponse(
id=request_id,
choices=[],
system_fingerprint=None,
object=ObjectType.text_completion,
created=created,
model=model,
usage=usage_payload,
)
yield f"data: {final_usage_chunk.model_dump_json(exclude_unset=True)}\n\n"

yield "data: [DONE]\n\n"

def _validate_completion_request(
Expand Down
84 changes: 83 additions & 1 deletion python/openai/openai_frontend/engine/utils/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import json
import os
import re
from dataclasses import asdict
from dataclasses import asdict, dataclass, field
from typing import Iterable, List, Optional, Union

import numpy as np
Expand All @@ -36,6 +36,7 @@
from schemas.openai import (
ChatCompletionNamedToolChoice,
ChatCompletionToolChoiceOption1,
CompletionUsage,
CreateChatCompletionRequest,
CreateCompletionRequest,
)
Expand Down Expand Up @@ -121,6 +122,8 @@ def _create_vllm_inference_request(
# Pass sampling_parameters as serialized JSON string input to support List
# fields like 'stop' that aren't supported by TRITONSERVER_Parameters yet.
inputs["sampling_parameters"] = [sampling_parameters]
inputs["return_num_input_tokens"] = np.bool_([True])
inputs["return_num_output_tokens"] = np.bool_([True])
return model.create_request(inputs=inputs)


Expand Down Expand Up @@ -221,6 +224,85 @@ def _to_string(tensor: tritonserver.Tensor) -> str:
return _construct_string_from_pointer(tensor.data_ptr + 4, tensor.size - 4)


@dataclass
class _StreamingUsageAccumulator:
"""Helper class to accumulate token usage from a streaming response."""

backend: str
prompt_tokens: int = 0
completion_tokens: int = 0
_prompt_tokens_set: bool = field(init=False, default=False)

def update(self, response: tritonserver.InferenceResponse):
"""Extracts usage from a response and updates the token counts."""
usage = _get_usage_from_response(response, self.backend)
if usage:
# The prompt_tokens is received with every chunk but should only be set once.
if not self._prompt_tokens_set:
self.prompt_tokens = usage.prompt_tokens
self._prompt_tokens_set = True
self.completion_tokens += usage.completion_tokens

def get_final_usage(self) -> Optional[CompletionUsage]:
"""
Returns the final populated CompletionUsage object if any tokens were tracked.
"""
# If _prompt_tokens_set is True, it means we have received and processed
# at least one valid usage payload.
if self._prompt_tokens_set:
return CompletionUsage(
prompt_tokens=self.prompt_tokens,
completion_tokens=self.completion_tokens,
total_tokens=self.prompt_tokens + self.completion_tokens,
)
return None


def _get_usage_from_response(
response: tritonserver._api._response.InferenceResponse,
backend: str,
) -> Optional[CompletionUsage]:
"""
Extracts token usage statistics from a Triton inference response.
"""
# TODO: Remove this check once TRT-LLM backend supports both "num_input_tokens"
# and "num_output_tokens", and also update the test cases accordingly.
if backend != "vllm":
return None

prompt_tokens = None
completion_tokens = None

if (
"num_input_tokens" in response.outputs
and "num_output_tokens" in response.outputs
):
input_token_tensor = response.outputs["num_input_tokens"]
output_token_tensor = response.outputs["num_output_tokens"]

if input_token_tensor.data_type == tritonserver.DataType.UINT32:
prompt_tokens_ptr = ctypes.cast(
input_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32)
)
prompt_tokens = prompt_tokens_ptr[0]

if output_token_tensor.data_type == tritonserver.DataType.UINT32:
completion_tokens_ptr = ctypes.cast(
output_token_tensor.data_ptr, ctypes.POINTER(ctypes.c_uint32)
)
completion_tokens = completion_tokens_ptr[0]

if prompt_tokens is not None and completion_tokens is not None:
total_tokens = prompt_tokens + completion_tokens
return CompletionUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)

return None


# TODO: Use tritonserver.InferenceResponse when support is published
def _get_output(response: tritonserver._api._response.InferenceResponse) -> str:
if "text_output" in response.outputs:
Expand Down
Loading
Loading