diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index c4703873de9..a46e7c5ed45 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -57,7 +57,8 @@ def __init__(self, # Normal case: filter based on available tools self.should_filter_tools = True self.available_tools = { - tool.get("function", {}).get("name", "") + tool.get("function", {}).get("name", "") if tool.get( + "name", None) is None else tool.get("name") for tool in available_tools } self.available_tools.discard("") @@ -78,6 +79,9 @@ def __init__(self, logger.debug("Created HarmonyStreamState for request %s", request_id) + def get_parser(self) -> StreamableParser: + return self.parser + def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]: """ Process a batch of tokens while maintaining parsing state. @@ -125,6 +129,42 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]: return deltas + def process_token_batch_to_messages(self, + tokens: list[int]) -> list[Message]: + """ + Process a batch of tokens while maintaining parsing state. + Returns OpenAI Messages for Responses API + """ + self.tokens_processed += len(tokens) + + for token in tokens: + # Store previous state for transition detection + prev_channel = self.parser.current_channel + prev_recipient = self.parser.current_recipient + + # Process the token + self.parser.process(token) + + # Detect channel/recipient transitions AFTER processing each token + channel_changed = prev_channel != self.parser.current_channel + recipient_changed = prev_recipient != self.parser.current_recipient + + if channel_changed or recipient_changed: + # Mark any active tool calls as completed if we're leaving a tool call + if prev_channel == "commentary" and prev_recipient and "functions." in str( + prev_recipient): + func_name = str(prev_recipient).split("functions.")[-1] + for tool_id, tool_info in self.tool_calls.items(): + if tool_info["name"] == func_name and tool_info.get( + "active", True): + tool_info["active"] = False + + # Reset channel state for new channel + self.channel_started = False + self.current_channel_state = None + + return self.parser.messages + def _create_closing_token_delta(self) -> dict[str, Any] | None: """Create closing token delta for channel transition.""" if not self.current_channel_state or not self.channel_started: @@ -317,6 +357,9 @@ def __init__( "<|constrain|>": 200009, } + def get_stream_state(self, request_id: str) -> HarmonyStreamState | None: + return self._stream_states.get(request_id, None) + def get_stop_tokens(self) -> list[int]: """ Return the list of stop token IDs for Harmony format. @@ -1214,6 +1257,42 @@ def stateful_stream_harmony_tokens_to_openai_deltas( # Return empty deltas to continue processing return [] + def stateful_stream_harmony_tokens_to_openai_messages( + self, + request_id: str, + tokens: list[int], + available_tools: list[dict[str, Any]] | None = None, + tool_choice: str | None = None) -> list[Message]: + """ + Process tokens using stateful parsing. + + This method maintains persistent state across multiple calls for the same request, + ensuring proper channel transitions and tool call handling. + + Args: + request_id: Request ID to maintain state per request + tokens: New tokens from this iteration + available_tools: Available tools for filtering + + Returns: + List of OpenAI Messages + """ + stream_state = self._stream_states.get(request_id, None) + if stream_state is None: + stream_state = self.create_stream_state(request_id, available_tools, + tool_choice) + + try: + messages = stream_state.process_token_batch_to_messages(tokens) + return messages + except (HarmonyError, UnicodeDecodeError, ValueError): + logger.error( + f"Streaming: Failed to process token batch of {len(tokens)} tokens for request {request_id}", + ) + logger.debug(f"Problematic streaming tokens: {tokens}") + + return [] + def create_openai_streaming_response( self, request_id: str, diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index f02178bacf5..acfbff14d23 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -11,9 +11,16 @@ ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam from openai.types.chat import \ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam +from openai.types.responses import (ResponseFunctionToolCall, + ResponseInputItemParam, ResponseOutputItem, + ResponsePrompt, ResponseReasoningItem, + ResponseStatus, ResponseTextConfig) +from openai.types.responses.response import ToolChoice +from openai.types.responses.tool import Tool +from openai.types.shared import Metadata, Reasoning from openai_harmony import ReasoningEffort from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated, Required, TypedDict +from typing_extensions import Annotated, Required, TypeAlias, TypedDict from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi import DisaggregatedParams as LlmDisaggregatedParams @@ -665,6 +672,208 @@ def check_suffix(cls, data): return data +ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, + ResponseReasoningItem, + ResponseFunctionToolCall] + + +class ResponsesRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/responses/create + background: Optional[bool] = False + include: Optional[list[ + Literal[ + "code_interpreter_call.outputs", + "computer_call_output.output.image_url", + "file_search_call.results", + "message.input_image.image_url", + "message.output_text.logprobs", + "reasoning.encrypted_content", + ], + ]] = None + input: Union[str, list[ResponseInputOutputItem]] + instructions: Optional[str] = None + max_output_tokens: Optional[int] = None + max_tool_calls: Optional[int] = None + metadata: Optional[Metadata] = None + model: str + parallel_tool_calls: Optional[bool] = False + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", + "priority"] = "auto" + store: Optional[bool] = True + stream: Optional[bool] = False + temperature: Optional[float] = None + text: Optional[ResponseTextConfig] = None + tool_choice: ToolChoice = "auto" + tools: list[Tool] = Field(default_factory=list) + top_logprobs: Optional[int] = 0 + top_p: Optional[float] = None + truncation: Optional[Literal["auto", "disabled"]] = "disabled" + user: Optional[str] = None + + request_id: str = Field( + default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}", + description=( + "The request_id related to this request. If the caller does " + "not set it, a random_uuid will be generated. This id is used " + "through out the inference process and return in response."), + ) + + _DEFAULT_SAMPLING_PARAMS = { + "temperature": 1.0, + "top_p": 1.0, + } + + def to_sampling_params( + self, + default_max_tokens: int, + default_sampling_params: Optional[dict] = None, + ) -> SamplingParams: + if self.max_output_tokens is None: + max_tokens = default_max_tokens + else: + max_tokens = min(self.max_output_tokens, default_max_tokens) + + default_sampling_params = default_sampling_params or {} + if (temperature := self.temperature) is None: + temperature = default_sampling_params.get( + "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) + if (top_p := self.top_p) is None: + top_p = default_sampling_params.get( + "top_p", self._DEFAULT_SAMPLING_PARAMS["top_p"]) + stop_token_ids = default_sampling_params.get("stop_token_ids") + + # Structured output + guided_decoding = None + if self.text is not None and self.text.format is not None: + response_format = self.text.format + if response_format.type == "json_schema": + guided_decoding = GuidedDecodingParams( + json=response_format.schema_) + elif response_format.type == "json_object": + raise NotImplementedError("json_object is not supported") + + return SamplingParams( + temperature=temperature, + top_p=top_p, + max_tokens=max_tokens, + logprobs=self.top_logprobs, + stop_token_ids=stop_token_ids, + guided_decoding=guided_decoding, + ) + + @model_validator(mode="before") + @classmethod + def validate_background(cls, data): + if not data.get("background"): + return data + if not data.get("store", True): + raise ValueError("background can only be used when `store` is true") + return data + + @model_validator(mode="before") + @classmethod + def validate_prompt(cls, data): + if data.get("prompt") is not None: + raise ValueError("prompt template is not supported") + return data + + +class InputTokensDetails(OpenAIBaseModel): + cached_tokens: int + + +class OutputTokensDetails(OpenAIBaseModel): + reasoning_tokens: int + + +class ResponseUsage(OpenAIBaseModel): + input_tokens: int + input_tokens_details: InputTokensDetails + output_tokens: int + output_tokens_details: OutputTokensDetails + total_tokens: int + + +class ResponsesResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}") + created_at: int = Field(default_factory=lambda: int(time.time())) + # error: Optional[ResponseError] = None + # incomplete_details: Optional[IncompleteDetails] = None + instructions: Optional[str] = None + metadata: Optional[Metadata] = None + model: str + object: Literal["response"] = "response" + output: list[ResponseOutputItem] + parallel_tool_calls: bool + temperature: float + tool_choice: ToolChoice + tools: list[Tool] + top_p: float + background: bool + max_output_tokens: int + max_tool_calls: Optional[int] = None + previous_response_id: Optional[str] = None + prompt: Optional[ResponsePrompt] = None + reasoning: Optional[Reasoning] = None + service_tier: Literal["auto", "default", "flex", "scale", "priority"] + status: ResponseStatus + text: Optional[ResponseTextConfig] = None + top_logprobs: int + truncation: Literal["auto", "disabled"] + usage: Optional[ResponseUsage] = None + user: Optional[str] = None + + @classmethod + def from_request( + cls, + request: ResponsesRequest, + sampling_params: SamplingParams, + model_name: str, + created_time: int, + output: list[ResponseOutputItem], + status: ResponseStatus, + usage: Optional[ResponseUsage] = None, + ) -> "ResponsesResponse": + return cls( + id=request.request_id, + created_at=created_time, + instructions=request.instructions, + metadata=request.metadata, + model=model_name, + output=output, + parallel_tool_calls=request.parallel_tool_calls, + temperature=sampling_params.temperature, + tool_choice=request.tool_choice, + tools=request.tools, + top_p=sampling_params.top_p, + background=request.background, + max_output_tokens=sampling_params.max_tokens, + max_tool_calls=request.max_tool_calls, + previous_response_id=request.previous_response_id, + prompt=request.prompt, + reasoning=request.reasoning, + service_tier=request.service_tier, + status=status, + text=request.text, + top_logprobs=sampling_params.logprobs, + truncation=request.truncation, + user=request.user, + usage=usage, + ) + + +class ResponsesStreamResponse(OpenAIBaseModel): + response: ResponsesResponse + sequence_number: int + type: Literal["response.created", "response.in_progress", + "response.completed", "response.failed", + "response.incomplete"] + + def encode_opaque_state(opaque_state: Optional[bytes]) -> Optional[str]: if opaque_state is None: return None diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 6fa7ee1952b..de245046359 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -40,12 +40,20 @@ CompletionResponse, CompletionResponseChoice, ErrorResponse, ModelCard, - ModelList, UsageInfo, + ModelList, ResponsesRequest, + UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor, chat_stream_post_processor, completion_response_post_processor, completion_stream_post_processor) +from tensorrt_llm.serve.responses_utils import ConversationHistoryStore +from tensorrt_llm.serve.responses_utils import \ + create_response as responses_api_create_response +from tensorrt_llm.serve.responses_utils import \ + process_streaming_events as responses_api_process_streaming_events +from tensorrt_llm.serve.responses_utils import \ + request_preprocess as responses_api_request_preprocess from tensorrt_llm.version import __version__ as VERSION from .._utils import nvtx_mark, set_prometheus_multiproc_dir @@ -82,6 +90,12 @@ def __init__(self, logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path) self.model_config = None + # Enable response storage for Responses API + self.enable_store = True + if len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) > 0: + self.enable_store = False + self.conversation_store = ConversationHistoryStore() + model_dir = Path(model) if model_dir.exists() and model_dir.is_dir(): self.model = model_dir.name @@ -165,6 +179,20 @@ def create_error_response( return JSONResponse(content=error_response.model_dump(), status_code=error_response.code) + def _create_invalid_response_id_error(self, response_id: str) -> Response: + return self.create_error_response( + err_type="InvalidRequestError", + message=(f"Invalid 'response_id': '{response_id}'. " + "Expected an ID that begins with 'resp'."), + ) + + def _create_response_id_not_found_error(self, response_id: str) -> Response: + return self.create_error_response( + err_type="InvalidRequestError", + message=f"Response with id '{response_id}' not found.", + status_code=HTTPStatus.NOT_FOUND, + ) + def register_routes(self): self.app.add_api_route("/health", self.health, methods=["GET"]) self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"]) @@ -181,6 +209,9 @@ def register_routes(self): self.app.add_api_route("/v1/chat/completions", self.openai_chat if not self.use_harmony else self.chat_harmony, methods=["POST"]) + self.app.add_api_route("/v1/responses", + self.openai_responses, + methods=["POST"]) if self.llm.args.return_perf_metrics: # register /prometheus/metrics self.mount_metrics() @@ -739,6 +770,80 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques logger.debug("Error details: %s", traceback.format_exc()) return self.create_error_response(message=str(e), err_type="internal_error") + async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response: + async def create_stream_response(generator, request: ResponsesRequest, sampling_params) -> AsyncGenerator[str, None]: + async for event_data in responses_api_process_streaming_events( + request=request, + sampling_params=sampling_params, + generator=generator, + harmony_adapter=self.harmony_adapter, + model_name=self.model, + conversation_store=self.conversation_store, + enable_store=self.enable_store + ): + yield event_data + + try: + if not self.use_harmony: + raise NotImplementedError("Responses API only supports harmony format for now") + + # Initialize HarmonyAdapter + # NOTE: WAR for Disagg failure, may affect perf if no warmup + if not self.harmony_adapter: + self.harmony_adapter = HarmonyAdapter() + + if request.background: + logger.warning("Request.background is not supported yet, will fallback to foreground processing.") + + # Get prev response + prev_response = None + if self.enable_store: + prev_response_id = request.previous_response_id + if prev_response_id is not None: + if not prev_response_id.startswith("resp_"): + return self._create_invalid_response_id_error(prev_response_id) + + prev_response = await self.conversation_store.load_response(prev_response_id) + if prev_response is None: + logger.debug(f"response_id {prev_response_id} not found") + return self._create_response_id_not_found_error(prev_response_id) + + input_tokens, sampling_params = await responses_api_request_preprocess( + request, prev_response, self.harmony_adapter, self.conversation_store, self.enable_store) + + promise = self.llm.generate_async( + inputs=input_tokens, + sampling_params=sampling_params, + streaming=request.stream, + ) + + asyncio.create_task(self.await_disconnected(raw_request, promise)) + + if request.stream: + return StreamingResponse( + create_stream_response(promise, request, sampling_params), + media_type="text/event-stream" + ) + else: + return await responses_api_create_response( + generator=promise, + request=request, + sampling_params=sampling_params, + model_name=self.model, + conversation_store=self.conversation_store, + generation_result=None, + enable_store=self.enable_store) + except CppExecutorError: + logger.error(traceback.format_exc()) + # If internal executor error is raised, shutdown the server + signal.raise_signal(signal.SIGINT) + except Exception as e: + logger.error(traceback.format_exc()) + return self.create_error_response(str(e)) + + return JSONResponse(content={"detail": "None"}) + + async def __call__(self, host, port): # Store the binding address for server registration self.binding_addr = f"http://{host}:{port}" diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py new file mode 100644 index 00000000000..d4a6af268c4 --- /dev/null +++ b/tensorrt_llm/serve/responses_utils.py @@ -0,0 +1,848 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import asyncio +import json +import os +import time +import uuid +from collections.abc import AsyncGenerator +from copy import copy +from typing import Literal, Optional, OrderedDict, Union + +# yapf: disable +from openai.types.responses import (ResponseCompletedEvent, + ResponseContentPartAddedEvent, + ResponseContentPartDoneEvent, + ResponseCreatedEvent, + ResponseFunctionToolCall, + ResponseInProgressEvent, ResponseOutputItem, + ResponseOutputItemAddedEvent, + ResponseOutputItemDoneEvent, + ResponseOutputMessage, ResponseOutputText, + ResponseReasoningItem, + ResponseReasoningTextDeltaEvent, + ResponseReasoningTextDoneEvent, + ResponseTextDeltaEvent, + ResponseTextDoneEvent) +# yapf: enable +from openai.types.responses.response_function_web_search import ( + ActionFind, ActionOpenPage, ActionSearch, ResponseFunctionWebSearch) +from openai.types.responses.response_reasoning_item import Content +from openai.types.responses.tool import Tool +from openai_harmony import (Author, Conversation, DeveloperContent, + HarmonyEncodingName, Message, ReasoningEffort, Role, + StreamState, SystemContent, TextContent, + ToolDescription, load_harmony_encoding) + +from tensorrt_llm.llmapi import SamplingParams +from tensorrt_llm.llmapi.llm import RequestOutput +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.openai_protocol import (OpenAIBaseModel, + ResponseInputOutputItem, + ResponsesRequest, + ResponsesResponse) + +from .harmony_adapter import HarmonyAdapter + +REASONING_EFFORT = { + "high": ReasoningEffort.HIGH, + "medium": ReasoningEffort.MEDIUM, + "low": ReasoningEffort.LOW, +} + +ENABLE_RESPONSES_DEBUG_MSG = False + + +def responses_debug_log(msg): + if ENABLE_RESPONSES_DEBUG_MSG: + logger.debug(msg) + + +_harmony_encoding = None + + +def random_uuid(): + return str(uuid.uuid4().hex) + + +def get_encoding(): + global _harmony_encoding + if _harmony_encoding is None: + _harmony_encoding = load_harmony_encoding( + HarmonyEncodingName.HARMONY_GPT_OSS) + return _harmony_encoding + + +def decode_tokens(tokens): + return get_encoding().decode(tokens) + + +def parse_response_input( + input_msg: ResponseInputOutputItem, + prev_responses: list[Union[ResponseOutputItem, ResponseReasoningItem]] +) -> Message: + if not isinstance(input_msg, dict): + input_msg = input_msg.model_dump() + + responses_debug_log(f"------- Parsing input -----------") + responses_debug_log(input_msg) + responses_debug_log("") + + if "type" not in input_msg or input_msg["type"] == "message": + role = input_msg["role"] + content = input_msg["content"] + if role == "system": + # User is trying to set a system message. Change it to: + # <|start|>developer<|message|># Instructions + # {instructions}<|end|> + role = "developer" + text_prefix = "Instructions:\n" + else: + text_prefix = "" + if isinstance(content, str): + msg = Message.from_role_and_content(role, text_prefix + content) + elif isinstance(content, list): + contents = [ + TextContent(text=text_prefix + c["text"]) for c in content + ] + msg = Message.from_role_and_contents(role, contents) + else: + logger.warning("Responses API: Invalid input message type") + msg = None + elif input_msg["type"] == "function_call_output": + call_id = input_msg["call_id"] + call_response: Optional[ResponseFunctionToolCall] = None + for prev_response in reversed(prev_responses): + if isinstance(prev_response, ResponseFunctionToolCall + ) and prev_response.call_id == call_id: + call_response = prev_response + break + if call_response is None: + raise ValueError(f"No call message found for {call_id}") + msg = Message.from_author_and_content( + Author.new(Role.TOOL, f"functions.{call_response.name}"), + input_msg["output"]) + elif input_msg["type"] == "reasoning": + content = input_msg["content"] + assert len(content) == 1 + msg = Message.from_role_and_content(Role.ASSISTANT, content[0]["text"]) + elif input_msg["type"] == "function_call": + msg = Message.from_role_and_content(Role.ASSISTANT, + input_msg["arguments"]) + msg = msg.with_channel("commentary") + msg = msg.with_recipient(f"functions.{input_msg['name']}") + msg = msg.with_content_type("json") + else: + raise ValueError(f"Unknown input type: {input_msg['type']}") + return msg + + +class ConversationHistoryStore: + + def __init__(self, resp_capacity: int = 16, max_conversations=32): + self.response_capacity = resp_capacity + self.conversation_capacity = resp_capacity * 4 + self.max_conversations = max_conversations + + self.responses_lock = asyncio.Lock() + self.responses: OrderedDict[str, ResponsesResponse] = OrderedDict() + + self.conversations_lock = asyncio.Lock() + self.conversations: OrderedDict[str, list[Message]] = OrderedDict() + self.response_to_conversation: dict[str, str] = {} + self.conversation_to_response: dict[str, str] = {} + + async def load_response(self, resp_id: str) -> ResponsesResponse: + responses_debug_log(f"ConversationHistoryStore loading resp: {resp_id}") + async with self.responses_lock: + return self.responses.get(resp_id) + + async def store_response(self, + resp: ResponsesResponse, + resp_msgs: Optional[list[Message]] = [], + prev_resp_id: Optional[str] = None) -> None: + resp_id = resp.id + responses_debug_log(f"ConversationHistoryStore storing resp: {resp_id}") + async with self.responses_lock: + self.responses[resp_id] = resp + if len(self.responses) > self.response_capacity: + self._pop_response() + + async with self.conversations_lock: + conversation_id: str + if resp_id in self.response_to_conversation: + conversation_id = self.response_to_conversation[resp_id] + self.conversations[conversation_id].extend(resp_msgs) + elif prev_resp_id is not None: + conversation_id = self.response_to_conversation[prev_resp_id] + self.conversations[conversation_id].extend(resp_msgs) + while len(self.conversations[conversation_id] + ) > self.conversation_capacity: + self._pop_conversation(resp_id) + else: + conversation_id = random_uuid() + self.conversations[conversation_id] = resp_msgs + + responses_debug_log( + f" * storing at conversation id: {conversation_id}") + + self.response_to_conversation[resp_id] = conversation_id + self.conversation_to_response[conversation_id] = resp_id + self._update_visited_conversation(conversation_id) + + async def store_messages(self, resp_id: str, msgs: list[Message], + prev_resp_id: Optional[str]): + responses_debug_log(f"ConversationHistoryStore storing msg:") + for msg in msgs: + responses_debug_log(f" -> {msg.to_json()}") + + async with self.conversations_lock: + conversation_id: str + if prev_resp_id is not None: + conversation_id = self.response_to_conversation[prev_resp_id] + else: + conversation_id = random_uuid() + + responses_debug_log( + f" * storing at conversation: {conversation_id}") + self.conversations[conversation_id] = msgs + if len(self.conversations[conversation_id] + ) > self.conversation_capacity: + self._pop_conversation(resp_id) + + self.response_to_conversation[resp_id] = conversation_id + self.conversation_to_response[conversation_id] = resp_id + self._update_visited_conversation(conversation_id) + + async def append_messages(self, resp_id: str, msgs: list[Message]): + responses_debug_log(f"ConversationHistoryStore appending msgs:") + for msg in msgs: + responses_debug_log(f" -> {msg.to_json()}") + + async with self.conversations_lock: + assert resp_id in self.response_to_conversation + conversation_id = self.response_to_conversation[resp_id] + + responses_debug_log( + f" * appending at conversation: {conversation_id}") + self.conversations[conversation_id].extend(msgs) + if len(self.conversations[conversation_id] + ) > self.conversation_capacity: + self._pop_conversation(resp_id) + self._update_visited_conversation(conversation_id) + + async def get_conversation_history(self, resp_id: str) -> list[Message]: + responses_debug_log(f"ConversationHistoryStore getting prev_msgs:") + responses_debug_log(f" -> prev_resp_id: {resp_id}") + async with self.conversations_lock: + if resp_id in self.response_to_conversation: + conversation_id = self.response_to_conversation[resp_id] + self._update_visited_conversation(conversation_id) + return self.conversations.get(conversation_id, []) + + return [] + + def _update_visited_conversation(self, conversation_id) -> None: + if conversation_id not in self.conversations: + return + + self.conversations.move_to_end(conversation_id) + if len(self.conversations) > self.max_conversations: + removed_id, _ = self.conversations.popitem(last=False) + responses_debug_log( + f"ConversationHistoryStore Removing conversation {removed_id}") + removed_resp_id = self.conversation_to_response[removed_id] + # The responses may have been removed due to response capacity + if removed_resp_id in self.response_to_conversation: + self.response_to_conversation.pop(removed_resp_id) + self.conversation_to_response.pop(removed_id) + + def _pop_conversation(self, resp_id) -> None: + conversation_id = self.response_to_conversation.get(resp_id, None) + if conversation_id is None: + return + + conversation = self.conversations[conversation_id] + first_conversation_range = [] + for i, msg in enumerate(conversation): + if msg.author.role == Role.USER: + first_conversation_range.append(i) + elif msg.channel == "final": + first_conversation_range.append(i) + break + del conversation[ + first_conversation_range[0]:first_conversation_range[1] + 1] + + def _pop_response(self) -> None: + responses_debug_log(f"responses type: {type(self.responses)}") + resp_id, _ = self.responses.popitem(last=False) + if resp_id in self.response_to_conversation: + self.response_to_conversation.pop(resp_id) + + +def get_system_message( + model_identity: Optional[str] = None, + reasoning_effort: Optional[Literal["high", "medium", "low"]] = None, + start_date: Optional[str] = None, + browser_description: Optional[str] = None, + python_description: Optional[str] = None, +) -> Message: + sys_msg_content = SystemContent.new() + if model_identity is not None: + sys_msg_content = sys_msg_content.with_model_identity(model_identity) + if reasoning_effort is not None: + sys_msg_content = sys_msg_content.with_reasoning_effort( + REASONING_EFFORT[reasoning_effort]) + if start_date: + sys_msg_content = sys_msg_content.with_conversation_start_date( + start_date) + if browser_description is not None: + sys_msg_content = sys_msg_content.with_tools(browser_description) + if python_description is not None: + sys_msg_content = sys_msg_content.with_tools(python_description) + sys_msg = Message.from_role_and_content(Role.SYSTEM, sys_msg_content) + return sys_msg + + +def get_developer_message(instructions: Optional[str] = None, + tools: Optional[list[Tool]] = None) -> Message: + dev_msg_content = DeveloperContent.new() + if instructions is not None: + dev_msg_content = dev_msg_content.with_instructions(instructions) + if tools is not None: + function_tools = [] + for tool in tools: + if tool.type in ("web_search_preview", "code_interpreter"): + # These are built-in tools that are added to the system message. + pass + elif tool.type == "function": + function_tools.append(tool) + else: + raise ValueError(f"tool type {tool.type} not supported") + if function_tools: + function_tool_descriptions = [ + ToolDescription.new( + name=tool.name, + description=tool.description, + parameters=tool.parameters, + ) for tool in function_tools + ] + dev_msg_content = dev_msg_content.with_function_tools( + function_tool_descriptions) + dev_msg = Message.from_role_and_content(Role.DEVELOPER, dev_msg_content) + return dev_msg + + +def get_user_message(content: str) -> Message: + return Message.from_role_and_content(Role.USER, content) + + +def construct_harmony_messages( + request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + prev_msgs: list[Message] = [], +) -> list[Message]: + """Construct messages from request input, includes conversation history messages if exists.""" + messages: list[Message] = [] + if prev_response is None: + # New conversation. + reasoning_effort = (request.reasoning.effort + if request.reasoning else None) + sys_msg = get_system_message(reasoning_effort=reasoning_effort, ) + messages.append(sys_msg) + dev_msg = get_developer_message(request.instructions, request.tools) + messages.append(dev_msg) + else: + messages.extend(prev_msgs) + # Append the new input. + # Responses API supports simple text inputs without chat format. + if isinstance(request.input, str): + messages.append(get_user_message(request.input)) + else: + if prev_response is not None: + prev_outputs = copy(prev_response.output) + else: + prev_outputs = [] + for input_msg in request.input: + msg = parse_response_input(input_msg, prev_outputs) + if msg is not None: + messages.append(msg) + # User passes in a a tool call request and its output. We need + # to add the tool call request to prev_outputs so that the + # parse_response_input can find the tool call request when + # parsing the tool call output. + if isinstance(input_msg, ResponseFunctionToolCall): + prev_outputs.append(input_msg) + return messages + + +def render_for_completion(messages: list[Message]) -> list[int]: + conversation = Conversation.from_messages(messages) + responses_debug_log("Rendering conversation:") + responses_debug_log(conversation.to_json()) + token_ids = get_encoding().render_conversation_for_completion( + conversation, Role.ASSISTANT) + return token_ids + + +def parse_output_tokens(tokens: list[int]) -> list[Message]: + return get_encoding().parse_messages_from_completion_tokens( + tokens, role=Role.ASSISTANT) + + +def parse_output_message(message: Message) -> list[ResponseOutputItem]: + """ + Parse a Harmony message into a list of output response items. + """ + if message.author.role != "assistant": + # This is a message from a tool to the assistant (e.g., search result). + # Don't include it in the final output for now. This aligns with + # OpenAI's behavior on models like o4-mini. + return [] + + output_items: list[ResponseOutputItem] = [] + recipient = message.recipient + if recipient is not None and recipient.startswith("browser."): + if len(message.content) != 1: + raise ValueError("Invalid number of contents in browser message") + content = message.content[0] + browser_call = json.loads(content.text) + # TODO: translate to url properly! + if recipient == "browser.search": + action = ActionSearch( + query=f"cursor:{browser_call.get('query', '')}", type="search") + elif recipient == "browser.open": + action = ActionOpenPage(url=f"cursor:{browser_call.get('url', '')}", + type="open_page") + elif recipient == "browser.find": + action = ActionFind(pattern=browser_call["pattern"], + url=f"cursor:{browser_call.get('url', '')}", + type="find") + else: + raise ValueError(f"Unknown browser action: {recipient}") + web_search_item = ResponseFunctionWebSearch( + id=f"ws_{random_uuid()}", + action=action, + status="completed", + type="web_search_call", + ) + output_items.append(web_search_item) + elif message.channel == "analysis": + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[Content(text=content.text, type="reasoning_text")], + status=None, + ) + output_items.append(reasoning_item) + elif message.channel == "commentary": + if message.recipient is None: + pass + elif message.recipient.startswith("functions."): + function_name = message.recipient.split(".")[-1] + for content in message.content: + random_id = random_uuid() + response_item = ResponseFunctionToolCall( + arguments=content.text, + call_id=f"call_{random_id}", + type="function_call", + name=function_name, + id=f"fc_{random_id}", + ) + output_items.append(response_item) + elif message.recipient.startswith( + "python") or message.recipient.startswith("browser"): + for content in message.content: + reasoning_item = ResponseReasoningItem( + id=f"rs_{random_uuid()}", + summary=[], + type="reasoning", + content=[Content(text=content.text, type="reasoning_text")], + status=None, + ) + output_items.append(reasoning_item) + else: + raise ValueError(f"Unknown recipient: {message.recipient}") + elif message.channel == "final": + contents = [] + for content in message.content: + output_text = ResponseOutputText( + text=content.text, + annotations=[], # TODO + type="output_text", + logprobs=None, # TODO + ) + contents.append(output_text) + text_item = ResponseOutputMessage( + id=f"msg_{random_uuid()}", + content=contents, + role=message.author.role, + status="completed", + type="message", + ) + output_items.append(text_item) + else: + raise ValueError(f"Unknown channel: {message.channel}") + return output_items + + +def finish_reason_mapping(finish_reason: str) -> str: + match finish_reason: + case 'stop': + return 'completed' + case 'length': + return 'incomplete' + case 'timeout': + return 'failed' + case 'cancelled': + return 'cancelled' + + raise RuntimeError("Should never reach here!") + + +async def request_preprocess(request: ResponsesRequest, + prev_response: Optional[ResponsesResponse], + harmony_adapter: HarmonyAdapter, + conversation_store: ConversationHistoryStore, + enable_store=False): + # TODO: fix default_max_tokens + sampling_params = request.to_sampling_params( + default_max_tokens=int(16384), + default_sampling_params={ + "stop_token_ids": harmony_adapter.get_stop_tokens() + }) + + prev_response_id = request.previous_response_id + + # TODO: better way to enable metrics + if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0: + sampling_params.return_perf_metrics = True + + prev_msgs = [] + if enable_store: + prev_msgs = await conversation_store.get_conversation_history( + prev_response_id) + + responses_debug_log(f"Prev msgs:") + for msg in prev_msgs: + responses_debug_log(f" -> {msg.to_json()}") + + messages = construct_harmony_messages(request, + prev_response, + prev_msgs=prev_msgs) + + if enable_store and request.store: + # Remove reasoning messages to save token usage during multi-turn conversation + msgs_to_store = [msg for msg in messages if msg.channel != "analysis"] + await conversation_store.store_messages(request.request_id, + msgs_to_store, prev_response_id) + + input_tokens = render_for_completion(messages) + + responses_debug_log("======= Complete Inputs to model =======") + responses_debug_log(decode_tokens(input_tokens)) + responses_debug_log("========================================") + return input_tokens, sampling_params + + +async def create_response( + generator, + request: ResponsesRequest, + sampling_params, + model_name: str, + conversation_store: ConversationHistoryStore, + generation_result: RequestOutput = None, + enable_store=False, + create_time: int = None, +) -> ResponsesResponse: + + final_res: Optional[RequestOutput] = None + response_creation_time = create_time if create_time is not None else int( + time.time()) + prev_response_id = request.previous_response_id + + if generation_result is not None: + final_res = generation_result + else: + final_res = await generator + + if final_res is None: + raise RuntimeError("No output generated or provided") + + responses_debug_log("================================================") + responses_debug_log("RAW MODEL OUTPUT:") + responses_debug_log(final_res.outputs) + responses_debug_log("================================================") + + output_messages = parse_output_tokens(final_res.outputs[0].token_ids) + + responses_debug_log(f"output messages: {len(output_messages)}") + for msg in output_messages: + responses_debug_log(f" -> {msg.to_json()}") + + # prepare responses output + output_content = [] + for msg in output_messages: + output_content.extend(parse_output_message(msg)) + + response = ResponsesResponse.from_request( + request=request, + sampling_params=sampling_params, + model_name=model_name, + created_time=response_creation_time, + output=output_content, + status=finish_reason_mapping(final_res.outputs[0].finish_reason), + ) + + if enable_store and request.store: + await conversation_store.store_response(resp=response, + resp_msgs=output_messages, + prev_resp_id=prev_response_id) + + responses_debug_log("========== Response ===========") + responses_debug_log(response) + responses_debug_log("===============================") + return response + + +async def process_streaming_events( + request: ResponsesRequest, + sampling_params: SamplingParams, + generator, + harmony_adapter: HarmonyAdapter, + model_name: str, + conversation_store: ConversationHistoryStore, + create_time: int = None, + enable_store=False) -> AsyncGenerator[str, None]: + sequence_number = 0 + response_creation_time = create_time if create_time is not None else int( + time.time()) + final_res: Optional[RequestOutput] = None + + def _send_event(event: OpenAIBaseModel): + nonlocal sequence_number + # Set sequence_number if the event has this attribute + if hasattr(event, 'sequence_number'): + event.sequence_number = sequence_number + sequence_number += 1 + # Get event type from the event's type field if it exists + event_type = getattr(event, 'type', 'unknown') + return (f"event: {event_type}\n" + f"data: {event.model_dump_json(indent=None)}\n\n") + + current_content_index = 0 # FIXME: this number is never changed + current_output_index = 0 + current_item_id = "" # FIXME: this number is never changed + sent_output_item_added = False + + initial_response = ResponsesResponse.from_request( + request, + sampling_params, + model_name=model_name, + created_time=response_creation_time, + output=[], + status="in_progress", + usage=None, + ).model_dump() + yield _send_event( + ResponseCreatedEvent( + type="response.created", + sequence_number=-1, + response=initial_response, + )) + yield _send_event( + ResponseInProgressEvent( + type="response.in_progress", + sequence_number=-1, + response=initial_response, + )) + + tools = [tool.model_dump() for tool in request.tools] + stream_request_id = f"responses-api-{request.request_id}" + async for res in generator: + final_res = res + output = res.outputs[0] + + messages = harmony_adapter.stateful_stream_harmony_tokens_to_openai_messages( + stream_request_id, output.token_ids_diff, tools, + request.tool_choice) + stream_state = harmony_adapter.get_stream_state(stream_request_id) + assert stream_state is not None + parser = stream_state.get_parser() + + if parser.state == StreamState.EXPECT_START: + current_output_index += 1 + sent_output_item_added = False + + if len(messages) > 0: + previous_item = messages[-1] + if previous_item.recipient is not None: + # Deal with tool call here + pass + elif previous_item.channel == "analysis": + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[ + Content( + text=previous_item.content[0].text, + type="reasoning_text", + ), + ], + status="completed", + id=current_item_id, + summary=[], + ) + yield _send_event( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=current_item_id, + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + )) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=reasoning_item, + )) + elif previous_item.channel == "final": + text_content = ResponseOutputText( + type="output_text", + text=previous_item.content[0].text, + annotations=[], + ) + yield _send_event( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=current_output_index, + content_index=current_content_index, + text=previous_item.content[0].text, + logprobs=[], + item_id=current_item_id, + )) + yield _send_event( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + part=text_content, + )) + yield _send_event( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[text_content], + status="completed", + ), + )) + + if parser.last_content_delta: + if (parser.current_channel == "final" + and parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseOutputMessage( + id=current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + )) + yield _send_event( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=current_content_index, + output_index=current_output_index, + item_id=current_item_id, + delta=parser.last_content_delta, + # TODO, use logprobs from ctx.last_request_output + logprobs=[], + )) + elif (parser.current_channel == "analysis" + and parser.current_recipient is None): + if not sent_output_item_added: + sent_output_item_added = True + yield _send_event( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=current_output_index, + item=ResponseReasoningItem( + type="reasoning", + id=current_item_id, + summary=[], + status="in_progress", + ), + )) + yield _send_event( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=current_output_index, + item_id=current_item_id, + content_index=current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + )) + yield _send_event( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + item_id=current_item_id, + output_index=current_output_index, + content_index=current_content_index, + delta=parser.last_content_delta, + sequence_number=-1, + )) + + # TODO(JunyiXu-nv): support built-in tools(python/browser/code interpreter) + + final_response = await create_response(generator, request, sampling_params, + model_name, conversation_store, + final_res, enable_store, + response_creation_time) + + yield _send_event( + ResponseCompletedEvent( + type="response.completed", + sequence_number=-1, + response=final_response.model_dump(), + )) diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index e94e75a68ed..899d9968606 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1618,6 +1618,13 @@ def test_openai_chat_harmony(llm_root, llm_venv): str(test_root / "_test_openai_chat_harmony.py")]) +def test_openai_responses(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_responses.py")]) + + def test_openai_prometheus(llm_root, llm_venv): test_root = unittest_path() / "llmapi" / "apps" llm_venv.run_cmd( diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 04bf47b9c23..2d76e20e9de 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -104,6 +104,7 @@ l0_h100: - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-enable_request_rate] # negative test - test_e2e.py::test_trtllm_bench_help_sanity[meta-llama/Llama-3.1-8B] - test_e2e.py::test_openai_chat_harmony + - test_e2e.py::test_openai_responses - test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True] # ------------- AutoDeploy tests --------------- - accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype diff --git a/tests/unittest/llmapi/apps/_test_openai_responses.py b/tests/unittest/llmapi/apps/_test_openai_responses.py new file mode 100644 index 00000000000..beaa805383c --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_responses.py @@ -0,0 +1,241 @@ +import json + +import openai +import pytest +from openai.types.responses import (ResponseCompletedEvent, + ResponseReasoningTextDeltaEvent, + ResponseTextDeltaEvent) + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(scope="module", ids=["GPT-OSS-20B"]) +def model(): + return "gpt_oss/gpt-oss-20b/" + + +@pytest.fixture(scope="module") +def server(model: str): + model_path = get_model_path(model) + with RemoteOpenAIServer(model_path) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_async_client() + + +def check_reponse(response, prefix=""): + reasoning_exist, message_exist = False, False + for output in response.output: + if output.type == "reasoning": + reasoning_exist = True + elif output.type == "message": + message_exist = True + + assert reasoning_exist, f"{prefix}Reasoning content not exists!" + assert message_exist, f"{prefix}Message content not exists!" + + +def check_tool_calling(response, first_resp=True, prefix=""): + reasoning_exist, tool_call_exist, message_exist = False, False, False + function_call = None + for output in response.output: + if output.type == "reasoning": + reasoning_exist = True + elif output.type == "function_call": + tool_call_exist = True + function_call = output + elif output.type == "message": + message_exist = True + + if first_resp: + assert reasoning_exist and tool_call_exist, f"{prefix}Invalid tool calling 1st response" + assert not message_exist, f"{prefix}Invalid tool calling 1st response" + + return function_call + else: + assert reasoning_exist and message_exist, f"{prefix}Invalid tool calling 2nd response" + assert not tool_call_exist, f"{prefix}Invalid tool calling 2nd response" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_reasoning(client: openai.AsyncOpenAI, model: str): + response = await client.responses.create( + model=model, input="Which one is larger as numeric, 9.9 or 9.11?") + + check_reponse(response, "test_reasoning: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_reasoning_effort(client: openai.AsyncOpenAI, model: str): + for effort in ["low", "medium", "high"]: + response = await client.responses.create( + model=model, + instructions="Use less than 1024 tokens for reasoning", + input="Which one is larger as numeric, 9.9 or 9.11?", + reasoning={"effort": effort}) + check_reponse(response, f"test_reasoning_effort_{effort}: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_chat(client: openai.AsyncOpenAI, model: str): + response = await client.responses.create(model=model, + input=[{ + "role": + "developer", + "content": + "Respond in Chinese." + }, { + "role": "user", + "content": "Hello!" + }, { + "role": + "assistant", + "content": + "Hello! How can I help you?" + }, { + "role": "user", + "content": "Tell me a joke." + }]) + check_reponse(response, "test_chat: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_multi_turn_chat(client: openai.AsyncOpenAI, model: str): + response = await client.responses.create(model=model, + input="What is the answer of 1+1?") + check_reponse(response, "test_multi_turn_chat_1: ") + + response_2 = await client.responses.create( + model=model, + input="What is the answer of previous question?", + previous_response_id=response.id) + check_reponse(response_2, "test_multi_turn_chat_2: ") + + +def get_current_weather(location: str, format: str = "celsius") -> dict: + return {"sunny": True, "temperature": 20 if format == "celsius" else 68} + + +@pytest.mark.asyncio(loop_scope="module") +async def test_tool_calls(client: openai.AsyncOpenAI, model: str): + tool_get_current_weather = { + "type": "function", + "name": "get_current_weather", + "description": "Gets the current weather in the provided location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "description": "default: celsius", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + } + } + messages = [{"role": "user", "content": "What is the weather like in SF?"}] + response = await client.responses.create( + model=model, + input=messages, + tools=[tool_get_current_weather], + ) + messages.extend(response.output) + function_call = check_tool_calling(response, True, "test_tool_calls: ") + + assert function_call.name == "get_current_weather" + + args = json.loads(function_call.arguments) + answer = get_current_weather(**args) + messages.append({ + "type": "function_call_output", + "call_id": function_call.call_id, + "output": json.dumps(answer), + }) + + response = await client.responses.create(model=model, + input=messages, + tools=[tool_get_current_weather]) + + check_tool_calling(response, False, "test_tool_calls: ") + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming(client: openai.AsyncOpenAI, model: str): + stream = await client.responses.create( + model=model, + input="Explain the theory of relativity in brief.", + stream=True, + ) + + reasoning_deltas, message_deltas = list(), list() + async for event in stream: + if isinstance(event, ResponseTextDeltaEvent): + message_deltas.append(event.delta) + elif isinstance(event, ResponseReasoningTextDeltaEvent): + reasoning_deltas.append(event.delta) + + full_response = "".join(message_deltas) + full_reasoning_response = "".join(reasoning_deltas) + assert full_response + assert full_reasoning_response + + +@pytest.mark.asyncio(loop_scope="module") +async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str): + tool_get_current_weather = { + "type": "function", + "name": "get_current_weather", + "description": "Gets the current weather in the provided location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "description": "default: celsius", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + } + } + messages = [{"role": "user", "content": "What is the weather like in SF?"}] + stream = await client.responses.create( + model=model, + input=messages, + tools=[tool_get_current_weather], + stream=True, + ) + + function_call = None + reasoning_deltas = list() + async for event in stream: + if isinstance(event, ResponseCompletedEvent): + for output in event.response.output: + if output.type == "function_call": + function_call = output + elif isinstance(event, ResponseReasoningTextDeltaEvent): + reasoning_deltas.append(event.delta) + + reasoning = "".join(reasoning_deltas) + tool_args = json.loads(function_call.arguments) + + assert function_call.name == "get_current_weather", "wrong function calling name" + assert tool_args, "tool args not exists!" + assert reasoning, "reasoning not exists!" + + get_current_weather(**tool_args)