Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 83 additions & 73 deletions tensorrt_llm/serve/harmony_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@
import time
import traceback
import uuid
from typing import Any, AsyncGenerator, Literal
from typing import Any, List, Literal

from openai_harmony import (Author, Conversation, DeveloperContent,
HarmonyEncodingName, HarmonyError, Message,
ReasoningEffort, Role, StreamableParser,
SystemContent, TextContent, ToolDescription,
load_harmony_encoding)

from tensorrt_llm.llmapi import RequestOutput
from tensorrt_llm.logger import logger

# yapf: disable
from .openai_protocol import (ChatCompletionMessageParam, ChatCompletionRequest,
from .openai_protocol import (ChatCompletionMessageParam,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage,
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
UsageInfo)

Expand Down Expand Up @@ -1485,36 +1485,72 @@ def _is_tool_call_allowed(self, tool_call: dict[str, Any],
return True


async def handle_streaming_response(
harmony_adapter: HarmonyAdapter,
generator: RequestOutput,
request_id: str,
request: ChatCompletionRequest,
) -> AsyncGenerator[str, None]:
"""Handle streaming response with harmony format."""
_SERVE_HARMONY_ADAPTER: HarmonyAdapter = None


def get_harmony_adapter():
global _SERVE_HARMONY_ADAPTER
if _SERVE_HARMONY_ADAPTER is None:
_SERVE_HARMONY_ADAPTER = HarmonyAdapter()

return _SERVE_HARMONY_ADAPTER


def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
request_id: str, done: bool,
num_prompt_tokens: int):
first_iteration = True
async for res in generator:
output = res.outputs[0]
output = outputs[0]

# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]

# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict
# Get tool_choice from request - if "none", don't pass tools to parser
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict

# Create OpenAI streaming responses
try:
# Create OpenAI streaming responses
try:
res = []
if done:
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)

usage_info = _create_usage_info(num_prompt_tokens, outputs)

# Send final message with finish_reason
final_response = ChatCompletionStreamResponse(
model=model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
],
)

final_response_json = final_response.model_dump_json(
exclude_none=True)
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
usage=usage_info)
final_usage_json = final_usage_chunk.model_dump_json(
exclude_none=True)
res.append(f"data: {final_response_json}\n\n")
res.append(f"data: {final_usage_json}\n\n")
else:
responses = harmony_adapter.create_openai_streaming_response(
request_id=request_id,
tokens=output.token_ids_diff,
available_tools=tools_for_parser,
model_name=request.model,
model_name=model,
tool_choice=tool_choice)
# Send first response after receiving the first output
if first_iteration:
Expand All @@ -1525,64 +1561,44 @@ async def handle_streaming_response(
delta=first_delta)

first_response = ChatCompletionStreamResponse(
model=request.model,
model=model,
choices=[choice],
)

response_json = first_response.model_dump_json(
exclude_none=True)
yield f"data: {response_json}\n\n"
res.append(f"data: {response_json}\n\n")

for response in responses:
yield response
res.extend(responses)

except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e

# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
return res

# Send final message with finish_reason
output = generator.outputs[0]
final_response = ChatCompletionStreamResponse(
model=request.model,
choices=[
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason=output.finish_reason,
stop_reason=output.stop_reason)
])
except Exception as e:
logger.error(f"Failed to create OpenAI streaming response: {e}")
logger.debug(f"Streaming error details: {traceback.format_exc()}")
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)
raise e

yield f"data: {final_response.model_dump_json(exclude_unset=True)}\n\n"
yield "data: [DONE]\n\n"


async def handle_non_streaming_response(
harmony_adapter: HarmonyAdapter, promise: RequestOutput,
request: ChatCompletionRequest) -> ChatCompletionResponse:
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
num_prompt_tokens: int):
"""Handle non-streaming response with harmony format."""
# Get final result
await promise

# Parse harmony output to OpenAI format
# Convert tools to dictionary format for harmony adapter (standard pattern)
tools_dict = None
if request.tools:
tools_dict = [tool.model_dump() for tool in request.tools]
harmony_adapter = get_harmony_adapter()
if tools:
tools_dict = [tool.model_dump() for tool in tools]

# Get tool_choice from request - if "none", don't pass tools to parser
tool_choice = getattr(request, 'tool_choice', None)
if tool_choice == "none":
tools_for_parser = None
else:
tools_for_parser = tools_dict

output = promise.outputs[0]
output = outputs[0]
parsed_output = harmony_adapter.harmony_output_to_openai(
output.token_ids, tools_for_parser, tool_choice)

Expand All @@ -1597,11 +1613,11 @@ async def handle_non_streaming_response(
output.finish_reason)

# Create usage info from metrics (RequestOutput doesn't have usage in v1)
usage_info = _create_usage_info(promise)
usage_info = _create_usage_info(num_prompt_tokens, outputs)

# Create response
response = ChatCompletionResponse(
model=request.model,
model=model,
choices=[
ChatCompletionResponseChoice(
index=0,
Expand All @@ -1613,7 +1629,6 @@ async def handle_non_streaming_response(
# Optional: Log if harmony parsing failed (for debugging)
if parsed_output.get('_harmony_parsing_failed'):
logger.warning("⚠️ Harmony parsing fell back to raw text decoding")
logger.debug(f"request\n\n{request}")
logger.debug(f"response\n\n{response}\n")

return response
Expand Down Expand Up @@ -1646,15 +1661,10 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
return reason


def _create_usage_info(final_res: RequestOutput) -> UsageInfo:
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
"""Create usage info from RequestOutput following serving_chat.py pattern."""
# Calculate prompt tokens from prompt_token_ids and encoder_prompt_token_ids
assert final_res.prompt_token_ids is not None
num_prompt_tokens = len(final_res.prompt_token_ids)

# Calculate completion tokens from all outputs
num_generated_tokens = sum(
len(output.token_ids) for output in final_res.outputs)
num_generated_tokens = sum(len(output.token_ids) for output in outputs)

# Create usage info
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
Expand Down
62 changes: 50 additions & 12 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
UsageInfo,
to_llm_disaggregated_params)
from tensorrt_llm.serve.postprocess_handlers import (
ChatPostprocArgs, CompletionPostprocArgs, chat_response_post_processor,
chat_stream_post_processor, completion_response_post_processor,
completion_stream_post_processor)
ChatCompletionPostprocArgs, ChatPostprocArgs, CompletionPostprocArgs,
chat_harmony_post_processor, chat_harmony_streaming_post_processor,
chat_response_post_processor, chat_stream_post_processor,
completion_response_post_processor, completion_stream_post_processor)
from tensorrt_llm.serve.responses_utils import ConversationHistoryStore
from tensorrt_llm.serve.responses_utils import \
create_response as responses_api_create_response
Expand All @@ -57,8 +58,7 @@
from tensorrt_llm.version import __version__ as VERSION

from .._utils import nvtx_mark, set_prometheus_multiproc_dir
from .harmony_adapter import (HarmonyAdapter, handle_non_streaming_response,
handle_streaming_response,
from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter,
maybe_transform_reasoning_effort)

# yapf: enale
Expand Down Expand Up @@ -117,7 +117,11 @@ def __init__(self,

# gpt-oss
self.harmony_adapter: HarmonyAdapter | None = None
self.use_harmony = self.model_config.model_type == "gpt_oss"
disable_harmony = os.getenv("DISABLE_HARMONY_ADAPTER", "0") == "1"
if disable_harmony:
self.use_harmony = False
else:
self.use_harmony = (self.model_config.model_type == "gpt_oss")

@asynccontextmanager
async def lifespan(app: FastAPI):
Expand Down Expand Up @@ -703,11 +707,35 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques
Chat Completion API with harmony format support.
Supports both streaming and non-streaming modes.
"""

async def create_harmony_response(
promise: RequestOutput, postproc_params: PostprocParams) -> ChatCompletionResponse:
await promise.aresult()
if self.postproc_worker_enabled:
chat_response =promise.outputs[0]._postprocess_result
else:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
chat_response = post_processor(promise, args)

return chat_response

async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams):
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args

async for res in promise:
pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args)
# await self._extract_metrics(res)
for pp_res in pp_results:
yield pp_res

yield "data: [DONE]\n\n"

try:
# Initialize HarmonyAdapter
# NOTE: WAR for Disagg failure, may affect perf if no warmup
if not self.harmony_adapter:
self.harmony_adapter = HarmonyAdapter()
self.harmony_adapter = get_harmony_adapter()
# Convert Pydantic models to dictionaries for JSON serialization (standard pattern)
tools_dict = None
if request.tools:
Expand Down Expand Up @@ -742,27 +770,37 @@ async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Reques
vocab_size=self.tokenizer.tokenizer.vocab_size)
sampling_params.detokenize = False # Harmony adapter handles detokenization

postproc_args = ChatCompletionPostprocArgs.from_request(request)
postproc_params = PostprocParams(
post_processor=chat_harmony_streaming_post_processor
if request.stream else chat_harmony_post_processor,
postproc_args=postproc_args,
)

# Generate
promise = self.llm.generate_async(
inputs=harmony_tokens,
sampling_params=sampling_params,
_postproc_params=postproc_params if self.postproc_worker_enabled else None,
streaming=bool(request.stream),
lora_request=request.lora_request,
)
postproc_args.request_id = promise.request_id

if not self.postproc_worker_enabled:
postproc_args.num_prompt_tokens = len(promise.prompt_token_ids)

# Disconnect cancellation
asyncio.create_task(self.await_disconnected(raw_request, promise))

# Handle streaming
if request.stream:
return StreamingResponse(
handle_streaming_response(
self.harmony_adapter, promise,
str(promise.request_id), request,
),
content=create_streaming_generator(promise, postproc_params),
media_type="text/event-stream"
)
else:
response = await handle_non_streaming_response(self.harmony_adapter, promise, request)
response = await create_harmony_response(promise, postproc_params)
return JSONResponse(response.model_dump())

except Exception as e:
Expand Down
Loading