Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 80 additions & 1 deletion tensorrt_llm/serve/harmony_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self,
# Normal case: filter based on available tools
self.should_filter_tools = True
self.available_tools = {
tool.get("function", {}).get("name", "")
tool.get("function", {}).get("name", "") if tool.get(
"name", None) is None else tool.get("name")
for tool in available_tools
}
self.available_tools.discard("")
Expand All @@ -78,6 +79,9 @@ def __init__(self,

logger.debug("Created HarmonyStreamState for request %s", request_id)

def get_parser(self) -> StreamableParser:
return self.parser

def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
"""
Process a batch of tokens while maintaining parsing state.
Expand Down Expand Up @@ -125,6 +129,42 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:

return deltas

def process_token_batch_to_messages(self,
tokens: list[int]) -> list[Message]:
"""
Process a batch of tokens while maintaining parsing state.
Returns OpenAI Messages for Responses API
"""
self.tokens_processed += len(tokens)

for token in tokens:
# Store previous state for transition detection
prev_channel = self.parser.current_channel
prev_recipient = self.parser.current_recipient

# Process the token
self.parser.process(token)

# Detect channel/recipient transitions AFTER processing each token
channel_changed = prev_channel != self.parser.current_channel
recipient_changed = prev_recipient != self.parser.current_recipient

if channel_changed or recipient_changed:
# Mark any active tool calls as completed if we're leaving a tool call
if prev_channel == "commentary" and prev_recipient and "functions." in str(
prev_recipient):
func_name = str(prev_recipient).split("functions.")[-1]
for tool_id, tool_info in self.tool_calls.items():
if tool_info["name"] == func_name and tool_info.get(
"active", True):
tool_info["active"] = False

# Reset channel state for new channel
self.channel_started = False
self.current_channel_state = None

return self.parser.messages

def _create_closing_token_delta(self) -> dict[str, Any] | None:
"""Create closing token delta for channel transition."""
if not self.current_channel_state or not self.channel_started:
Expand Down Expand Up @@ -317,6 +357,9 @@ def __init__(
"<|constrain|>": 200009,
}

def get_stream_state(self, request_id: str) -> HarmonyStreamState | None:
return self._stream_states.get(request_id, None)

def get_stop_tokens(self) -> list[int]:
"""
Return the list of stop token IDs for Harmony format.
Expand Down Expand Up @@ -1214,6 +1257,42 @@ def stateful_stream_harmony_tokens_to_openai_deltas(
# Return empty deltas to continue processing
return []

def stateful_stream_harmony_tokens_to_openai_messages(
self,
request_id: str,
tokens: list[int],
available_tools: list[dict[str, Any]] | None = None,
tool_choice: str | None = None) -> list[Message]:
"""
Process tokens using stateful parsing.

This method maintains persistent state across multiple calls for the same request,
ensuring proper channel transitions and tool call handling.

Args:
request_id: Request ID to maintain state per request
tokens: New tokens from this iteration
available_tools: Available tools for filtering

Returns:
List of OpenAI Messages
"""
stream_state = self._stream_states.get(request_id, None)
if stream_state is None:
stream_state = self.create_stream_state(request_id, available_tools,
tool_choice)

try:
messages = stream_state.process_token_batch_to_messages(tokens)
return messages
except (HarmonyError, UnicodeDecodeError, ValueError):
logger.error(
f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}",
)
logger.debug(f"Problematic streaming tokens: {tokens}")

return []

def create_openai_streaming_response(
self,
request_id: str,
Expand Down
211 changes: 210 additions & 1 deletion tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import \
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam
from openai.types.responses import (ResponseFunctionToolCall,
ResponseInputItemParam, ResponseOutputItem,
ResponsePrompt, ResponseReasoningItem,
ResponseStatus, ResponseTextConfig)
from openai.types.responses.response import ToolChoice
from openai.types.responses.tool import Tool
from openai.types.shared import Metadata, Reasoning
from openai_harmony import ReasoningEffort
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Annotated, Required, TypedDict
from typing_extensions import Annotated, Required, TypeAlias, TypedDict

from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams
Expand Down Expand Up @@ -665,6 +672,208 @@ def check_suffix(cls, data):
return data


ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
ResponseReasoningItem,
ResponseFunctionToolCall]


class ResponsesRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/responses/create
background: Optional[bool] = False
include: Optional[list[
Literal[
"code_interpreter_call.outputs",
"computer_call_output.output.image_url",
"file_search_call.results",
"message.input_image.image_url",
"message.output_text.logprobs",
"reasoning.encrypted_content",
],
]] = None
input: Union[str, list[ResponseInputOutputItem]]
instructions: Optional[str] = None
max_output_tokens: Optional[int] = None
max_tool_calls: Optional[int] = None
metadata: Optional[Metadata] = None
model: str
parallel_tool_calls: Optional[bool] = False
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePrompt] = None
reasoning: Optional[Reasoning] = None
service_tier: Literal["auto", "default", "flex", "scale",
"priority"] = "auto"
store: Optional[bool] = True
stream: Optional[bool] = False
temperature: Optional[float] = None
text: Optional[ResponseTextConfig] = None
tool_choice: ToolChoice = "auto"
tools: list[Tool] = Field(default_factory=list)
top_logprobs: Optional[int] = 0
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
user: Optional[str] = None

request_id: str = Field(
default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."),
)

_DEFAULT_SAMPLING_PARAMS = {
"temperature": 1.0,
"top_p": 1.0,
}

def to_sampling_params(
self,
default_max_tokens: int,
default_sampling_params: Optional[dict] = None,
) -> SamplingParams:
if self.max_output_tokens is None:
max_tokens = default_max_tokens
else:
max_tokens = min(self.max_output_tokens, default_max_tokens)

default_sampling_params = default_sampling_params or {}
if (temperature := self.temperature) is None:
temperature = default_sampling_params.get(
"temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"])
if (top_p := self.top_p) is None:
top_p = default_sampling_params.get(
"top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"])
stop_token_ids = default_sampling_params.get("stop_token_ids")

# Structured output
guided_decoding = None
if self.text is not None and self.text.format is not None:
response_format = self.text.format
if response_format.type == "json_schema":
guided_decoding = GuidedDecodingParams(
json=response_format.schema_)
elif response_format.type == "json_object":
raise NotImplementedError("json_object is not supported")

return SamplingParams(
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
logprobs=self.top_logprobs,
stop_token_ids=stop_token_ids,
guided_decoding=guided_decoding,
)

@model_validator(mode="before")
@classmethod
def validate_background(cls, data):
if not data.get("background"):
return data
if not data.get("store", True):
raise ValueError("background can only be used when `store` is true")
return data

@model_validator(mode="before")
@classmethod
def validate_prompt(cls, data):
if data.get("prompt") is not None:
raise ValueError("prompt template is not supported")
return data


class InputTokensDetails(OpenAIBaseModel):
cached_tokens: int


class OutputTokensDetails(OpenAIBaseModel):
reasoning_tokens: int


class ResponseUsage(OpenAIBaseModel):
input_tokens: int
input_tokens_details: InputTokensDetails
output_tokens: int
output_tokens_details: OutputTokensDetails
total_tokens: int


class ResponsesResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}")
created_at: int = Field(default_factory=lambda: int(time.time()))
# error: Optional[ResponseError] = None
# incomplete_details: Optional[IncompleteDetails] = None
instructions: Optional[str] = None
metadata: Optional[Metadata] = None
model: str
object: Literal["response"] = "response"
output: list[ResponseOutputItem]
parallel_tool_calls: bool
temperature: float
tool_choice: ToolChoice
tools: list[Tool]
top_p: float
background: bool
max_output_tokens: int
max_tool_calls: Optional[int] = None
previous_response_id: Optional[str] = None
prompt: Optional[ResponsePrompt] = None
reasoning: Optional[Reasoning] = None
service_tier: Literal["auto", "default", "flex", "scale", "priority"]
status: ResponseStatus
text: Optional[ResponseTextConfig] = None
top_logprobs: int
truncation: Literal["auto", "disabled"]
usage: Optional[ResponseUsage] = None
user: Optional[str] = None

@classmethod
def from_request(
cls,
request: ResponsesRequest,
sampling_params: SamplingParams,
model_name: str,
created_time: int,
output: list[ResponseOutputItem],
status: ResponseStatus,
usage: Optional[ResponseUsage] = None,
) -> "ResponsesResponse":
return cls(
id=request.request_id,
created_at=created_time,
instructions=request.instructions,
metadata=request.metadata,
model=model_name,
output=output,
parallel_tool_calls=request.parallel_tool_calls,
temperature=sampling_params.temperature,
tool_choice=request.tool_choice,
tools=request.tools,
top_p=sampling_params.top_p,
background=request.background,
max_output_tokens=sampling_params.max_tokens,
max_tool_calls=request.max_tool_calls,
previous_response_id=request.previous_response_id,
prompt=request.prompt,
reasoning=request.reasoning,
service_tier=request.service_tier,
status=status,
text=request.text,
top_logprobs=sampling_params.logprobs,
truncation=request.truncation,
user=request.user,
usage=usage,
)


class ResponsesStreamResponse(OpenAIBaseModel):
response: ResponsesResponse
sequence_number: int
type: Literal["response.created", "response.in_progress",
"response.completed", "response.failed",
"response.incomplete"]


def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]:
if opaque_state is None:
return None
Expand Down
Loading