From ea9692965117ff6b401e0c614dec04c95325beae Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 10 Dec 2025 15:17:02 -0800 Subject: [PATCH] [model-gateway] code clean up on oai router --- .../src/routers/openai/router.rs | 406 ++++++++---------- 1 file changed, 171 insertions(+), 235 deletions(-) diff --git a/sgl-model-gateway/src/routers/openai/router.rs b/sgl-model-gateway/src/routers/openai/router.rs index b45ab7da3e79..6b765f69cca6 100644 --- a/sgl-model-gateway/src/routers/openai/router.rs +++ b/sgl-model-gateway/src/routers/openai/router.rs @@ -65,9 +65,79 @@ impl std::fmt::Debug for OpenAIRouter { } } +/// Error response helpers for consistent API error formatting +mod error_responses { + use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, + }; + use serde_json::json; + + pub fn bad_request(message: impl Into) -> Response { + (StatusCode::BAD_REQUEST, message.into()).into_response() + } + + pub fn not_found(resource: &str, id: &str) -> Response { + ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("No {} found with id '{}'", resource, id), + "type": "invalid_request_error", + "param": null, + "code": "not_found" + } + })), + ) + .into_response() + } + + pub fn internal_error(message: impl Into) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ + "error": { + "message": message.into(), + "type": "internal_error", + "param": null, + "code": "storage_error" + } + })), + ) + .into_response() + } + + pub fn service_unavailable(message: impl Into) -> Response { + (StatusCode::SERVICE_UNAVAILABLE, message.into()).into_response() + } + + pub fn model_not_found(model: &str) -> Response { + ( + StatusCode::NOT_FOUND, + Json(json!({ + "error": { + "message": format!("No worker available for model '{}'", model), + "type": "model_not_found", + } + })), + ) + .into_response() + } +} + impl OpenAIRouter { const MAX_CONVERSATION_HISTORY_ITEMS: usize = 100; + /// Get all external workers from the registry + fn external_workers(&self) -> Vec> { + self.worker_registry + .get_all() + .into_iter() + .filter(|w| w.metadata().runtime_type == RuntimeType::External) + .collect() + } + fn shared_components(&self) -> Arc { Arc::clone(&self.shared_components) } @@ -204,53 +274,73 @@ impl OpenAIRouter { join_all(futures).await; } + /// Find workers that can handle the given model and select the least loaded one + fn find_best_worker_for_model(&self, model_id: &str) -> Option> { + self.worker_registry + .get_workers_filtered(None, None, None, Some(RuntimeType::External), true) + .into_iter() + .filter(|w| w.supports_model(model_id) && w.circuit_breaker().can_execute()) + .min_by_key(|w| w.load()) + } + async fn select_worker_for_model( &self, model_id: &str, auth_header: Option<&HeaderValue>, - ) -> Result, Box> { - let find_candidates = || { - self.worker_registry - .get_workers_filtered(None, None, None, Some(RuntimeType::External), true) - .into_iter() - .filter(|w| w.supports_model(model_id) && w.circuit_breaker().can_execute()) - .collect::>() - }; - - let candidates = find_candidates(); - if !candidates.is_empty() { - return Ok(candidates - .into_iter() - .min_by_key(|w| w.load()) - .expect("candidates is not empty")); + ) -> Result, Response> { + // Try to find a worker immediately + if let Some(worker) = self.find_best_worker_for_model(model_id) { + return Ok(worker); } + // Refresh external models and try again tracing::debug!( "No worker found for model '{}', refreshing external worker models", model_id ); self.refresh_external_models(auth_header).await; - let candidates = find_candidates(); - if !candidates.is_empty() { - return Ok(candidates - .into_iter() - .min_by_key(|w| w.load()) - .expect("candidates is not empty")); - } + self.find_best_worker_for_model(model_id) + .ok_or_else(|| error_responses::model_not_found(model_id)) + } - Err(Box::new( - ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("No worker available for model '{}'", model_id), - "type": "model_not_found", - } - })), - ) - .into_response(), - )) + /// Deserialize ResponseInputOutputItems from a JSON array value + fn deserialize_items_from_array(array: &Value) -> Vec { + array + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|item| { + serde_json::from_value::(item.clone()) + .map_err(|e| warn!("Failed to deserialize item: {}. Item: {}", e, item)) + .ok() + }) + .collect() + }) + .unwrap_or_default() + } + + /// Append current request input to items list, creating a user message if needed + fn append_current_input( + items: &mut Vec, + input: &ResponseInput, + id_suffix: &str, + ) { + match input { + ResponseInput::Text(text) => { + items.push(ResponseInputOutputItem::Message { + id: format!("msg_u_{}", id_suffix), + role: "user".to_string(), + content: vec![ResponseContentPart::InputText { text: text.clone() }], + status: Some("completed".to_string()), + }); + } + ResponseInput::Items(current_items) => { + for item in current_items { + items.push(crate::protocols::responses::normalize_input_item(item)); + } + } + } } async fn handle_non_streaming_response(&self, mut ctx: RequestContext) -> Response { @@ -383,65 +473,38 @@ impl crate::routers::RouterTrait for OpenAIRouter { } async fn health_generate(&self, _req: Request) -> Response { - let external_workers: Vec<_> = self - .worker_registry - .get_all() - .into_iter() - .filter(|w| w.metadata().runtime_type == RuntimeType::External) - .collect(); - + let external_workers = self.external_workers(); if external_workers.is_empty() { - return ( - StatusCode::SERVICE_UNAVAILABLE, - "No external workers registered", - ) - .into_response(); + return error_responses::service_unavailable("No external workers registered"); } - let mut healthy_count = 0; - let mut unhealthy_workers = Vec::new(); - - for worker in &external_workers { - if worker.is_healthy() { - healthy_count += 1; - } else { - unhealthy_workers.push(format!("{} ({})", worker.model_id(), worker.url())); - } - } + let (healthy, unhealthy): (Vec<_>, Vec<_>) = + external_workers.iter().partition(|w| w.is_healthy()); - if unhealthy_workers.is_empty() { + if unhealthy.is_empty() { ( StatusCode::OK, - format!("OK - {} workers healthy", healthy_count), + format!("OK - {} workers healthy", healthy.len()), ) .into_response() } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - format!( - "{}/{} workers unhealthy: {}", - unhealthy_workers.len(), - external_workers.len(), - unhealthy_workers.join(", ") - ), - ) - .into_response() + let unhealthy_info: Vec<_> = unhealthy + .iter() + .map(|w| format!("{} ({})", w.model_id(), w.url())) + .collect(); + error_responses::service_unavailable(format!( + "{}/{} workers unhealthy: {}", + unhealthy.len(), + external_workers.len(), + unhealthy_info.join(", ") + )) } } async fn get_server_info(&self, _req: Request) -> Response { let stats = self.worker_registry.stats(); - let external_workers: Vec<_> = self - .worker_registry - .get_all() - .into_iter() - .filter(|w| w.metadata().runtime_type == RuntimeType::External) - .collect(); - - let worker_urls: Vec = external_workers - .iter() - .map(|w| w.url().to_string()) - .collect(); + let external_workers = self.external_workers(); + let worker_urls: Vec<_> = external_workers.iter().map(|w| w.url()).collect(); let info = json!({ "router_type": "openai", @@ -455,19 +518,9 @@ impl crate::routers::RouterTrait for OpenAIRouter { } async fn get_models(&self, req: Request) -> Response { - let external_workers: Vec<_> = self - .worker_registry - .get_all() - .into_iter() - .filter(|w| w.metadata().runtime_type == RuntimeType::External) - .collect(); - + let external_workers = self.external_workers(); if external_workers.is_empty() { - return ( - StatusCode::SERVICE_UNAVAILABLE, - "No external workers registered", - ) - .into_response(); + return error_responses::service_unavailable("No external workers registered"); } let auth_header = extract_auth_header(Some(req.headers()), &None); @@ -530,27 +583,19 @@ impl crate::routers::RouterTrait for OpenAIRouter { .await { Ok(w) => w, - Err(response) => return *response, + Err(response) => return response, }; let mut payload = match to_value(body) { Ok(v) => v, Err(e) => { - return ( - StatusCode::BAD_REQUEST, - format!("Failed to serialize request: {}", e), - ) - .into_response(); + return error_responses::bad_request(format!("Failed to serialize request: {}", e)) } }; let provider = self.get_provider_arc_for_worker(worker.as_ref(), model_id); if let Err(e) = provider.transform_request(&mut payload, Endpoint::Chat) { - return ( - StatusCode::BAD_REQUEST, - format!("Provider transform error: {}", e), - ) - .into_response(); + return error_responses::bad_request(format!("Provider transform error: {}", e)); } let mut ctx = RequestContext::for_chat( @@ -659,7 +704,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { .await { Ok(w) => w, - Err(response) => return *response, + Err(response) => return response, }; let mut request_body = body.clone(); @@ -670,8 +715,9 @@ impl crate::routers::RouterTrait for OpenAIRouter { let original_previous_response_id = request_body.previous_response_id.clone(); + // Load items from previous response chain if specified let mut conversation_items: Option> = None; - if let Some(prev_id_str) = request_body.previous_response_id.clone() { + if let Some(prev_id_str) = request_body.previous_response_id.take() { let prev_id = ResponseId::from(prev_id_str.as_str()); match self .responses_components @@ -680,43 +726,16 @@ impl crate::routers::RouterTrait for OpenAIRouter { .await { Ok(chain) => { - let mut items = Vec::new(); - for stored in chain.responses.iter() { - if let Some(input_arr) = stored.input.as_array() { - for item in input_arr { - match serde_json::from_value::( - item.clone(), - ) { - Ok(input_item) => { - items.push(input_item); - } - Err(e) => { - warn!( - "Failed to deserialize stored input item: {}. Item: {}", - e, item - ); - } - } - } - } - - if let Some(output_arr) = stored.output.as_array() { - for item in output_arr { - match serde_json::from_value::( - item.clone(), - ) { - Ok(output_item) => { - items.push(output_item); - } - Err(e) => { - warn!("Failed to deserialize stored output item: {}. Item: {}", e, item); - } - } - } - } - } + let items: Vec = chain + .responses + .iter() + .flat_map(|stored| { + Self::deserialize_items_from_array(&stored.input) + .into_iter() + .chain(Self::deserialize_items_from_array(&stored.output)) + }) + .collect(); conversation_items = Some(items); - request_body.previous_response_id = None; } Err(e) => { warn!( @@ -736,11 +755,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { .get_conversation(&conv_id) .await { - return ( - StatusCode::NOT_FOUND, - Json(json!({"error": "Conversation not found"})), - ) - .into_response(); + return error_responses::not_found("conversation", &conv_id.0); } let params = ListParams { @@ -825,26 +840,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { } } - match &request_body.input { - ResponseInput::Text(text) => { - items.push(ResponseInputOutputItem::Message { - id: format!("msg_u_{}", conv_id.0), - role: "user".to_string(), - content: vec![ResponseContentPart::InputText { - text: text.clone(), - }], - status: Some("completed".to_string()), - }); - } - ResponseInput::Items(current_items) => { - for item in current_items.iter() { - let normalized = - crate::protocols::responses::normalize_input_item(item); - items.push(normalized); - } - } - } - + Self::append_current_input(&mut items, &request_body.input, &conv_id.0); request_body.input = ResponseInput::Items(items); } Err(e) => { @@ -853,29 +849,10 @@ impl crate::routers::RouterTrait for OpenAIRouter { } } + // Apply previous response chain items if loaded if let Some(mut items) = conversation_items { - match &request_body.input { - ResponseInput::Text(text) => { - items.push(ResponseInputOutputItem::Message { - id: format!( - "msg_u_{}", - original_previous_response_id - .as_ref() - .unwrap_or(&"new".to_string()) - ), - role: "user".to_string(), - content: vec![ResponseContentPart::InputText { text: text.clone() }], - status: Some("completed".to_string()), - }); - } - ResponseInput::Items(current_items) => { - for item in current_items.iter() { - let normalized = crate::protocols::responses::normalize_input_item(item); - items.push(normalized); - } - } - } - + let id_suffix = original_previous_response_id.as_deref().unwrap_or("new"); + Self::append_current_input(&mut items, &request_body.input, id_suffix); request_body.input = ResponseInput::Items(items); } @@ -887,21 +864,13 @@ impl crate::routers::RouterTrait for OpenAIRouter { let mut payload = match to_value(&request_body) { Ok(v) => v, Err(e) => { - return ( - StatusCode::BAD_REQUEST, - format!("Failed to serialize request: {}", e), - ) - .into_response(); + return error_responses::bad_request(format!("Failed to serialize request: {}", e)) } }; let provider = self.get_provider_arc_for_worker(worker.as_ref(), model_id); if let Err(e) = provider.transform_request(&mut payload, Endpoint::Responses) { - return ( - StatusCode::BAD_REQUEST, - format!("Provider transform error: {}", e), - ) - .into_response(); + return error_responses::bad_request(format!("Provider transform error: {}", e)); } let mut ctx = RequestContext::for_responses( @@ -949,16 +918,8 @@ impl crate::routers::RouterTrait for OpenAIRouter { } (StatusCode::OK, Json(response_json)).into_response() } - Ok(None) => ( - StatusCode::NOT_FOUND, - Json(json!({"error": "Response not found"})), - ) - .into_response(), - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ "error": format!("Failed to get response: {}", e) })), - ) - .into_response(), + Ok(None) => error_responses::not_found("response", response_id), + Err(e) => error_responses::internal_error(format!("Failed to get response: {}", e)), } } @@ -976,10 +937,7 @@ impl crate::routers::RouterTrait for OpenAIRouter { .await { Ok(Some(stored)) => { - let items = match &stored.input { - Value::Array(arr) => arr.clone(), - _ => vec![], - }; + let items = stored.input.as_array().cloned().unwrap_or_default(); let items_with_ids: Vec = items .into_iter() @@ -1003,32 +961,10 @@ impl crate::routers::RouterTrait for OpenAIRouter { (StatusCode::OK, Json(response_body)).into_response() } - Ok(None) => ( - StatusCode::NOT_FOUND, - Json(json!({ - "error": { - "message": format!("No response found with id '{}'", response_id), - "type": "invalid_request_error", - "param": Value::Null, - "code": "not_found" - } - })), - ) - .into_response(), + Ok(None) => error_responses::not_found("response", response_id), Err(e) => { warn!("Failed to retrieve input items for {}: {}", response_id, e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - Json(json!({ - "error": { - "message": format!("Failed to retrieve input items: {}", e), - "type": "internal_error", - "param": Value::Null, - "code": "storage_error" - } - })), - ) - .into_response() + error_responses::internal_error(format!("Failed to retrieve input items: {}", e)) } } }