diff --git a/tensorrt_llm/serve/harmony_adapter.py b/tensorrt_llm/serve/harmony_adapter.py index a46e7c5ed45..2949965d729 100644 --- a/tensorrt_llm/serve/harmony_adapter.py +++ b/tensorrt_llm/serve/harmony_adapter.py @@ -6,7 +6,7 @@ import time import traceback import uuid -from typing import Any, AsyncGenerator, Literal +from typing import Any, List, Literal from openai_harmony import (Author, Conversation, DeveloperContent, HarmonyEncodingName, HarmonyError, Message, @@ -14,15 +14,15 @@ SystemContent, TextContent, ToolDescription, load_harmony_encoding) -from tensorrt_llm.llmapi import RequestOutput from tensorrt_llm.logger import logger # yapf: disable -from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest, +from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, + ChatCompletionStreamResponse, + ChatCompletionToolsParam, ChatMessage, DeltaFunctionCall, DeltaMessage, DeltaToolCall, UsageInfo) @@ -1485,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any], return True -async def handle_streaming_response( - harmony_adapter: HarmonyAdapter, - generator: RequestOutput, - request_id: str, - request: ChatCompletionRequest, -) -> AsyncGenerator[str, None]: - """Handle streaming response with harmony format.""" +_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None + + +def get_harmony_adapter(): + global _SERVE_HARMONY_ADAPTER + if _SERVE_HARMONY_ADAPTER is None: + _SERVE_HARMONY_ADAPTER = HarmonyAdapter() + + return _SERVE_HARMONY_ADAPTER + + +def handle_streaming_response(tools: List[ChatCompletionToolsParam], + tool_choice: str, outputs: List, model: str, + request_id: str, done: bool, + num_prompt_tokens: int): first_iteration = True - async for res in generator: - output = res.outputs[0] + output = outputs[0] - # Convert tools to dictionary format for harmony adapter (standard pattern) - tools_dict = None - if request.tools: - tools_dict = [tool.model_dump() for tool in request.tools] + # Convert tools to dictionary format for harmony adapter (standard pattern) + tools_dict = None + harmony_adapter = get_harmony_adapter() + if tools: + tools_dict = [tool.model_dump() for tool in tools] - # Get tool_choice from request - if "none", don't pass tools to parser - tool_choice = getattr(request, 'tool_choice', None) - if tool_choice == "none": - tools_for_parser = None - else: - tools_for_parser = tools_dict + # Get tool_choice from request - if "none", don't pass tools to parser + if tool_choice == "none": + tools_for_parser = None + else: + tools_for_parser = tools_dict - # Create OpenAI streaming responses - try: + # Create OpenAI streaming responses + try: + res = [] + if done: + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + + usage_info = _create_usage_info(num_prompt_tokens, outputs) + + # Send final message with finish_reason + final_response = ChatCompletionStreamResponse( + model=model, + choices=[ + ChatCompletionResponseStreamChoice( + index=0, + delta=DeltaMessage(), + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + ], + ) + + final_response_json = final_response.model_dump_json( + exclude_none=True) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=model, + usage=usage_info) + final_usage_json = final_usage_chunk.model_dump_json( + exclude_none=True) + res.append(f"data: {final_response_json}\n\n") + res.append(f"data: {final_usage_json}\n\n") + else: responses = harmony_adapter.create_openai_streaming_response( request_id=request_id, tokens=output.token_ids_diff, available_tools=tools_for_parser, - model_name=request.model, + model_name=model, tool_choice=tool_choice) # Send first response after receiving the first output if first_iteration: @@ -1525,64 +1561,44 @@ async def handle_streaming_response( delta=first_delta) first_response = ChatCompletionStreamResponse( - model=request.model, + model=model, choices=[choice], ) response_json = first_response.model_dump_json( exclude_none=True) - yield f"data: {response_json}\n\n" + res.append(f"data: {response_json}\n\n") - for response in responses: - yield response + res.extend(responses) - except Exception as e: - logger.error(f"Failed to create OpenAI streaming response: {e}") - logger.debug(f"Streaming error details: {traceback.format_exc()}") - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) - raise e - - # Clean up state - harmony_adapter.cleanup_stream_state(request_id) + return res - # Send final message with finish_reason - output = generator.outputs[0] - final_response = ChatCompletionStreamResponse( - model=request.model, - choices=[ - ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(), - finish_reason=output.finish_reason, - stop_reason=output.stop_reason) - ]) + except Exception as e: + logger.error(f"Failed to create OpenAI streaming response: {e}") + logger.debug(f"Streaming error details: {traceback.format_exc()}") + # Clean up state + harmony_adapter.cleanup_stream_state(request_id) + raise e - yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n" - yield "data: [DONE]\n\n" - -async def handle_non_streaming_response( - harmony_adapter: HarmonyAdapter, promise: RequestOutput, - request: ChatCompletionRequest) -> ChatCompletionResponse: +def handle_non_streaming_response(tools: List[ChatCompletionToolsParam], + tool_choice: str, outputs: List, model: str, + num_prompt_tokens: int): """Handle non-streaming response with harmony format.""" - # Get final result - await promise - # Parse harmony output to OpenAI format # Convert tools to dictionary format for harmony adapter (standard pattern) tools_dict = None - if request.tools: - tools_dict = [tool.model_dump() for tool in request.tools] + harmony_adapter = get_harmony_adapter() + if tools: + tools_dict = [tool.model_dump() for tool in tools] # Get tool_choice from request - if "none", don't pass tools to parser - tool_choice = getattr(request, 'tool_choice', None) if tool_choice == "none": tools_for_parser = None else: tools_for_parser = tools_dict - output = promise.outputs[0] + output = outputs[0] parsed_output = harmony_adapter.harmony_output_to_openai( output.token_ids, tools_for_parser, tool_choice) @@ -1597,11 +1613,11 @@ async def handle_non_streaming_response( output.finish_reason) # Create usage info from metrics (RequestOutput doesn't have usage in v1) - usage_info = _create_usage_info(promise) + usage_info = _create_usage_info(num_prompt_tokens, outputs) # Create response response = ChatCompletionResponse( - model=request.model, + model=model, choices=[ ChatCompletionResponseChoice( index=0, @@ -1613,7 +1629,6 @@ async def handle_non_streaming_response( # Optional: Log if harmony parsing failed (for debugging) if parsed_output.get('_harmony_parsing_failed'): logger.warning("⚠️ Harmony parsing fell back to raw text decoding") - logger.debug(f"request\n\n{request}") logger.debug(f"response\n\n{response}\n") return response @@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any], return reason -def _create_usage_info(final_res: RequestOutput) -> UsageInfo: +def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo: """Create usage info from RequestOutput following serving_chat.py pattern.""" - # Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids - assert final_res.prompt_token_ids is not None - num_prompt_tokens = len(final_res.prompt_token_ids) - # Calculate completion tokens from all outputs - num_generated_tokens = sum( - len(output.token_ids) for output in final_res.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in outputs) # Create usage info usage = UsageInfo(prompt_tokens=num_prompt_tokens, diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index de245046359..cd1a8546b0d 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -44,9 +44,10 @@ UsageInfo, to_llm_disaggregated_params) from tensorrt_llm.serve.postprocess_handlers import ( - ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor, - chat_stream_post_processor, completion_response_post_processor, - completion_stream_post_processor) + ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs, + chat_harmony_post_processor, chat_harmony_streaming_post_processor, + chat_response_post_processor, chat_stream_post_processor, + completion_response_post_processor, completion_stream_post_processor) from tensorrt_llm.serve.responses_utils import ConversationHistoryStore from tensorrt_llm.serve.responses_utils import \ create_response as responses_api_create_response @@ -57,8 +58,7 @@ from tensorrt_llm.version import __version__ as VERSION from .._utils import nvtx_mark, set_prometheus_multiproc_dir -from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response, - handle_streaming_response, +from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter, maybe_transform_reasoning_effort) # yapf: enale @@ -117,7 +117,11 @@ def __init__(self, # gpt-oss self.harmony_adapter: HarmonyAdapter | None = None - self.use_harmony = self.model_config.model_type == "gpt_oss" + disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1" + if disable_harmony: + self.use_harmony = False + else: + self.use_harmony = (self.model_config.model_type == "gpt_oss") @asynccontextmanager async def lifespan(app: FastAPI): @@ -703,11 +707,35 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques Chat Completion API with harmony format support. Supports both streaming and non-streaming modes. """ + + async def create_harmony_response( + promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse: + await promise.aresult() + if self.postproc_worker_enabled: + chat_response =promise.outputs[0]._postprocess_result + else: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + chat_response = post_processor(promise, args) + + return chat_response + + async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + if not self.postproc_worker_enabled: + post_processor, args = postproc_params.post_processor, postproc_params.postproc_args + + async for res in promise: + pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + # await self._extract_metrics(res) + for pp_res in pp_results: + yield pp_res + + yield "data: [DONE]\n\n" + try: # Initialize HarmonyAdapter # NOTE: WAR for Disagg failure, may affect perf if no warmup if not self.harmony_adapter: - self.harmony_adapter = HarmonyAdapter() + self.harmony_adapter = get_harmony_adapter() # Convert Pydantic models to dictionaries for JSON serialization (standard pattern) tools_dict = None if request.tools: @@ -742,27 +770,37 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques vocab_size=self.tokenizer.tokenizer.vocab_size) sampling_params.detokenize = False # Harmony adapter handles detokenization + postproc_args = ChatCompletionPostprocArgs.from_request(request) + postproc_params = PostprocParams( + post_processor=chat_harmony_streaming_post_processor + if request.stream else chat_harmony_post_processor, + postproc_args=postproc_args, + ) + # Generate promise = self.llm.generate_async( inputs=harmony_tokens, sampling_params=sampling_params, + _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=bool(request.stream), lora_request=request.lora_request, ) + postproc_args.request_id = promise.request_id + + if not self.postproc_worker_enabled: + postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) + # Disconnect cancellation asyncio.create_task(self.await_disconnected(raw_request, promise)) # Handle streaming if request.stream: return StreamingResponse( - handle_streaming_response( - self.harmony_adapter, promise, - str(promise.request_id), request, - ), + content=create_streaming_generator(promise, postproc_params), media_type="text/event-stream" ) else: - response = await handle_non_streaming_response(self.harmony_adapter, promise, request) + response = await create_harmony_response(promise, postproc_params) return JSONResponse(response.model_dump()) except Exception as e: diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index 07db6e27a75..0fbcedb9dac 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -9,6 +9,8 @@ ReasoningParserFactory) from ..llmapi.tokenizer import TransformersTokenizer # yapf: disable +from .harmony_adapter import (handle_non_streaming_response, + handle_streaming_response) from .openai_protocol import (ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, @@ -24,7 +26,8 @@ FunctionCall, StreamOptions, ToolCall, UsageInfo, to_disaggregated_params) -# yapf: enale +# yapf: enable + @dataclass(kw_only=True) class ChatPostprocArgs(PostprocArgs): @@ -57,8 +60,7 @@ def from_request(cls, request: ChatCompletionRequest): ) -def create_logprobs(token_ids: List[int], - tokenizer: TransformersTokenizer, +def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, logprobs: List[float]) -> ChatCompletionLogProbs: assert len(token_ids) == len(logprobs), \ "token_ids and logprobs have different lengths" @@ -75,12 +77,14 @@ def create_logprobs(token_ids: List[int], return chat_logprobs -def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, streaming: bool) -> Tuple[bool, str, str]: +def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, + streaming: bool) -> Tuple[bool, str, str]: reasoning_parser = None if args.reasoning_parser is not None: if output_index not in args.reasoning_parser_dict: - args.reasoning_parser_dict[output_index] = ReasoningParserFactory.create_reasoning_parser( - args.reasoning_parser) + args.reasoning_parser_dict[ + output_index] = ReasoningParserFactory.create_reasoning_parser( + args.reasoning_parser) reasoning_parser = args.reasoning_parser_dict[output_index] in_reasoning = False @@ -97,7 +101,8 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, @nvtx_range_debug("chat_stream_post_processor") -def chat_stream_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> List[str]: +def chat_stream_post_processor(rsp: GenerationResultBase, + args: ChatPostprocArgs) -> List[str]: def yield_first_chat(num_tokens: int, idx: int, @@ -128,9 +133,13 @@ def yield_first_chat(num_tokens: int, include_continuous_usage = False if args.first_iteration: for i in range(args.num_choices): - res.append(f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n") + res.append( + f"data: {yield_first_chat(prompt_tokens, i, role=args.role)} \n\n" + ) if args.echo and args.last_message_content: - res.append(f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n") + res.append( + f"data: {yield_first_chat(prompt_tokens, i, content=args.last_message_content)} \n\n" + ) args.first_iteration = False for output in rsp.outputs: @@ -158,14 +167,18 @@ def yield_first_chat(num_tokens: int, delta_message = DeltaMessage( content=delta_text, reasoning_content=reasoning_delta_text) - choice = ChatCompletionResponseStreamChoice(index=i, - delta=delta_message, - finish_reason=None, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None)) + choice = ChatCompletionResponseStreamChoice( + index=i, + delta=delta_message, + finish_reason=None, + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None)) if args.return_logprobs: logprobs = output.logprobs_diff token_ids = output.token_ids_diff - choice.logprobs = create_logprobs(token_ids, args.tokenizer, logprobs) + choice.logprobs = create_logprobs(token_ids, args.tokenizer, + logprobs) if output.finish_reason is not None: choice.finish_reason = output.finish_reason choice.stop_reason = output.stop_reason @@ -179,57 +192,62 @@ def yield_first_chat(num_tokens: int, res.append(f"data: {data}\n\n") if include_usage and rsp._done: - completion_tokens = sum(output.length - for output in rsp.outputs) + completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=args.model, usage=final_usage) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=args.model, + usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") return res @nvtx_range_debug("chat_response_post_processor") -def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocArgs) -> ChatCompletionResponse: +def chat_response_post_processor( + rsp: GenerationResultBase, + args: ChatPostprocArgs) -> ChatCompletionResponse: choices: List[ChatCompletionResponseChoice] = [] role = args.role for output in rsp.outputs: _, text, reasoning_text = apply_reasoning_parser( args, output.index, output.text, False) - if args.tool_choice and isinstance( - args.tool_choice, - ChatCompletionNamedToolChoiceParam): + if args.tool_choice and isinstance(args.tool_choice, + ChatCompletionNamedToolChoiceParam): message = ChatMessage( role=role, content="", tool_calls=[ ToolCall(function=FunctionCall( - name=args.tool_choice.function.name, - arguments=text)) + name=args.tool_choice.function.name, arguments=text)) ]) else: if text is None: text = "" - message = ChatMessage( - role=role, content=text, reasoning_content=reasoning_text) - disaggregated_params = to_disaggregated_params(output.disaggregated_params) + message = ChatMessage(role=role, + content=text, + reasoning_content=reasoning_text) + disaggregated_params = to_disaggregated_params( + output.disaggregated_params) choice = ChatCompletionResponseChoice( index=output.index, message=message, finish_reason=output.finish_reason, stop_reason=output.stop_reason, disaggregated_params=disaggregated_params, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) if args.return_logprobs: - choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, output.logprobs) + choice.logprobs = create_logprobs(output.token_ids, args.tokenizer, + output.logprobs) choices.append(choice) if args.echo and args.last_message_content: @@ -238,8 +256,7 @@ def chat_response_post_processor(rsp: GenerationResultBase, args: ChatPostprocAr choice.message.content = full_message num_prompt_tokens = args.num_prompt_tokens - num_generated_tokens = sum( - len(output.token_ids) for output in rsp.outputs) + num_generated_tokens = sum(len(output.token_ids) for output in rsp.outputs) usage = UsageInfo( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, @@ -275,7 +292,8 @@ def from_request(cls, request: CompletionRequest): @nvtx_range_debug("completion_stream_post_processor") -def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: CompletionPostprocArgs) -> List[str]: +def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, + args: CompletionPostprocArgs) -> List[str]: res: List[str] = [] prompt_tokens = args.num_prompt_tokens if stream_option := args.stream_options: @@ -293,9 +311,11 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: index=args.prompt_idx * args.num_choices + output.index, text=delta_text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids_diff, - finish_reason = output.finish_reason, - stop_reason = output.stop_reason, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + finish_reason=output.finish_reason, + stop_reason=output.stop_reason, + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) chunk = CompletionStreamResponse(model=args.model, choices=[choice]) if include_continuous_usage: @@ -306,16 +326,16 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: res.append(f"data: {data}\n\n") if include_usage and rsp._done: - completion_tokens = sum(output.length - for output in rsp.outputs) + completion_tokens = sum(output.length for output in rsp.outputs) final_usage = UsageInfo( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens, ) - final_usage_chunk = ChatCompletionStreamResponse( - choices=[], model=args.model, usage=final_usage) + final_usage_chunk = ChatCompletionStreamResponse(choices=[], + model=args.model, + usage=final_usage) final_usage_data = final_usage_chunk.model_dump_json() res.append(f"data: {final_usage_data}\n\n") args.first_iteration = False @@ -323,7 +343,9 @@ def completion_stream_post_processor(rsp: DetokenizedGenerationResultBase, args: @nvtx_range_debug("completion_response_post_processor") -def completion_response_post_processor(rsp: GenerationResult, args: CompletionPostprocArgs) -> CompletionResponse: +def completion_response_post_processor( + rsp: GenerationResult, + args: CompletionPostprocArgs) -> CompletionResponse: prompt_tokens = args.num_prompt_tokens completion_tokens = 0 choices = [] @@ -331,23 +353,75 @@ def completion_response_post_processor(rsp: GenerationResult, args: CompletionPo text = output.text if args.echo: text = args.prompt + text - disaggregated_params = to_disaggregated_params(output.disaggregated_params) + disaggregated_params = to_disaggregated_params( + output.disaggregated_params) choice = CompletionResponseChoice( text=text if args.detokenize else "", token_ids=None if args.detokenize else output.token_ids, index=args.prompt_idx * args.num_choices + output.index, disaggregated_params=disaggregated_params, - context_logits=None if rsp.context_logits is None else rsp.context_logits.tolist(), + context_logits=None + if rsp.context_logits is None else rsp.context_logits.tolist(), stop_reason=output.stop_reason, finish_reason=output.finish_reason, - avg_decoded_tokens_per_iter=getattr(rsp, 'avg_decoded_tokens_per_iter', None), + avg_decoded_tokens_per_iter=getattr(rsp, + 'avg_decoded_tokens_per_iter', + None), ) completion_tokens += output.length choices.append(choice) usage = UsageInfo(prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=completion_tokens + prompt_tokens) - response = CompletionResponse(choices=choices, model=args.model, usage=usage) + completion_tokens=completion_tokens, + total_tokens=completion_tokens + prompt_tokens) + response = CompletionResponse(choices=choices, + model=args.model, + usage=usage) + return response + + +@dataclass(kw_only=True) +class ChatCompletionPostprocArgs(PostprocArgs): + model: str + tools: Optional[List[ChatCompletionToolsParam]] + tool_choice: Optional[Union[Literal["none", "auto"], + ChatCompletionNamedToolChoiceParam]] + request_id: Optional[int] = None + + @classmethod + def from_request(cls, request: ChatCompletionRequest): + return cls( + model=request.model, + tools=request.tools, + tool_choice=request.tool_choice, + ) + + +@nvtx_range_debug("chat_harmony_post_processor") +def chat_harmony_post_processor( + rsp: GenerationResult, + args: ChatCompletionPostprocArgs) -> ChatCompletionResponse: + response = handle_non_streaming_response( + tools=args.tools, + tool_choice=args.tool_choice, + outputs=rsp.outputs, + model=args.model, + num_prompt_tokens=args.num_prompt_tokens, + ) + return response + + +@nvtx_range_debug("chat_harmony_streaming_post_processor") +def chat_harmony_streaming_post_processor( + rsp: GenerationResult, args: ChatCompletionPostprocArgs) -> List[str]: + response = handle_streaming_response( + tools=args.tools, + tool_choice=args.tool_choice, + outputs=rsp.outputs, + model=args.model, + request_id=args.request_id, + done=rsp._done, + num_prompt_tokens=args.num_prompt_tokens, + ) return response diff --git a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py index 0204a04acff..ba6c7d53379 100644 --- a/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py +++ b/tests/unittest/llmapi/apps/_test_openai_chat_harmony.py @@ -147,6 +147,10 @@ async def test_streaming(client: openai.AsyncOpenAI, model: str): collected_chunks = [] collected_messages = [] async for chunk in response: + # Last streaming response will only contains usage info + if len(chunk.choices) <= 0: + continue + collected_chunks.append(chunk) collected_messages.append(chunk.choices[0].delta) @@ -198,6 +202,10 @@ async def test_streaming_tool_call(client: openai.AsyncOpenAI, model: str): reasoning_chunks: list[str] = [] tool_arg_chunks: list[str] = [] async for chunk in response: + # Last streaming response will only contains usage info + if len(chunk.choices) <= 0: + continue + delta = chunk.choices[0].delta if hasattr(delta, "tool_calls") and delta.tool_calls: function = delta.tool_calls[0].function