Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 58 additions & 52 deletions sgl-router/src/protocols/spec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1073,8 +1073,8 @@ fn generate_request_id() -> String {
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest {
/// Run the request in the background
#[serde(default)]
pub background: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub background: Option<bool>,

/// Fields to include in the response
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -1108,8 +1108,8 @@ pub struct ResponsesRequest {
pub conversation: Option<String>,

/// Whether to enable parallel tool calls
#[serde(default = "default_true")]
pub parallel_tool_calls: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,

/// ID of previous response to continue from
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -1120,40 +1120,40 @@ pub struct ResponsesRequest {
pub reasoning: Option<ResponseReasoningParam>,

/// Service tier
#[serde(default)]
pub service_tier: ServiceTier,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_tier: Option<ServiceTier>,

/// Whether to store the response
#[serde(default = "default_true")]
pub store: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub store: Option<bool>,

/// Whether to stream the response
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>,

/// Temperature for sampling
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,

/// Tool choice behavior
#[serde(default)]
pub tool_choice: ToolChoice,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,

/// Available tools
#[serde(default)]
pub tools: Vec<ResponseTool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ResponseTool>>,

/// Number of top logprobs to return
#[serde(default)]
pub top_logprobs: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,

/// Top-p sampling parameter
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,

/// Truncation behavior
#[serde(default)]
pub truncation: Truncation,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncation: Option<Truncation>,

/// User identifier
#[serde(skip_serializing_if = "Option::is_none")]
Expand All @@ -1168,12 +1168,12 @@ pub struct ResponsesRequest {
pub priority: i32,

/// Frequency penalty
#[serde(default)]
pub frequency_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,

/// Presence penalty
#[serde(default)]
pub presence_penalty: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,

/// Stop sequences
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -1210,7 +1210,7 @@ fn default_repetition_penalty() -> f32 {
impl Default for ResponsesRequest {
fn default() -> Self {
Self {
background: false,
background: None,
include: None,
input: ResponseInput::Text(String::new()),
instructions: None,
Expand All @@ -1219,23 +1219,23 @@ impl Default for ResponsesRequest {
metadata: None,
model: None,
conversation: None,
parallel_tool_calls: true,
parallel_tool_calls: None,
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::default(),
store: true,
stream: false,
service_tier: None,
store: None,
stream: None,
temperature: None,
tool_choice: ToolChoice::default(),
tools: Vec::new(),
top_logprobs: 0,
tool_choice: None,
tools: None,
top_logprobs: None,
top_p: None,
truncation: Truncation::default(),
truncation: None,
user: None,
request_id: generate_request_id(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
frequency_penalty: None,
presence_penalty: None,
stop: None,
top_k: default_top_k(),
min_p: 0.0,
Expand Down Expand Up @@ -1299,14 +1299,18 @@ impl ResponsesRequest {
"top_p".to_string(),
Value::Number(Number::from_f64(top_p as f64).unwrap()),
);
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(self.frequency_penalty as f64).unwrap()),
);
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(self.presence_penalty as f64).unwrap()),
);
if let Some(fp) = self.frequency_penalty {
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(fp as f64).unwrap()),
);
}
if let Some(pp) = self.presence_penalty {
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(pp as f64).unwrap()),
);
}
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
params.insert(
"min_p".to_string(),
Expand Down Expand Up @@ -1337,7 +1341,7 @@ impl ResponsesRequest {

impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool {
self.stream
self.stream.unwrap_or(false)
}

fn get_model(&self) -> Option<&str> {
Expand Down Expand Up @@ -1523,31 +1527,33 @@ impl ResponsesResponse {
max_output_tokens: request.max_output_tokens,
model: model_name,
output,
parallel_tool_calls: request.parallel_tool_calls,
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
previous_response_id: request.previous_response_id.clone(),
reasoning: request.reasoning.as_ref().map(|r| ReasoningInfo {
effort: r.effort.as_ref().map(|e| format!("{:?}", e)),
summary: None,
}),
store: request.store,
store: request.store.unwrap_or(false),
temperature: request.temperature,
text: Some(ResponseTextFormat {
format: TextFormatType {
format_type: "text".to_string(),
},
}),
tool_choice: match &request.tool_choice {
ToolChoice::Value(ToolChoiceValue::Auto) => "auto".to_string(),
ToolChoice::Value(ToolChoiceValue::Required) => "required".to_string(),
ToolChoice::Value(ToolChoiceValue::None) => "none".to_string(),
ToolChoice::Function { .. } => "function".to_string(),
ToolChoice::AllowedTools { mode, .. } => mode.clone(),
Some(ToolChoice::Value(ToolChoiceValue::Auto)) => "auto".to_string(),
Some(ToolChoice::Value(ToolChoiceValue::Required)) => "required".to_string(),
Some(ToolChoice::Value(ToolChoiceValue::None)) => "none".to_string(),
Some(ToolChoice::Function { .. }) => "function".to_string(),
Some(ToolChoice::AllowedTools { mode, .. }) => mode.clone(),
None => "auto".to_string(),
},
tools: request.tools.clone(),
tools: request.tools.clone().unwrap_or_default(),
top_p: request.top_p,
truncation: match &request.truncation {
Truncation::Auto => Some("auto".to_string()),
Truncation::Disabled => Some("disabled".to_string()),
Some(Truncation::Auto) => Some("auto".to_string()),
Some(Truncation::Disabled) => Some("disabled".to_string()),
None => None,
},
usage: usage.map(ResponsesUsage::Classic),
user: request.user.clone(),
Expand Down
20 changes: 14 additions & 6 deletions sgl-router/src/routers/openai/mcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -689,9 +689,13 @@ pub(super) async fn execute_tool_loop(
if state.total_calls > 0 {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");

// Build mcp_list_tools item
Expand Down Expand Up @@ -747,9 +751,13 @@ pub(super) fn build_incomplete_response(
if let Some(output_array) = obj.get_mut("output").and_then(|v| v.as_array_mut()) {
let server_label = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
.as_ref()
.and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp))
.and_then(|t| t.server_label.as_deref())
})
.unwrap_or("mcp");

// Find any function_call items and convert them to mcp_call (incomplete)
Expand Down
16 changes: 10 additions & 6 deletions sgl-router/src/routers/openai/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ pub(super) fn patch_streaming_response_json(
}
}

obj.insert("store".to_string(), Value::Bool(original_body.store));
obj.insert(
"store".to_string(),
Value::Bool(original_body.store.unwrap_or(false)),
);

if obj
.get("model")
Expand Down Expand Up @@ -205,7 +208,7 @@ pub(super) fn rewrite_streaming_block(

let mut changed = false;
if let Some(response_obj) = parsed.get_mut("response").and_then(|v| v.as_object_mut()) {
let desired_store = Value::Bool(original_body.store);
let desired_store = Value::Bool(original_body.store.unwrap_or(false));
if response_obj.get("store") != Some(&desired_store) {
response_obj.insert("store".to_string(), desired_store);
changed = true;
Expand Down Expand Up @@ -267,10 +270,11 @@ pub(super) fn rewrite_streaming_block(

/// Mask function tools as MCP tools in response for client
pub(super) fn mask_tools_as_mcp(resp: &mut Value, original_body: &ResponsesRequest) {
let mcp_tool = original_body
.tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some());
let mcp_tool = original_body.tools.as_ref().and_then(|tools| {
tools
.iter()
.find(|t| matches!(t.r#type, ResponseToolType::Mcp) && t.server_url.is_some())
});
let Some(t) = mcp_tool else {
return;
};
Expand Down
47 changes: 40 additions & 7 deletions sgl-router/src/routers/openai/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ impl OpenAIRouter {
original_previous_response_id: Option<String>,
) -> Response {
// Check if MCP is active for this request
let req_mcp_manager = mcp_manager_from_request_tools(&original_body.tools).await;
let req_mcp_manager = if let Some(ref tools) = original_body.tools {
mcp_manager_from_request_tools(tools.as_slice()).await
} else {
None
};
let active_mcp = req_mcp_manager.as_ref().or(self.mcp_manager.as_ref());

let mut response_json: Value;
Expand Down Expand Up @@ -183,6 +187,7 @@ impl OpenAIRouter {
}
} else {
// No MCP - simple request

let mut request_builder = self.client.post(&url).json(&payload);
if let Some(h) = headers {
request_builder = apply_request_headers(h, request_builder, true);
Expand Down Expand Up @@ -385,6 +390,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
};
if let Some(obj) = payload.as_object_mut() {
// Always remove SGLang-specific fields (unsupported by OpenAI)
for key in [
"top_k",
"min_p",
Expand Down Expand Up @@ -535,7 +541,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
.into_response();
}

// Clone the body and override model if needed
// Clone the body for validation and logic, but we'll build payload differently
let mut request_body = body.clone();
if let Some(model) = model_id {
request_body.model = Some(model.to_string());
Expand Down Expand Up @@ -690,7 +696,7 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}

// Always set store=false for upstream (we store internally)
request_body.store = false;
request_body.store = Some(false);

// Convert to JSON and strip SGLang-specific fields
let mut payload = match to_value(&request_body) {
Expand All @@ -704,14 +710,13 @@ impl crate::routers::RouterTrait for OpenAIRouter {
}
};

// Remove SGLang-specific fields
// Remove SGLang-specific fields only
if let Some(obj) = payload.as_object_mut() {
// Remove SGLang-specific fields (not part of OpenAI API)
for key in [
"request_id",
"priority",
"top_k",
"frequency_penalty",
"presence_penalty",
"min_p",
"min_tokens",
"regex",
Expand All @@ -732,10 +737,38 @@ impl crate::routers::RouterTrait for OpenAIRouter {
] {
obj.remove(key);
}
// XAI doesn't support the OPENAI item type input: https://platform.openai.com/docs/api-reference/responses/create#responses-create-input-input-item-list-item
// To Achieve XAI compatibility, strip extra fields from input messages (id, status)
// XAI doesn't support output_text as type for content with role of assistant
// so normalize content types: output_text -> input_text
if let Some(input_arr) = obj.get_mut("input").and_then(Value::as_array_mut) {
for item_obj in input_arr.iter_mut().filter_map(Value::as_object_mut) {
// Remove fields not universally supported
item_obj.remove("id");
item_obj.remove("status");

// Normalize content types to input_text (xAI compatibility)
if let Some(content_arr) =
item_obj.get_mut("content").and_then(Value::as_array_mut)
{
for content_obj in content_arr.iter_mut().filter_map(Value::as_object_mut) {
// Change output_text to input_text
if content_obj.get("type").and_then(Value::as_str)
== Some("output_text")
{
content_obj.insert(
"type".to_string(),
Value::String("input_text".to_string()),
);
}
}
}
}
}
}

// Delegate to streaming or non-streaming handler
if body.stream {
if body.stream.unwrap_or(false) {
handle_streaming_response(
&self.client,
&self.circuit_breaker,
Expand Down
Loading
Loading