From 1efb295835a492ae84281af3fa7f727c23e39fc3 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sat, 6 Dec 2025 00:40:47 -0800 Subject: [PATCH 1/2] [model-gateway] refactor oai streaming --- sgl-model-gateway/src/routers/openai/mod.rs | 1 - .../src/routers/openai/streaming.rs | 734 ++++++++++-------- sgl-model-gateway/src/routers/openai/utils.rs | 63 -- 3 files changed, 416 insertions(+), 382 deletions(-) delete mode 100644 sgl-model-gateway/src/routers/openai/utils.rs diff --git a/sgl-model-gateway/src/routers/openai/mod.rs b/sgl-model-gateway/src/routers/openai/mod.rs index 04358d73e019..3288657afdb7 100644 --- a/sgl-model-gateway/src/routers/openai/mod.rs +++ b/sgl-model-gateway/src/routers/openai/mod.rs @@ -14,7 +14,6 @@ pub mod provider; mod responses; mod router; mod streaming; -mod utils; // Re-export the main types for external use pub use provider::{Provider, ProviderError, ProviderRegistry}; diff --git a/sgl-model-gateway/src/routers/openai/streaming.rs b/sgl-model-gateway/src/routers/openai/streaming.rs index 5b5c547ec016..b61484950753 100644 --- a/sgl-model-gateway/src/routers/openai/streaming.rs +++ b/sgl-model-gateway/src/routers/openai/streaming.rs @@ -7,7 +7,7 @@ //! - MCP tool execution loops within streaming responses //! - Event transformation and output index remapping -use std::{borrow::Cow, io, sync::Arc}; +use std::{borrow::Cow, collections::HashMap, io, sync::Arc}; use axum::{ body::Body, @@ -28,10 +28,9 @@ use super::{ mcp::{ build_resume_payload, ensure_request_mcp_client, execute_streaming_tool_calls, inject_mcp_metadata_streaming, prepare_mcp_payload_for_streaming, - send_mcp_list_tools_events, McpLoopConfig, ToolLoopState, + send_mcp_list_tools_events, FunctionCallInProgress, McpLoopConfig, ToolLoopState, }, responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block}, - utils::{FunctionCallInProgress, OutputIndexMapper, StreamAction}, }; use crate::{ protocols::{ @@ -44,6 +43,135 @@ use crate::{ routers::header_utils::{apply_request_headers, preserve_response_headers}, }; +// ============================================================================ +// Stream Action Enum +// ============================================================================ + +/// Action to take based on streaming event processing +#[derive(Debug)] +pub(crate) enum StreamAction { + Forward, // Pass event to client + Buffer, // Accumulate for tool execution + ExecuteTools, // Function call complete, execute now +} + +// ============================================================================ +// Output Index Mapper +// ============================================================================ + +/// Maps upstream output indices to sequential downstream indices +#[derive(Debug, Default)] +pub(crate) struct OutputIndexMapper { + next_index: usize, + // Map upstream output_index -> remapped output_index + assigned: HashMap, +} + +impl OutputIndexMapper { + pub fn with_start(next_index: usize) -> Self { + Self { + next_index, + assigned: HashMap::new(), + } + } + + pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize { + *self.assigned.entry(upstream_index).or_insert_with(|| { + let assigned = self.next_index; + self.next_index += 1; + assigned + }) + } + + pub fn lookup(&self, upstream_index: usize) -> Option { + self.assigned.get(&upstream_index).copied() + } + + pub fn allocate_synthetic(&mut self) -> usize { + let assigned = self.next_index; + self.next_index += 1; + assigned + } + + pub fn next_index(&self) -> usize { + self.next_index + } +} + +// ============================================================================ +// Helper Functions +// ============================================================================ + +/// Extract output_index from a JSON value +#[inline] +fn extract_output_index(value: &Value) -> Option { + value.get("output_index")?.as_u64().map(|v| v as usize) +} + +/// Get event type from event name or parsed JSON, returning a reference to avoid allocation +#[inline] +fn get_event_type<'a>(event_name: Option<&'a str>, parsed: &'a Value) -> &'a str { + event_name + .or_else(|| parsed.get("type").and_then(|v| v.as_str())) + .unwrap_or("") +} + +// ============================================================================ +// Chunk Processor +// ============================================================================ + +/// Processes incoming byte chunks into complete SSE blocks. +/// Handles buffering of partial chunks and CRLF normalization. +pub(super) struct ChunkProcessor { + pending: String, +} + +impl ChunkProcessor { + pub fn new() -> Self { + Self { + pending: String::new(), + } + } + + /// Append a chunk to the buffer, normalizing line endings + pub fn push_chunk(&mut self, chunk: &[u8]) { + let chunk_str = match std::str::from_utf8(chunk) { + Ok(s) => Cow::Borrowed(s), + Err(_) => Cow::Owned(String::from_utf8_lossy(chunk).into_owned()), + }; + // Normalize CRLF to LF + if chunk_str.contains("\r\n") { + self.pending.push_str(&chunk_str.replace("\r\n", "\n")); + } else { + self.pending.push_str(&chunk_str); + } + } + + /// Extract the next complete SSE block from the buffer, if available + pub fn next_block(&mut self) -> Option { + let pos = self.pending.find("\n\n")?; + let block = self.pending[..pos].to_string(); + self.pending.drain(..pos + 2); + + if block.trim().is_empty() { + // Skip empty blocks, try next + self.next_block() + } else { + Some(block) + } + } + + /// Check if there's remaining content in the buffer + pub fn has_remaining(&self) -> bool { + !self.pending.trim().is_empty() + } + + /// Take any remaining content from the buffer + pub fn take_remaining(&mut self) -> String { + std::mem::take(&mut self.pending) + } +} + // ============================================================================ // Streaming Response Accumulator // ============================================================================ @@ -156,17 +284,7 @@ impl StreamingResponseAccumulator { } }; - let event_type = event_name - .map(|s| s.to_string()) - .or_else(|| { - parsed - .get("type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }) - .unwrap_or_default(); - - match event_type.as_str() { + match get_event_type(event_name, &parsed) { ResponseEvent::CREATED => { if self.initial_response.is_none() { if let Some(response) = parsed.get("response") { @@ -180,13 +298,9 @@ impl StreamingResponseAccumulator { } } OutputItemEvent::DONE => { - if let (Some(index), Some(item)) = ( - parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize), - parsed.get("item"), - ) { + if let (Some(index), Some(item)) = + (extract_output_index(&parsed), parsed.get("item")) + { self.output_items.push((index, item.clone())); } } @@ -287,130 +401,26 @@ impl StreamingToolHandler { Err(_) => return StreamAction::Forward, }; - let event_type = event_name - .map(|s| s.to_string()) - .or_else(|| { - parsed - .get("type") - .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - }) - .unwrap_or_default(); - - match event_type.as_str() { + match get_event_type(event_name, &parsed) { ResponseEvent::CREATED => { if self.original_response_id.is_none() { - if let Some(response_obj) = parsed.get("response").and_then(|v| v.as_object()) { - if let Some(id) = response_obj.get("id").and_then(|v| v.as_str()) { - self.original_response_id = Some(id.to_string()); - } - } + self.original_response_id = parsed + .get("response") + .and_then(|v| v.get("id")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); } StreamAction::Forward } ResponseEvent::COMPLETED => StreamAction::Forward, - OutputItemEvent::ADDED => { - if let Some(idx) = parsed.get("output_index").and_then(|v| v.as_u64()) { - self.ensure_output_index(idx as usize); - } - - // Check if this is a function_call item being added - if let Some(item) = parsed.get("item") { - if let Some(item_type) = item.get("type").and_then(|v| v.as_str()) { - if is_function_call_type(item_type) { - match parsed.get("output_index").and_then(|v| v.as_u64()) { - Some(idx) => { - let output_index = idx as usize; - let assigned_index = self.ensure_output_index(output_index); - let call_id = - item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); - let name = - item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - - // Create or update the function call - let call = self.get_or_create_call(output_index, item); - call.call_id = call_id.to_string(); - call.name = name.to_string(); - call.assigned_output_index = Some(assigned_index); - - self.in_function_call = true; - } - None => { - warn!( - "Missing output_index in function_call added event, \ - forwarding without processing for tool execution" - ); - } - } - } - } - } - StreamAction::Forward - } - FunctionCallEvent::ARGUMENTS_DELTA => { - // Accumulate arguments for the function call - if let Some(output_index) = parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - let assigned_index = self.ensure_output_index(output_index); - if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) { - if let Some(call) = self - .pending_calls - .iter_mut() - .find(|c| c.output_index == output_index) - { - call.arguments_buffer.push_str(delta); - if let Some(obfuscation) = - parsed.get("obfuscation").and_then(|v| v.as_str()) - { - call.last_obfuscation = Some(obfuscation.to_string()); - } - if call.assigned_output_index.is_none() { - call.assigned_output_index = Some(assigned_index); - } - } - } - } - StreamAction::Forward - } - FunctionCallEvent::ARGUMENTS_DONE => { - // Function call arguments complete - check if ready to execute - if let Some(output_index) = parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - let assigned_index = self.ensure_output_index(output_index); - if let Some(call) = self - .pending_calls - .iter_mut() - .find(|c| c.output_index == output_index) - { - if call.assigned_output_index.is_none() { - call.assigned_output_index = Some(assigned_index); - } - } - } - - if self.has_complete_calls() { - StreamAction::ExecuteTools - } else { - StreamAction::Forward - } - } + OutputItemEvent::ADDED => self.handle_output_item_added(&parsed), + FunctionCallEvent::ARGUMENTS_DELTA => self.handle_arguments_delta(&parsed), + FunctionCallEvent::ARGUMENTS_DONE => self.handle_arguments_done(&parsed), OutputItemEvent::DELTA => self.process_output_delta(&parsed), OutputItemEvent::DONE => { - // Check if we have complete function calls ready to execute - if let Some(output_index) = parsed - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { + if let Some(output_index) = extract_output_index(&parsed) { self.ensure_output_index(output_index); } - if self.has_complete_calls() { StreamAction::ExecuteTools } else { @@ -421,14 +431,91 @@ impl StreamingToolHandler { } } + fn handle_output_item_added(&mut self, parsed: &Value) -> StreamAction { + if let Some(output_index) = extract_output_index(parsed) { + self.ensure_output_index(output_index); + } + + // Check if this is a function_call item being added + let Some(item) = parsed.get("item") else { + return StreamAction::Forward; + }; + let Some(item_type) = item.get("type").and_then(|v| v.as_str()) else { + return StreamAction::Forward; + }; + + if !is_function_call_type(item_type) { + return StreamAction::Forward; + } + + let Some(output_index) = extract_output_index(parsed) else { + warn!( + "Missing output_index in function_call added event, \ + forwarding without processing for tool execution" + ); + return StreamAction::Forward; + }; + + let assigned_index = self.ensure_output_index(output_index); + let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); + let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + + let call = self.get_or_create_call(output_index, item); + call.call_id = call_id.to_string(); + call.name = name.to_string(); + call.assigned_output_index = Some(assigned_index); + self.in_function_call = true; + + StreamAction::Forward + } + + fn handle_arguments_delta(&mut self, parsed: &Value) -> StreamAction { + let Some(output_index) = extract_output_index(parsed) else { + return StreamAction::Forward; + }; + + let assigned_index = self.ensure_output_index(output_index); + + if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) { + if let Some(call) = self.find_call_mut(output_index) { + call.arguments_buffer.push_str(delta); + if let Some(obfuscation) = parsed.get("obfuscation").and_then(|v| v.as_str()) { + call.last_obfuscation = Some(obfuscation.to_string()); + } + if call.assigned_output_index.is_none() { + call.assigned_output_index = Some(assigned_index); + } + } + } + StreamAction::Forward + } + + fn handle_arguments_done(&mut self, parsed: &Value) -> StreamAction { + if let Some(output_index) = extract_output_index(parsed) { + let assigned_index = self.ensure_output_index(output_index); + if let Some(call) = self.find_call_mut(output_index) { + if call.assigned_output_index.is_none() { + call.assigned_output_index = Some(assigned_index); + } + } + } + + if self.has_complete_calls() { + StreamAction::ExecuteTools + } else { + StreamAction::Forward + } + } + + fn find_call_mut(&mut self, output_index: usize) -> Option<&mut FunctionCallInProgress> { + self.pending_calls + .iter_mut() + .find(|c| c.output_index == output_index) + } + /// Process output delta events to detect and accumulate function calls fn process_output_delta(&mut self, event: &Value) -> StreamAction { - let output_index = event - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - .unwrap_or(0); - + let output_index = extract_output_index(event).unwrap_or(0); let assigned_index = self.ensure_output_index(output_index); let delta = match event.get("delta") { @@ -557,13 +644,14 @@ pub(super) fn apply_event_transformations_inplace( let mut changed = false; // 1. Apply rewrite_streaming_block logic (store, previous_response_id, tools masking) + // Get event_type as owned String to avoid borrow conflict with mutable operations below let event_type = parsed_data .get("type") .and_then(|v| v.as_str()) - .map(|s| s.to_string()) - .unwrap_or_default(); - - let should_patch = is_response_event(event_type.as_str()); + .unwrap_or(""); + let should_patch = is_response_event(event_type); + // Need owned copy for the match below since we mutate parsed_data + let event_type = event_type.to_string(); if should_patch { if let Some(response_obj) = parsed_data @@ -674,6 +762,109 @@ fn build_mcp_tools_value(original_body: &ResponsesRequest) -> Option { Some(Value::Array(tools_array)) } +/// Send an SSE event to the client channel +/// Returns false if client disconnected +#[inline] +fn send_sse_event( + tx: &mpsc::UnboundedSender>, + event_name: &str, + data: &Value, +) -> bool { + let block = format!("event: {}\ndata: {}\n\n", event_name, data); + tx.send(Ok(Bytes::from(block))).is_ok() +} + +/// Transform fc_* item IDs to mcp_* format +#[inline] +fn transform_fc_to_mcp_id(item_id: &str) -> String { + item_id + .strip_prefix("fc_") + .map(|stripped| format!("mcp_{}", stripped)) + .unwrap_or_else(|| item_id.to_string()) +} + +/// Map function_call event names to mcp_call event names +#[inline] +fn map_event_name(event_name: &str) -> &str { + match event_name { + FunctionCallEvent::ARGUMENTS_DELTA => McpEvent::CALL_ARGUMENTS_DELTA, + FunctionCallEvent::ARGUMENTS_DONE => McpEvent::CALL_ARGUMENTS_DONE, + other => other, + } +} + +/// Send buffered function call arguments as a synthetic delta event. +/// Returns false if client disconnected. +fn send_buffered_arguments( + parsed_data: &mut Value, + handler: &StreamingToolHandler, + tx: &mpsc::UnboundedSender>, + sequence_number: &mut u64, + mapped_output_index: &mut Option, +) -> bool { + let Some(output_index) = extract_output_index(parsed_data) else { + return true; + }; + + let assigned_index = handler + .mapped_output_index(output_index) + .unwrap_or(output_index); + *mapped_output_index = Some(assigned_index); + + let Some(call) = handler + .pending_calls + .iter() + .find(|c| c.output_index == output_index) + else { + return true; + }; + + let arguments_value = if call.arguments_buffer.is_empty() { + "{}".to_string() + } else { + call.arguments_buffer.clone() + }; + + // Update the done event with full arguments + parsed_data["arguments"] = Value::String(arguments_value.clone()); + + // Transform item_id + let item_id = parsed_data + .get("item_id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let mcp_item_id = transform_fc_to_mcp_id(item_id); + + // Build synthetic delta event + let mut delta_event = json!({ + "type": McpEvent::CALL_ARGUMENTS_DELTA, + "sequence_number": *sequence_number, + "output_index": assigned_index, + "item_id": mcp_item_id, + "delta": arguments_value, + }); + + // Add obfuscation if present + let obfuscation = call + .last_obfuscation + .as_ref() + .map(|s| Value::String(s.clone())) + .or_else(|| parsed_data.get("obfuscation").cloned()); + + if let Some(obf) = obfuscation { + if let Some(obj) = delta_event.as_object_mut() { + obj.insert("obfuscation".to_string(), obf); + } + } + + if !send_sse_event(tx, McpEvent::CALL_ARGUMENTS_DELTA, &delta_event) { + return false; + } + + *sequence_number += 1; + true +} + /// Forward and transform a streaming event to the client /// Returns false if client disconnected pub(super) fn forward_streaming_event( @@ -690,117 +881,48 @@ pub(super) fn forward_streaming_event( return true; } - // Parse JSON data once (optimized!) + // Parse JSON data once let mut parsed_data: Value = match serde_json::from_str(data) { Ok(v) => v, Err(_) => { - // If parsing fails, forward raw block as-is - let chunk_to_send = format!("{}\n\n", raw_block); - return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok(); + let chunk = format!("{}\n\n", raw_block); + return tx.send(Ok(Bytes::from(chunk))).is_ok(); } }; - let event_type = event_name - .or_else(|| parsed_data.get("type").and_then(|v| v.as_str())) - .unwrap_or(""); - + let event_type = get_event_type(event_name, &parsed_data); if event_type == ResponseEvent::COMPLETED { return true; } - // Check if this is function_call_arguments.done - need to send buffered args first + // Handle function_call_arguments.done - send buffered args first let mut mapped_output_index: Option = None; - - if event_name == Some(FunctionCallEvent::ARGUMENTS_DONE) { - if let Some(output_index) = parsed_data - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - let assigned_index = handler - .mapped_output_index(output_index) - .unwrap_or(output_index); - mapped_output_index = Some(assigned_index); - - if let Some(call) = handler - .pending_calls - .iter() - .find(|c| c.output_index == output_index) - { - let arguments_value = if call.arguments_buffer.is_empty() { - "{}".to_string() - } else { - call.arguments_buffer.clone() - }; - - // Make sure the done event carries full arguments - parsed_data["arguments"] = Value::String(arguments_value.clone()); - - // Get item_id and transform it - let item_id = parsed_data - .get("item_id") - .and_then(|v| v.as_str()) - .unwrap_or(""); - let mcp_item_id = if let Some(stripped) = item_id.strip_prefix("fc_") { - format!("mcp_{}", stripped) - } else { - item_id.to_string() - }; - - // Emit a synthetic MCP arguments delta event before the done event - let mut delta_event = json!({ - "type": McpEvent::CALL_ARGUMENTS_DELTA, - "sequence_number": *sequence_number, - "output_index": assigned_index, - "item_id": mcp_item_id, - "delta": arguments_value, - }); - - if let Some(obfuscation) = call.last_obfuscation.as_ref() { - if let Some(obj) = delta_event.as_object_mut() { - obj.insert( - "obfuscation".to_string(), - Value::String(obfuscation.clone()), - ); - } - } else if let Some(obfuscation) = parsed_data.get("obfuscation").cloned() { - if let Some(obj) = delta_event.as_object_mut() { - obj.insert("obfuscation".to_string(), obfuscation); - } - } - - let delta_block = format!( - "event: {}\ndata: {}\n\n", - McpEvent::CALL_ARGUMENTS_DELTA, - delta_event - ); - if tx.send(Ok(Bytes::from(delta_block))).is_err() { - return false; - } - - *sequence_number += 1; - } - } + if event_name == Some(FunctionCallEvent::ARGUMENTS_DONE) + && !send_buffered_arguments( + &mut parsed_data, + handler, + tx, + sequence_number, + &mut mapped_output_index, + ) + { + return false; } - // Remap output_index (if present) so downstream sees sequential indices + // Remap output_index for sequential downstream indices if mapped_output_index.is_none() { - if let Some(output_index) = parsed_data - .get("output_index") - .and_then(|v| v.as_u64()) - .map(|v| v as usize) - { - mapped_output_index = handler.mapped_output_index(output_index); + if let Some(idx) = extract_output_index(&parsed_data) { + mapped_output_index = handler.mapped_output_index(idx); } } - if let Some(mapped) = mapped_output_index { parsed_data["output_index"] = json!(mapped); } - // Apply all transformations in-place (single parse/serialize!) + // Apply transformations apply_event_transformations_inplace(&mut parsed_data, ctx); + // Restore original response ID if let Some(response_obj) = parsed_data .get_mut("response") .and_then(|v| v.as_object_mut()) @@ -810,73 +932,74 @@ pub(super) fn forward_streaming_event( } } - // Update sequence number if present in the event + // Update sequence number if parsed_data.get("sequence_number").is_some() { parsed_data["sequence_number"] = json!(*sequence_number); *sequence_number += 1; } - // Serialize once + // Serialize and send let final_data = match serde_json::to_string(&parsed_data) { Ok(s) => s, Err(_) => { - // Serialization failed, forward original - let chunk_to_send = format!("{}\n\n", raw_block); - return tx.send(Ok(Bytes::from(chunk_to_send))).is_ok(); + let chunk = format!("{}\n\n", raw_block); + return tx.send(Ok(Bytes::from(chunk))).is_ok(); } }; - // Rebuild SSE block with potentially transformed event name - let mut final_block = String::new(); - if let Some(evt) = event_name { - // Update event name for function_call_arguments events - if evt == FunctionCallEvent::ARGUMENTS_DELTA { - final_block.push_str(&format!("event: {}\n", McpEvent::CALL_ARGUMENTS_DELTA)); - } else if evt == FunctionCallEvent::ARGUMENTS_DONE { - final_block.push_str(&format!("event: {}\n", McpEvent::CALL_ARGUMENTS_DONE)); - } else { - final_block.push_str(&format!("event: {}\n", evt)); - } - } - final_block.push_str(&format!("data: {}", final_data)); + // Build SSE block with transformed event name + let final_block = match event_name { + Some(evt) => format!("event: {}\ndata: {}\n\n", map_event_name(evt), final_data), + None => format!("data: {}\n\n", final_data), + }; - let chunk_to_send = format!("{}\n\n", final_block); - if tx.send(Ok(Bytes::from(chunk_to_send))).is_err() { + if tx.send(Ok(Bytes::from(final_block))).is_err() { return false; } // After sending output_item.added for mcp_call, inject mcp_call.in_progress event - if event_name == Some(OutputItemEvent::ADDED) { - if let Some(item) = parsed_data.get("item") { - if item.get("type").and_then(|v| v.as_str()) == Some(ItemType::MCP_CALL) { - // Already transformed to mcp_call - if let (Some(item_id), Some(output_index)) = ( - item.get("id").and_then(|v| v.as_str()), - parsed_data.get("output_index").and_then(|v| v.as_u64()), - ) { - let in_progress_event = json!({ - "type": McpEvent::CALL_IN_PROGRESS, - "sequence_number": *sequence_number, - "output_index": output_index, - "item_id": item_id - }); - *sequence_number += 1; - let in_progress_block = format!( - "event: {}\ndata: {}\n\n", - McpEvent::CALL_IN_PROGRESS, - in_progress_event - ); - if tx.send(Ok(Bytes::from(in_progress_block))).is_err() { - return false; - } - } - } - } + if event_name == Some(OutputItemEvent::ADDED) + && !maybe_inject_mcp_in_progress(&parsed_data, tx, sequence_number) + { + return false; } true } +/// Inject mcp_call.in_progress event after an mcp_call item is added. +/// Returns false if client disconnected. +fn maybe_inject_mcp_in_progress( + parsed_data: &Value, + tx: &mpsc::UnboundedSender>, + sequence_number: &mut u64, +) -> bool { + let Some(item) = parsed_data.get("item") else { + return true; + }; + + if item.get("type").and_then(|v| v.as_str()) != Some(ItemType::MCP_CALL) { + return true; + } + + let Some(item_id) = item.get("id").and_then(|v| v.as_str()) else { + return true; + }; + let Some(output_index) = parsed_data.get("output_index").and_then(|v| v.as_u64()) else { + return true; + }; + + let event = json!({ + "type": McpEvent::CALL_IN_PROGRESS, + "sequence_number": *sequence_number, + "output_index": output_index, + "item_id": item_id + }); + *sequence_number += 1; + + send_sse_event(tx, McpEvent::CALL_IN_PROGRESS, &event) +} + /// Send final response.completed event to client /// Returns false if client disconnected pub(super) fn send_final_response_event( @@ -992,38 +1115,25 @@ pub(super) async fn handle_simple_streaming_passthrough( let mut accumulator = StreamingResponseAccumulator::new(); let mut upstream_failed = false; let mut receiver_connected = true; - let mut pending = String::new(); + let mut chunk_processor = ChunkProcessor::new(); while let Some(chunk_result) = upstream_stream.next().await { match chunk_result { Ok(chunk) => { - let chunk_text = match std::str::from_utf8(&chunk) { - Ok(text) => Cow::Borrowed(text), - Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), - }; - - pending.push_str(&chunk_text.replace("\r\n", "\n")); + chunk_processor.push_chunk(&chunk); - while let Some(pos) = pending.find("\n\n") { - let raw_block = pending[..pos].to_string(); - pending.drain(..pos + 2); - - if raw_block.trim().is_empty() { - continue; - } - - let block_cow = if let Some(modified) = rewrite_streaming_block( - raw_block.as_str(), + while let Some(raw_block) = chunk_processor.next_block() { + let block_cow = match rewrite_streaming_block( + &raw_block, &original_request, previous_response_id.as_deref(), ) { - Cow::Owned(modified) - } else { - Cow::Borrowed(raw_block.as_str()) + Some(modified) => Cow::Owned(modified), + None => Cow::Borrowed(raw_block.as_str()), }; if should_store || persist_needed { - accumulator.ingest_block(block_cow.as_ref()); + accumulator.ingest_block(&block_cow); } if receiver_connected { @@ -1052,8 +1162,8 @@ pub(super) async fn handle_simple_streaming_passthrough( } if (should_store || persist_needed) && !upstream_failed { - if !pending.trim().is_empty() { - accumulator.ingest_block(&pending); + if chunk_processor.has_remaining() { + accumulator.ingest_block(&chunk_processor.take_remaining()); } let encountered_error = accumulator.encountered_error().cloned(); if let Some(mut response_json) = accumulator.into_final_response() { @@ -1189,28 +1299,16 @@ pub(super) async fn handle_streaming_with_tool_interception( if let Some(ref id) = preserved_response_id { handler.original_response_id = Some(id.clone()); } - let mut pending = String::new(); + let mut chunk_processor = ChunkProcessor::new(); let mut tool_calls_detected = false; let mut seen_in_progress = false; while let Some(chunk_result) = upstream_stream.next().await { match chunk_result { Ok(chunk) => { - let chunk_text = match std::str::from_utf8(&chunk) { - Ok(text) => Cow::Borrowed(text), - Err(_) => Cow::Owned(String::from_utf8_lossy(&chunk).to_string()), - }; - - pending.push_str(&chunk_text.replace("\r\n", "\n")); - - while let Some(pos) = pending.find("\n\n") { - let raw_block = pending[..pos].to_string(); - pending.drain(..pos + 2); - - if raw_block.trim().is_empty() { - continue; - } + chunk_processor.push_chunk(&chunk); + while let Some(raw_block) = chunk_processor.next_block() { // Parse event let (event_name, data) = parse_sse_block(&raw_block); diff --git a/sgl-model-gateway/src/routers/openai/utils.rs b/sgl-model-gateway/src/routers/openai/utils.rs deleted file mode 100644 index b9262c7728f9..000000000000 --- a/sgl-model-gateway/src/routers/openai/utils.rs +++ /dev/null @@ -1,63 +0,0 @@ -//! Utility types for OpenAI router - -use std::collections::HashMap; - -// ============================================================================ -// Stream Action Enum -// ============================================================================ - -/// Action to take based on streaming event processing -#[derive(Debug)] -pub(crate) enum StreamAction { - Forward, // Pass event to client - Buffer, // Accumulate for tool execution - ExecuteTools, // Function call complete, execute now -} - -// ============================================================================ -// Output Index Mapper -// ============================================================================ - -/// Maps upstream output indices to sequential downstream indices -#[derive(Debug, Default)] -pub(crate) struct OutputIndexMapper { - next_index: usize, - // Map upstream output_index -> remapped output_index - assigned: HashMap, -} - -impl OutputIndexMapper { - pub fn with_start(next_index: usize) -> Self { - Self { - next_index, - assigned: HashMap::new(), - } - } - - pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize { - *self.assigned.entry(upstream_index).or_insert_with(|| { - let assigned = self.next_index; - self.next_index += 1; - assigned - }) - } - - pub fn lookup(&self, upstream_index: usize) -> Option { - self.assigned.get(&upstream_index).copied() - } - - pub fn allocate_synthetic(&mut self) -> usize { - let assigned = self.next_index; - self.next_index += 1; - assigned - } - - pub fn next_index(&self) -> usize { - self.next_index - } -} - -// ============================================================================ -// Re-export FunctionCallInProgress from mcp module -// ============================================================================ -pub(crate) use super::mcp::FunctionCallInProgress; From c347aac5caea694c483469bc6ce7fefbc75752a5 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sat, 6 Dec 2025 01:00:15 -0800 Subject: [PATCH 2/2] [model-gateway] refactor oai streaming -2 --- .../src/routers/openai/accumulator.rs | 164 ++++++ sgl-model-gateway/src/routers/openai/mod.rs | 2 + .../src/routers/openai/streaming.rs | 527 +----------------- .../src/routers/openai/tool_handler.rs | 341 ++++++++++++ 4 files changed, 531 insertions(+), 503 deletions(-) create mode 100644 sgl-model-gateway/src/routers/openai/accumulator.rs create mode 100644 sgl-model-gateway/src/routers/openai/tool_handler.rs diff --git a/sgl-model-gateway/src/routers/openai/accumulator.rs b/sgl-model-gateway/src/routers/openai/accumulator.rs new file mode 100644 index 000000000000..b9ed755aa30a --- /dev/null +++ b/sgl-model-gateway/src/routers/openai/accumulator.rs @@ -0,0 +1,164 @@ +//! Streaming response accumulator for persisting responses. + +use serde_json::Value; +use tracing::warn; + +use super::streaming::{extract_output_index, get_event_type}; +use crate::protocols::event_types::{OutputItemEvent, ResponseEvent}; + +// ============================================================================ +// Streaming Response Accumulator +// ============================================================================ + +/// Helper that parses SSE frames from the OpenAI responses stream and +/// accumulates enough information to persist the final response locally. +pub(super) struct StreamingResponseAccumulator { + /// The initial `response.created` payload (if emitted). + initial_response: Option, + /// The final `response.completed` payload (if emitted). + completed_response: Option, + /// Collected output items keyed by the upstream output index, used when + /// a final response payload is absent and we need to synthesize one. + output_items: Vec<(usize, Value)>, + /// Captured error payload (if the upstream stream fails midway). + encountered_error: Option, +} + +impl StreamingResponseAccumulator { + pub fn new() -> Self { + Self { + initial_response: None, + completed_response: None, + output_items: Vec::new(), + encountered_error: None, + } + } + + /// Feed the accumulator with the next SSE chunk. + pub fn ingest_block(&mut self, block: &str) { + if block.trim().is_empty() { + return; + } + self.process_block(block); + } + + /// Consume the accumulator and produce the best-effort final response value. + pub fn into_final_response(mut self) -> Option { + if self.completed_response.is_some() { + return self.completed_response; + } + + self.build_fallback_response() + } + + pub fn encountered_error(&self) -> Option<&Value> { + self.encountered_error.as_ref() + } + + pub fn original_response_id(&self) -> Option<&str> { + self.initial_response + .as_ref() + .and_then(|response| response.get("id")) + .and_then(|id| id.as_str()) + } + + pub fn snapshot_final_response(&self) -> Option { + if let Some(resp) = &self.completed_response { + return Some(resp.clone()); + } + self.build_fallback_response_snapshot() + } + + fn build_fallback_response_snapshot(&self) -> Option { + let mut response = self.initial_response.clone()?; + + if let Some(obj) = response.as_object_mut() { + obj.insert("status".to_string(), Value::String("completed".to_string())); + + let mut output_items = self.output_items.clone(); + output_items.sort_by_key(|(index, _)| *index); + let outputs: Vec = output_items.into_iter().map(|(_, item)| item).collect(); + obj.insert("output".to_string(), Value::Array(outputs)); + } + + Some(response) + } + + fn process_block(&mut self, block: &str) { + let trimmed = block.trim(); + if trimmed.is_empty() { + return; + } + + let mut event_name: Option = None; + let mut data_lines: Vec = Vec::new(); + + for line in trimmed.lines() { + if let Some(rest) = line.strip_prefix("event:") { + event_name = Some(rest.trim().to_string()); + } else if let Some(rest) = line.strip_prefix("data:") { + data_lines.push(rest.trim_start().to_string()); + } + } + + let data_payload = data_lines.join("\n"); + if data_payload.is_empty() { + return; + } + + self.handle_event(event_name.as_deref(), &data_payload); + } + + fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) { + let parsed: Value = match serde_json::from_str(data_payload) { + Ok(value) => value, + Err(err) => { + warn!("Failed to parse streaming event JSON: {}", err); + return; + } + }; + + match get_event_type(event_name, &parsed) { + ResponseEvent::CREATED => { + if self.initial_response.is_none() { + if let Some(response) = parsed.get("response") { + self.initial_response = Some(response.clone()); + } + } + } + ResponseEvent::COMPLETED => { + if let Some(response) = parsed.get("response") { + self.completed_response = Some(response.clone()); + } + } + OutputItemEvent::DONE => { + if let (Some(index), Some(item)) = + (extract_output_index(&parsed), parsed.get("item")) + { + self.output_items.push((index, item.clone())); + } + } + "response.error" => { + self.encountered_error = Some(parsed); + } + _ => {} + } + } + + fn build_fallback_response(&mut self) -> Option { + let mut response = self.initial_response.clone()?; + + if let Some(obj) = response.as_object_mut() { + obj.insert("status".to_string(), Value::String("completed".to_string())); + + self.output_items.sort_by_key(|(index, _)| *index); + let outputs: Vec = std::mem::take(&mut self.output_items) + .into_iter() + .map(|(_, item)| item) + .collect(); + obj.insert("output".to_string(), Value::Array(outputs)); + } + + Some(response) + } +} diff --git a/sgl-model-gateway/src/routers/openai/mod.rs b/sgl-model-gateway/src/routers/openai/mod.rs index 3288657afdb7..5d8d024431b8 100644 --- a/sgl-model-gateway/src/routers/openai/mod.rs +++ b/sgl-model-gateway/src/routers/openai/mod.rs @@ -7,6 +7,7 @@ //! - Multi-turn tool execution loops //! - SSE (Server-Sent Events) streaming +mod accumulator; mod context; pub mod conversations; pub mod mcp; @@ -14,6 +15,7 @@ pub mod provider; mod responses; mod router; mod streaming; +mod tool_handler; // Re-export the main types for external use pub use provider::{Provider, ProviderError, ProviderRegistry}; diff --git a/sgl-model-gateway/src/routers/openai/streaming.rs b/sgl-model-gateway/src/routers/openai/streaming.rs index b61484950753..d1e174f37436 100644 --- a/sgl-model-gateway/src/routers/openai/streaming.rs +++ b/sgl-model-gateway/src/routers/openai/streaming.rs @@ -7,7 +7,7 @@ //! - MCP tool execution loops within streaming responses //! - Event transformation and output index remapping -use std::{borrow::Cow, collections::HashMap, io, sync::Arc}; +use std::{borrow::Cow, io, sync::Arc}; use axum::{ body::Body, @@ -22,15 +22,17 @@ use tokio_stream::wrappers::UnboundedReceiverStream; use tracing::warn; // Import from sibling modules -use super::context::{RequestContext, StreamingEventContext, StreamingRequest}; +use super::accumulator::StreamingResponseAccumulator; use super::{ + context::{RequestContext, StreamingEventContext, StreamingRequest}, conversations::persist_conversation_items, mcp::{ build_resume_payload, ensure_request_mcp_client, execute_streaming_tool_calls, inject_mcp_metadata_streaming, prepare_mcp_payload_for_streaming, - send_mcp_list_tools_events, FunctionCallInProgress, McpLoopConfig, ToolLoopState, + send_mcp_list_tools_events, McpLoopConfig, ToolLoopState, }, responses::{mask_tools_as_mcp, patch_streaming_response_json, rewrite_streaming_block}, + tool_handler::{StreamAction, StreamingToolHandler}, }; use crate::{ protocols::{ @@ -43,74 +45,19 @@ use crate::{ routers::header_utils::{apply_request_headers, preserve_response_headers}, }; -// ============================================================================ -// Stream Action Enum -// ============================================================================ - -/// Action to take based on streaming event processing -#[derive(Debug)] -pub(crate) enum StreamAction { - Forward, // Pass event to client - Buffer, // Accumulate for tool execution - ExecuteTools, // Function call complete, execute now -} - -// ============================================================================ -// Output Index Mapper -// ============================================================================ - -/// Maps upstream output indices to sequential downstream indices -#[derive(Debug, Default)] -pub(crate) struct OutputIndexMapper { - next_index: usize, - // Map upstream output_index -> remapped output_index - assigned: HashMap, -} - -impl OutputIndexMapper { - pub fn with_start(next_index: usize) -> Self { - Self { - next_index, - assigned: HashMap::new(), - } - } - - pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize { - *self.assigned.entry(upstream_index).or_insert_with(|| { - let assigned = self.next_index; - self.next_index += 1; - assigned - }) - } - - pub fn lookup(&self, upstream_index: usize) -> Option { - self.assigned.get(&upstream_index).copied() - } - - pub fn allocate_synthetic(&mut self) -> usize { - let assigned = self.next_index; - self.next_index += 1; - assigned - } - - pub fn next_index(&self) -> usize { - self.next_index - } -} - // ============================================================================ // Helper Functions // ============================================================================ /// Extract output_index from a JSON value #[inline] -fn extract_output_index(value: &Value) -> Option { +pub(super) fn extract_output_index(value: &Value) -> Option { value.get("output_index")?.as_u64().map(|v| v as usize) } /// Get event type from event name or parsed JSON, returning a reference to avoid allocation #[inline] -fn get_event_type<'a>(event_name: Option<&'a str>, parsed: &'a Value) -> &'a str { +pub(super) fn get_event_type<'a>(event_name: Option<&'a str>, parsed: &'a Value) -> &'a str { event_name .or_else(|| parsed.get("type").and_then(|v| v.as_str())) .unwrap_or("") @@ -139,25 +86,28 @@ impl ChunkProcessor { Ok(s) => Cow::Borrowed(s), Err(_) => Cow::Owned(String::from_utf8_lossy(chunk).into_owned()), }; - // Normalize CRLF to LF - if chunk_str.contains("\r\n") { - self.pending.push_str(&chunk_str.replace("\r\n", "\n")); - } else { - self.pending.push_str(&chunk_str); + // Normalize CRLF to LF without extra allocation + let mut chars = chunk_str.chars().peekable(); + while let Some(c) = chars.next() { + if c == '\r' && chars.peek() == Some(&'\n') { + // Skip \r when followed by \n + continue; + } + self.pending.push(c); } } /// Extract the next complete SSE block from the buffer, if available pub fn next_block(&mut self) -> Option { - let pos = self.pending.find("\n\n")?; - let block = self.pending[..pos].to_string(); - self.pending.drain(..pos + 2); - - if block.trim().is_empty() { - // Skip empty blocks, try next - self.next_block() - } else { - Some(block) + loop { + let pos = self.pending.find("\n\n")?; + let block = self.pending[..pos].to_string(); + self.pending.drain(..pos + 2); + + if !block.trim().is_empty() { + return Some(block); + } + // If block is empty, loop again to find the next one } } @@ -172,435 +122,6 @@ impl ChunkProcessor { } } -// ============================================================================ -// Streaming Response Accumulator -// ============================================================================ - -/// Helper that parses SSE frames from the OpenAI responses stream and -/// accumulates enough information to persist the final response locally. -pub(super) struct StreamingResponseAccumulator { - /// The initial `response.created` payload (if emitted). - initial_response: Option, - /// The final `response.completed` payload (if emitted). - completed_response: Option, - /// Collected output items keyed by the upstream output index, used when - /// a final response payload is absent and we need to synthesize one. - output_items: Vec<(usize, Value)>, - /// Captured error payload (if the upstream stream fails midway). - encountered_error: Option, -} - -impl StreamingResponseAccumulator { - pub fn new() -> Self { - Self { - initial_response: None, - completed_response: None, - output_items: Vec::new(), - encountered_error: None, - } - } - - /// Feed the accumulator with the next SSE chunk. - pub fn ingest_block(&mut self, block: &str) { - if block.trim().is_empty() { - return; - } - self.process_block(block); - } - - /// Consume the accumulator and produce the best-effort final response value. - pub fn into_final_response(mut self) -> Option { - if self.completed_response.is_some() { - return self.completed_response; - } - - self.build_fallback_response() - } - - pub fn encountered_error(&self) -> Option<&Value> { - self.encountered_error.as_ref() - } - - pub fn original_response_id(&self) -> Option<&str> { - self.initial_response - .as_ref() - .and_then(|response| response.get("id")) - .and_then(|id| id.as_str()) - } - - pub fn snapshot_final_response(&self) -> Option { - if let Some(resp) = &self.completed_response { - return Some(resp.clone()); - } - self.build_fallback_response_snapshot() - } - - fn build_fallback_response_snapshot(&self) -> Option { - let mut response = self.initial_response.clone()?; - - if let Some(obj) = response.as_object_mut() { - obj.insert("status".to_string(), Value::String("completed".to_string())); - - let mut output_items = self.output_items.clone(); - output_items.sort_by_key(|(index, _)| *index); - let outputs: Vec = output_items.into_iter().map(|(_, item)| item).collect(); - obj.insert("output".to_string(), Value::Array(outputs)); - } - - Some(response) - } - - fn process_block(&mut self, block: &str) { - let trimmed = block.trim(); - if trimmed.is_empty() { - return; - } - - let mut event_name: Option = None; - let mut data_lines: Vec = Vec::new(); - - for line in trimmed.lines() { - if let Some(rest) = line.strip_prefix("event:") { - event_name = Some(rest.trim().to_string()); - } else if let Some(rest) = line.strip_prefix("data:") { - data_lines.push(rest.trim_start().to_string()); - } - } - - let data_payload = data_lines.join("\n"); - if data_payload.is_empty() { - return; - } - - self.handle_event(event_name.as_deref(), &data_payload); - } - - fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) { - let parsed: Value = match serde_json::from_str(data_payload) { - Ok(value) => value, - Err(err) => { - warn!("Failed to parse streaming event JSON: {}", err); - return; - } - }; - - match get_event_type(event_name, &parsed) { - ResponseEvent::CREATED => { - if self.initial_response.is_none() { - if let Some(response) = parsed.get("response") { - self.initial_response = Some(response.clone()); - } - } - } - ResponseEvent::COMPLETED => { - if let Some(response) = parsed.get("response") { - self.completed_response = Some(response.clone()); - } - } - OutputItemEvent::DONE => { - if let (Some(index), Some(item)) = - (extract_output_index(&parsed), parsed.get("item")) - { - self.output_items.push((index, item.clone())); - } - } - "response.error" => { - self.encountered_error = Some(parsed); - } - _ => {} - } - } - - fn build_fallback_response(&mut self) -> Option { - let mut response = self.initial_response.clone()?; - - if let Some(obj) = response.as_object_mut() { - obj.insert("status".to_string(), Value::String("completed".to_string())); - - self.output_items.sort_by_key(|(index, _)| *index); - let outputs: Vec = self - .output_items - .iter() - .map(|(_, item)| item.clone()) - .collect(); - obj.insert("output".to_string(), Value::Array(outputs)); - } - - Some(response) - } -} - -// ============================================================================ -// Streaming Tool Handler -// ============================================================================ - -/// Handles streaming responses with MCP tool call interception -pub(super) struct StreamingToolHandler { - /// Accumulator for response persistence - pub accumulator: StreamingResponseAccumulator, - /// Function calls being built from deltas - pub pending_calls: Vec, - /// Track if we're currently in a function call - in_function_call: bool, - /// Manage output_index remapping so they increment per item - output_index_mapper: OutputIndexMapper, - /// Original response id captured from the first response.created event - pub original_response_id: Option, -} - -impl StreamingToolHandler { - pub fn with_starting_index(start: usize) -> Self { - Self { - accumulator: StreamingResponseAccumulator::new(), - pending_calls: Vec::new(), - in_function_call: false, - output_index_mapper: OutputIndexMapper::with_start(start), - original_response_id: None, - } - } - - pub fn ensure_output_index(&mut self, upstream_index: usize) -> usize { - self.output_index_mapper.ensure_mapping(upstream_index) - } - - pub fn mapped_output_index(&self, upstream_index: usize) -> Option { - self.output_index_mapper.lookup(upstream_index) - } - - pub fn allocate_synthetic_output_index(&mut self) -> usize { - self.output_index_mapper.allocate_synthetic() - } - - pub fn next_output_index(&self) -> usize { - self.output_index_mapper.next_index() - } - - pub fn original_response_id(&self) -> Option<&str> { - self.original_response_id - .as_deref() - .or_else(|| self.accumulator.original_response_id()) - } - - pub fn snapshot_final_response(&self) -> Option { - self.accumulator.snapshot_final_response() - } - - /// Process an SSE event and determine what action to take - pub fn process_event(&mut self, event_name: Option<&str>, data: &str) -> StreamAction { - // Always feed to accumulator for storage - self.accumulator.ingest_block(&format!( - "{}data: {}", - event_name - .map(|n| format!("event: {}\n", n)) - .unwrap_or_default(), - data - )); - - let parsed: Value = match serde_json::from_str(data) { - Ok(v) => v, - Err(_) => return StreamAction::Forward, - }; - - match get_event_type(event_name, &parsed) { - ResponseEvent::CREATED => { - if self.original_response_id.is_none() { - self.original_response_id = parsed - .get("response") - .and_then(|v| v.get("id")) - .and_then(|v| v.as_str()) - .map(|s| s.to_string()); - } - StreamAction::Forward - } - ResponseEvent::COMPLETED => StreamAction::Forward, - OutputItemEvent::ADDED => self.handle_output_item_added(&parsed), - FunctionCallEvent::ARGUMENTS_DELTA => self.handle_arguments_delta(&parsed), - FunctionCallEvent::ARGUMENTS_DONE => self.handle_arguments_done(&parsed), - OutputItemEvent::DELTA => self.process_output_delta(&parsed), - OutputItemEvent::DONE => { - if let Some(output_index) = extract_output_index(&parsed) { - self.ensure_output_index(output_index); - } - if self.has_complete_calls() { - StreamAction::ExecuteTools - } else { - StreamAction::Forward - } - } - _ => StreamAction::Forward, - } - } - - fn handle_output_item_added(&mut self, parsed: &Value) -> StreamAction { - if let Some(output_index) = extract_output_index(parsed) { - self.ensure_output_index(output_index); - } - - // Check if this is a function_call item being added - let Some(item) = parsed.get("item") else { - return StreamAction::Forward; - }; - let Some(item_type) = item.get("type").and_then(|v| v.as_str()) else { - return StreamAction::Forward; - }; - - if !is_function_call_type(item_type) { - return StreamAction::Forward; - } - - let Some(output_index) = extract_output_index(parsed) else { - warn!( - "Missing output_index in function_call added event, \ - forwarding without processing for tool execution" - ); - return StreamAction::Forward; - }; - - let assigned_index = self.ensure_output_index(output_index); - let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); - let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); - - let call = self.get_or_create_call(output_index, item); - call.call_id = call_id.to_string(); - call.name = name.to_string(); - call.assigned_output_index = Some(assigned_index); - self.in_function_call = true; - - StreamAction::Forward - } - - fn handle_arguments_delta(&mut self, parsed: &Value) -> StreamAction { - let Some(output_index) = extract_output_index(parsed) else { - return StreamAction::Forward; - }; - - let assigned_index = self.ensure_output_index(output_index); - - if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) { - if let Some(call) = self.find_call_mut(output_index) { - call.arguments_buffer.push_str(delta); - if let Some(obfuscation) = parsed.get("obfuscation").and_then(|v| v.as_str()) { - call.last_obfuscation = Some(obfuscation.to_string()); - } - if call.assigned_output_index.is_none() { - call.assigned_output_index = Some(assigned_index); - } - } - } - StreamAction::Forward - } - - fn handle_arguments_done(&mut self, parsed: &Value) -> StreamAction { - if let Some(output_index) = extract_output_index(parsed) { - let assigned_index = self.ensure_output_index(output_index); - if let Some(call) = self.find_call_mut(output_index) { - if call.assigned_output_index.is_none() { - call.assigned_output_index = Some(assigned_index); - } - } - } - - if self.has_complete_calls() { - StreamAction::ExecuteTools - } else { - StreamAction::Forward - } - } - - fn find_call_mut(&mut self, output_index: usize) -> Option<&mut FunctionCallInProgress> { - self.pending_calls - .iter_mut() - .find(|c| c.output_index == output_index) - } - - /// Process output delta events to detect and accumulate function calls - fn process_output_delta(&mut self, event: &Value) -> StreamAction { - let output_index = extract_output_index(event).unwrap_or(0); - let assigned_index = self.ensure_output_index(output_index); - - let delta = match event.get("delta") { - Some(d) => d, - None => return StreamAction::Forward, - }; - - // Check if this is a function call delta - let item_type = delta.get("type").and_then(|v| v.as_str()); - - if item_type.is_some_and(is_function_call_type) { - self.in_function_call = true; - - // Get or create function call for this output index - let call = self.get_or_create_call(output_index, delta); - call.assigned_output_index = Some(assigned_index); - - // Accumulate call_id if present - if let Some(call_id) = delta.get("call_id").and_then(|v| v.as_str()) { - call.call_id = call_id.to_string(); - } - - // Accumulate name if present - if let Some(name) = delta.get("name").and_then(|v| v.as_str()) { - call.name.push_str(name); - } - - // Accumulate arguments if present - if let Some(args) = delta.get("arguments").and_then(|v| v.as_str()) { - call.arguments_buffer.push_str(args); - } - - if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { - call.last_obfuscation = Some(obfuscation.to_string()); - } - - // Buffer this event, don't forward to client - return StreamAction::Buffer; - } - - // Forward non-function-call events - StreamAction::Forward - } - - fn get_or_create_call( - &mut self, - output_index: usize, - delta: &Value, - ) -> &mut FunctionCallInProgress { - // Find existing call for this output index - if let Some(pos) = self - .pending_calls - .iter() - .position(|c| c.output_index == output_index) - { - return &mut self.pending_calls[pos]; - } - - // Create new call - let call_id = delta - .get("call_id") - .and_then(|v| v.as_str()) - .unwrap_or("") - .to_string(); - - let mut call = FunctionCallInProgress::new(call_id, output_index); - if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { - call.last_obfuscation = Some(obfuscation.to_string()); - } - - self.pending_calls.push(call); - self.pending_calls - .last_mut() - .expect("Just pushed to pending_calls, must have at least one element") - } - - fn has_complete_calls(&self) -> bool { - !self.pending_calls.is_empty() && self.pending_calls.iter().all(|c| c.is_complete()) - } - - pub fn take_pending_calls(&mut self) -> Vec { - std::mem::take(&mut self.pending_calls) - } -} - // ============================================================================ // SSE Parsing // ============================================================================ diff --git a/sgl-model-gateway/src/routers/openai/tool_handler.rs b/sgl-model-gateway/src/routers/openai/tool_handler.rs new file mode 100644 index 000000000000..5c912846175a --- /dev/null +++ b/sgl-model-gateway/src/routers/openai/tool_handler.rs @@ -0,0 +1,341 @@ +//! Streaming tool call handling for MCP interception. + +use std::collections::HashMap; + +use serde_json::Value; +use tracing::warn; + +use super::{ + accumulator::StreamingResponseAccumulator, + mcp::FunctionCallInProgress, + streaming::{extract_output_index, get_event_type}, +}; +use crate::protocols::event_types::{ + is_function_call_type, FunctionCallEvent, OutputItemEvent, ResponseEvent, +}; + +// ============================================================================ +// Stream Action Enum +// ============================================================================ + +/// Action to take based on streaming event processing +#[derive(Debug)] +pub(crate) enum StreamAction { + Forward, // Pass event to client + Buffer, // Accumulate for tool execution + ExecuteTools, // Function call complete, execute now +} + +// ============================================================================ +// Output Index Mapper +// ============================================================================ + +/// Maps upstream output indices to sequential downstream indices +#[derive(Debug, Default)] +pub(crate) struct OutputIndexMapper { + next_index: usize, + // Map upstream output_index -> remapped output_index + assigned: HashMap, +} + +impl OutputIndexMapper { + pub fn with_start(next_index: usize) -> Self { + Self { + next_index, + assigned: HashMap::new(), + } + } + + pub fn ensure_mapping(&mut self, upstream_index: usize) -> usize { + *self.assigned.entry(upstream_index).or_insert_with(|| { + let assigned = self.next_index; + self.next_index += 1; + assigned + }) + } + + pub fn lookup(&self, upstream_index: usize) -> Option { + self.assigned.get(&upstream_index).copied() + } + + pub fn allocate_synthetic(&mut self) -> usize { + let assigned = self.next_index; + self.next_index += 1; + assigned + } + + pub fn next_index(&self) -> usize { + self.next_index + } +} + +// ============================================================================ +// Streaming Tool Handler +// ============================================================================ + +/// Handles streaming responses with MCP tool call interception +pub(super) struct StreamingToolHandler { + /// Accumulator for response persistence + pub accumulator: StreamingResponseAccumulator, + /// Function calls being built from deltas + pub pending_calls: Vec, + /// Track if we're currently in a function call + in_function_call: bool, + /// Manage output_index remapping so they increment per item + output_index_mapper: OutputIndexMapper, + /// Original response id captured from the first response.created event + pub original_response_id: Option, +} + +impl StreamingToolHandler { + pub fn with_starting_index(start: usize) -> Self { + Self { + accumulator: StreamingResponseAccumulator::new(), + pending_calls: Vec::new(), + in_function_call: false, + output_index_mapper: OutputIndexMapper::with_start(start), + original_response_id: None, + } + } + + pub fn ensure_output_index(&mut self, upstream_index: usize) -> usize { + self.output_index_mapper.ensure_mapping(upstream_index) + } + + pub fn mapped_output_index(&self, upstream_index: usize) -> Option { + self.output_index_mapper.lookup(upstream_index) + } + + pub fn allocate_synthetic_output_index(&mut self) -> usize { + self.output_index_mapper.allocate_synthetic() + } + + pub fn next_output_index(&self) -> usize { + self.output_index_mapper.next_index() + } + + pub fn original_response_id(&self) -> Option<&str> { + self.original_response_id + .as_deref() + .or_else(|| self.accumulator.original_response_id()) + } + + pub fn snapshot_final_response(&self) -> Option { + self.accumulator.snapshot_final_response() + } + + /// Process an SSE event and determine what action to take + pub fn process_event(&mut self, event_name: Option<&str>, data: &str) -> StreamAction { + // Always feed to accumulator for storage + self.accumulator.ingest_block(&format!( + "{}data: {}", + event_name + .map(|n| format!("event: {}\n", n)) + .unwrap_or_default(), + data + )); + + let parsed: Value = match serde_json::from_str(data) { + Ok(v) => v, + Err(_) => return StreamAction::Forward, + }; + + match get_event_type(event_name, &parsed) { + ResponseEvent::CREATED => { + if self.original_response_id.is_none() { + self.original_response_id = parsed + .get("response") + .and_then(|v| v.get("id")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + } + StreamAction::Forward + } + ResponseEvent::COMPLETED => StreamAction::Forward, + OutputItemEvent::ADDED => self.handle_output_item_added(&parsed), + FunctionCallEvent::ARGUMENTS_DELTA => self.handle_arguments_delta(&parsed), + FunctionCallEvent::ARGUMENTS_DONE => self.handle_arguments_done(&parsed), + OutputItemEvent::DELTA => self.process_output_delta(&parsed), + OutputItemEvent::DONE => { + if let Some(output_index) = extract_output_index(&parsed) { + self.ensure_output_index(output_index); + } + if self.has_complete_calls() { + StreamAction::ExecuteTools + } else { + StreamAction::Forward + } + } + _ => StreamAction::Forward, + } + } + + fn handle_output_item_added(&mut self, parsed: &Value) -> StreamAction { + if let Some(output_index) = extract_output_index(parsed) { + self.ensure_output_index(output_index); + } + + // Check if this is a function_call item being added + let Some(item) = parsed.get("item") else { + return StreamAction::Forward; + }; + let Some(item_type) = item.get("type").and_then(|v| v.as_str()) else { + return StreamAction::Forward; + }; + + if !is_function_call_type(item_type) { + return StreamAction::Forward; + } + + let Some(output_index) = extract_output_index(parsed) else { + warn!( + "Missing output_index in function_call added event, \ + forwarding without processing for tool execution" + ); + return StreamAction::Forward; + }; + + let assigned_index = self.ensure_output_index(output_index); + let call_id = item.get("call_id").and_then(|v| v.as_str()).unwrap_or(""); + let name = item.get("name").and_then(|v| v.as_str()).unwrap_or(""); + + let call = self.get_or_create_call(output_index, item); + call.call_id = call_id.to_string(); + call.name = name.to_string(); + call.assigned_output_index = Some(assigned_index); + self.in_function_call = true; + + StreamAction::Forward + } + + fn handle_arguments_delta(&mut self, parsed: &Value) -> StreamAction { + let Some(output_index) = extract_output_index(parsed) else { + return StreamAction::Forward; + }; + + let assigned_index = self.ensure_output_index(output_index); + + if let Some(delta) = parsed.get("delta").and_then(|v| v.as_str()) { + if let Some(call) = self.find_call_mut(output_index) { + call.arguments_buffer.push_str(delta); + if let Some(obfuscation) = parsed.get("obfuscation").and_then(|v| v.as_str()) { + call.last_obfuscation = Some(obfuscation.to_string()); + } + if call.assigned_output_index.is_none() { + call.assigned_output_index = Some(assigned_index); + } + } + } + StreamAction::Forward + } + + fn handle_arguments_done(&mut self, parsed: &Value) -> StreamAction { + if let Some(output_index) = extract_output_index(parsed) { + let assigned_index = self.ensure_output_index(output_index); + if let Some(call) = self.find_call_mut(output_index) { + if call.assigned_output_index.is_none() { + call.assigned_output_index = Some(assigned_index); + } + } + } + + if self.has_complete_calls() { + StreamAction::ExecuteTools + } else { + StreamAction::Forward + } + } + + fn find_call_mut(&mut self, output_index: usize) -> Option<&mut FunctionCallInProgress> { + self.pending_calls + .iter_mut() + .find(|c| c.output_index == output_index) + } + + /// Process output delta events to detect and accumulate function calls + fn process_output_delta(&mut self, event: &Value) -> StreamAction { + let output_index = extract_output_index(event).unwrap_or(0); + let assigned_index = self.ensure_output_index(output_index); + + let delta = match event.get("delta") { + Some(d) => d, + None => return StreamAction::Forward, + }; + + // Check if this is a function call delta + let item_type = delta.get("type").and_then(|v| v.as_str()); + + if item_type.is_some_and(is_function_call_type) { + self.in_function_call = true; + + // Get or create function call for this output index + let call = self.get_or_create_call(output_index, delta); + call.assigned_output_index = Some(assigned_index); + + // Accumulate call_id if present + if let Some(call_id) = delta.get("call_id").and_then(|v| v.as_str()) { + call.call_id = call_id.to_string(); + } + + // Accumulate name if present + if let Some(name) = delta.get("name").and_then(|v| v.as_str()) { + call.name.push_str(name); + } + + // Accumulate arguments if present + if let Some(args) = delta.get("arguments").and_then(|v| v.as_str()) { + call.arguments_buffer.push_str(args); + } + + if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { + call.last_obfuscation = Some(obfuscation.to_string()); + } + + // Buffer this event, don't forward to client + return StreamAction::Buffer; + } + + // Forward non-function-call events + StreamAction::Forward + } + + fn get_or_create_call( + &mut self, + output_index: usize, + delta: &Value, + ) -> &mut FunctionCallInProgress { + // Find existing call for this output index + if let Some(pos) = self + .pending_calls + .iter() + .position(|c| c.output_index == output_index) + { + return &mut self.pending_calls[pos]; + } + + // Create new call + let call_id = delta + .get("call_id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + + let mut call = FunctionCallInProgress::new(call_id, output_index); + if let Some(obfuscation) = delta.get("obfuscation").and_then(|v| v.as_str()) { + call.last_obfuscation = Some(obfuscation.to_string()); + } + + self.pending_calls.push(call); + self.pending_calls + .last_mut() + .expect("Just pushed to pending_calls, must have at least one element") + } + + fn has_complete_calls(&self) -> bool { + !self.pending_calls.is_empty() && self.pending_calls.iter().all(|c| c.is_complete()) + } + + pub fn take_pending_calls(&mut self) -> Vec { + std::mem::take(&mut self.pending_calls) + } +}