Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions sgl-model-gateway/src/routers/openai/accumulator.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
//! Streaming response accumulator for persisting responses.

use serde_json::Value;
use tracing::warn;

use super::streaming::{extract_output_index, get_event_type};
use crate::protocols::event_types::{OutputItemEvent, ResponseEvent};

// ============================================================================
// Streaming Response Accumulator
// ============================================================================

/// Helper that parses SSE frames from the OpenAI responses stream and
/// accumulates enough information to persist the final response locally.
pub(super) struct StreamingResponseAccumulator {
/// The initial `response.created` payload (if emitted).
initial_response: Option<Value>,
/// The final `response.completed` payload (if emitted).
completed_response: Option<Value>,
/// Collected output items keyed by the upstream output index, used when
/// a final response payload is absent and we need to synthesize one.
output_items: Vec<(usize, Value)>,
/// Captured error payload (if the upstream stream fails midway).
encountered_error: Option<Value>,
}

impl StreamingResponseAccumulator {
pub fn new() -> Self {
Self {
initial_response: None,
completed_response: None,
output_items: Vec::new(),
encountered_error: None,
}
}

/// Feed the accumulator with the next SSE chunk.
pub fn ingest_block(&mut self, block: &str) {
if block.trim().is_empty() {
return;
}
self.process_block(block);
}

/// Consume the accumulator and produce the best-effort final response value.
pub fn into_final_response(mut self) -> Option<Value> {
if self.completed_response.is_some() {
return self.completed_response;
}

self.build_fallback_response()
}

pub fn encountered_error(&self) -> Option<&Value> {
self.encountered_error.as_ref()
}

pub fn original_response_id(&self) -> Option<&str> {
self.initial_response
.as_ref()
.and_then(|response| response.get("id"))
.and_then(|id| id.as_str())
}

pub fn snapshot_final_response(&self) -> Option<Value> {
if let Some(resp) = &self.completed_response {
return Some(resp.clone());
}
self.build_fallback_response_snapshot()
}

fn build_fallback_response_snapshot(&self) -> Option<Value> {
let mut response = self.initial_response.clone()?;

if let Some(obj) = response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));

let mut output_items = self.output_items.clone();
output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = output_items.into_iter().map(|(_, item)| item).collect();
obj.insert("output".to_string(), Value::Array(outputs));
}

Some(response)
}

fn process_block(&mut self, block: &str) {
let trimmed = block.trim();
if trimmed.is_empty() {
return;
}

let mut event_name: Option<String> = None;
let mut data_lines: Vec<String> = Vec::new();

for line in trimmed.lines() {
if let Some(rest) = line.strip_prefix("event:") {
event_name = Some(rest.trim().to_string());
} else if let Some(rest) = line.strip_prefix("data:") {
data_lines.push(rest.trim_start().to_string());
}
}

let data_payload = data_lines.join("\n");
if data_payload.is_empty() {
return;
}

self.handle_event(event_name.as_deref(), &data_payload);
}

fn handle_event(&mut self, event_name: Option<&str>, data_payload: &str) {
let parsed: Value = match serde_json::from_str(data_payload) {
Ok(value) => value,
Err(err) => {
warn!("Failed to parse streaming event JSON: {}", err);
return;
}
};

match get_event_type(event_name, &parsed) {
ResponseEvent::CREATED => {
if self.initial_response.is_none() {
if let Some(response) = parsed.get("response") {
self.initial_response = Some(response.clone());
}
}
}
ResponseEvent::COMPLETED => {
if let Some(response) = parsed.get("response") {
self.completed_response = Some(response.clone());
}
}
OutputItemEvent::DONE => {
if let (Some(index), Some(item)) =
(extract_output_index(&parsed), parsed.get("item"))
{
self.output_items.push((index, item.clone()));
}
}
"response.error" => {
self.encountered_error = Some(parsed);
}
_ => {}
}
}

fn build_fallback_response(&mut self) -> Option<Value> {
let mut response = self.initial_response.clone()?;

if let Some(obj) = response.as_object_mut() {
obj.insert("status".to_string(), Value::String("completed".to_string()));

self.output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = std::mem::take(&mut self.output_items)
.into_iter()
.map(|(_, item)| item)
.collect();
obj.insert("output".to_string(), Value::Array(outputs));
Comment on lines +154 to +159
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In build_fallback_response, you are iterating over self.output_items and cloning each item. Since this method is only called from into_final_response, which consumes the accumulator, you can avoid these clones. By using std::mem::take to move the output_items out of self, you can then use into_iter() to consume the items without cloning, which is more efficient, especially for large response payloads.

Suggested change
self.output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = self
.output_items
.iter()
.map(|(_, item)| item.clone())
.collect();
obj.insert("output".to_string(), Value::Array(outputs));
self.output_items.sort_by_key(|(index, _)| *index);
let outputs: Vec<Value> = std::mem::take(&mut self.output_items)
.into_iter()
.map(|(_, item)| item)
.collect();
obj.insert("output".to_string(), Value::Array(outputs));

}

Some(response)
}
}
3 changes: 2 additions & 1 deletion sgl-model-gateway/src/routers/openai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
//! - Multi-turn tool execution loops
//! - SSE (Server-Sent Events) streaming

mod accumulator;
mod context;
pub mod conversations;
pub mod mcp;
pub mod provider;
mod responses;
mod router;
mod streaming;
mod utils;
mod tool_handler;

// Re-export the main types for external use
pub use provider::{Provider, ProviderError, ProviderRegistry};
Expand Down
Loading
Loading