diff --git a/crates/goose/src/providers/formats/mod.rs b/crates/goose/src/providers/formats/mod.rs index e0a88288a49e..d28f06a54208 100644 --- a/crates/goose/src/providers/formats/mod.rs +++ b/crates/goose/src/providers/formats/mod.rs @@ -5,4 +5,5 @@ pub mod gcpvertexai; pub mod google; pub mod openai; pub mod openai_responses; +pub mod openrouter; pub mod snowflake; diff --git a/crates/goose/src/providers/formats/openai.rs b/crates/goose/src/providers/formats/openai.rs index 34cf33fc84f8..f3a8a1ff68cd 100644 --- a/crates/goose/src/providers/formats/openai.rs +++ b/crates/goose/src/providers/formats/openai.rs @@ -1,4 +1,4 @@ -use crate::conversation::message::{Message, MessageContent}; +use crate::conversation::message::{Message, MessageContent, ProviderMetadata}; use crate::model::ModelConfig; use crate::providers::base::{ProviderUsage, Usage}; use crate::providers::utils::{ @@ -37,6 +37,7 @@ struct Delta { content: Option, role: Option, tool_calls: Option>, + reasoning_details: Option>, } #[derive(Serialize, Deserialize, Debug)] @@ -449,6 +450,8 @@ where try_stream! { use futures::StreamExt; + let mut accumulated_reasoning: Vec = Vec::new(); + 'outer: while let Some(response) = stream.next().await { if response.as_ref().is_ok_and(|s| s == "data: [DONE]") { break 'outer; @@ -464,6 +467,12 @@ where .ok_or_else(|| anyhow!("unexpected stream format"))?) .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; + if !chunk.choices.is_empty() { + if let Some(details) = &chunk.choices[0].delta.reasoning_details { + accumulated_reasoning.extend(details.iter().cloned()); + } + } + let usage = chunk.usage.as_ref().and_then(|u| { chunk.model.as_ref().map(|model| { ProviderUsage { @@ -486,7 +495,6 @@ where } } - // Check if this chunk already has finish_reason "tool_calls" let is_complete = chunk.choices[0].finish_reason == Some("tool_calls".to_string()); if !is_complete { @@ -502,6 +510,9 @@ where .map_err(|e| anyhow!("Failed to parse streaming chunk: {}: {:?}", e, &line))?; if !tool_chunk.choices.is_empty() { + if let Some(details) = &tool_chunk.choices[0].delta.reasoning_details { + accumulated_reasoning.extend(details.iter().cloned()); + } if let Some(delta_tool_calls) = &tool_chunk.choices[0].delta.tool_calls { for delta_call in delta_tool_calls { if let Some(index) = delta_call.index { @@ -526,6 +537,14 @@ where } } + let metadata: Option = if !accumulated_reasoning.is_empty() { + let mut map = ProviderMetadata::new(); + map.insert("reasoning_details".to_string(), json!(accumulated_reasoning)); + Some(map) + } else { + None + }; + let mut contents = Vec::new(); let mut sorted_indices: Vec<_> = tool_call_data.keys().cloned().collect(); sorted_indices.sort(); @@ -540,9 +559,10 @@ where let content = match parsed { Ok(params) => { - MessageContent::tool_request( + MessageContent::tool_request_with_metadata( id.clone(), Ok(CallToolRequestParam { name: function_name.clone().into(), arguments: Some(object(params)) }), + metadata.as_ref(), ) }, Err(e) => { @@ -554,7 +574,7 @@ where )), data: None, }; - MessageContent::tool_request(id.clone(), Err(error)) + MessageContent::tool_request_with_metadata(id.clone(), Err(error), metadata.as_ref()) } }; contents.push(content); diff --git a/crates/goose/src/providers/formats/openrouter.rs b/crates/goose/src/providers/formats/openrouter.rs new file mode 100644 index 000000000000..f20d613cc075 --- /dev/null +++ b/crates/goose/src/providers/formats/openrouter.rs @@ -0,0 +1,152 @@ +use crate::conversation::message::{Message, MessageContent, ProviderMetadata}; +use crate::providers::formats::openai; +use rmcp::model::Role; +use serde_json::{json, Value}; + +pub const REASONING_DETAILS_KEY: &str = "reasoning_details"; + +fn has_assistant_content(message: &Message) -> bool { + message.content.iter().any(|c| match c { + MessageContent::Text(t) => !t.text.is_empty(), + MessageContent::Image(_) => true, + MessageContent::ToolRequest(req) => req.tool_call.is_ok(), + MessageContent::FrontendToolRequest(req) => req.tool_call.is_ok(), + _ => false, + }) +} + +pub fn extract_reasoning_details(response: &Value) -> Option> { + response + .get("choices") + .and_then(|c| c.get(0)) + .and_then(|m| m.get("message")) + .and_then(|msg| msg.get("reasoning_details")) + .and_then(|d| d.as_array()) + .cloned() +} + +pub fn get_reasoning_details(metadata: &Option) -> Option> { + metadata + .as_ref() + .and_then(|m| m.get(REASONING_DETAILS_KEY)) + .and_then(|v| v.as_array()) + .cloned() +} + +pub fn response_to_message(response: &Value) -> anyhow::Result { + let mut message = openai::response_to_message(response)?; + + if let Some(details) = extract_reasoning_details(response) { + for content in &mut message.content { + if let MessageContent::ToolRequest(req) = content { + let mut meta = req.metadata.clone().unwrap_or_default(); + meta.insert(REASONING_DETAILS_KEY.to_string(), json!(details)); + req.metadata = Some(meta); + } + } + } + + Ok(message) +} + +pub fn add_reasoning_details_to_request(payload: &mut Value, messages: &[Message]) { + let mut assistant_reasoning: Vec>> = messages + .iter() + .filter(|m| m.is_agent_visible()) + .filter(|m| m.role == Role::Assistant) + .filter(|m| has_assistant_content(m)) + .map(|message| { + message.content.iter().find_map(|c| match c { + MessageContent::ToolRequest(req) => get_reasoning_details(&req.metadata), + _ => None, + }) + }) + .collect(); + + if let Some(payload_messages) = payload + .as_object_mut() + .and_then(|obj| obj.get_mut("messages")) + .and_then(|m| m.as_array_mut()) + { + let mut assistant_idx = 0; + for payload_msg in payload_messages.iter_mut() { + if payload_msg.get("role").and_then(|r| r.as_str()) == Some("assistant") { + if assistant_idx < assistant_reasoning.len() { + if let Some(details) = assistant_reasoning + .get_mut(assistant_idx) + .and_then(|d| d.take()) + { + if let Some(obj) = payload_msg.as_object_mut() { + obj.insert("reasoning_details".to_string(), json!(details)); + } + } + } + assistant_idx += 1; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_reasoning_details() { + let response = json!({ + "choices": [{ + "message": { + "content": "Hello", + "reasoning_details": [ + {"type": "text", "text": "Let me think..."}, + {"type": "encrypted", "data": "abc123signature"} + ] + } + }] + }); + + let details = extract_reasoning_details(&response).unwrap(); + assert_eq!(details.len(), 2); + } + + #[test] + fn test_response_to_message_with_tool_calls() { + let response = json!({ + "choices": [{ + "message": { + "content": null, + "tool_calls": [{ + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\": \"NYC\"}" + } + }], + "reasoning_details": [ + {"type": "encrypted", "data": "sig456"} + ] + } + }] + }); + + let message = response_to_message(&response).unwrap(); + assert!(!message.content.is_empty()); + + let tool_request = message + .content + .iter() + .find_map(|c| { + if let MessageContent::ToolRequest(req) = c { + Some(req) + } else { + None + } + }) + .unwrap(); + + assert!(tool_request.metadata.is_some()); + let details = get_reasoning_details(&tool_request.metadata).unwrap(); + assert_eq!(details.len(), 1); + } +} diff --git a/crates/goose/src/providers/openrouter.rs b/crates/goose/src/providers/openrouter.rs index e3d4cd2ea5ea..967659b36f9e 100644 --- a/crates/goose/src/providers/openrouter.rs +++ b/crates/goose/src/providers/openrouter.rs @@ -1,4 +1,4 @@ -use anyhow::{Error, Result}; +use anyhow::Result; use async_trait::async_trait; use serde_json::{json, Value}; @@ -7,13 +7,14 @@ use super::base::{ConfigKey, MessageStream, Provider, ProviderMetadata, Provider use super::errors::ProviderError; use super::retry::ProviderRetry; use super::utils::{ - get_model, handle_response_google_compat, handle_response_openai_compat, - handle_status_openai_compat, is_google_model, stream_openai_compat, RequestLog, + get_model, handle_response_openai_compat, handle_status_openai_compat, stream_openai_compat, + RequestLog, }; use crate::conversation::message::Message; use crate::model::ModelConfig; -use crate::providers::formats::openai::{create_request, get_usage, response_to_message}; +use crate::providers::formats::openai::{create_request, get_usage}; +use crate::providers::formats::openrouter as openrouter_format; use rmcp::model::Tool; pub const OPENROUTER_DEFAULT_MODEL: &str = "anthropic/claude-sonnet-4"; @@ -74,27 +75,11 @@ impl OpenRouterProvider { .response_post("api/v1/chat/completions", payload) .await?; - // Handle Google-compatible model responses differently - if is_google_model(payload) { - return handle_response_google_compat(response).await; - } - - // For OpenAI-compatible models, parse the response body to JSON let response_body = handle_response_openai_compat(response) .await .map_err(|e| ProviderError::RequestFailed(format!("Failed to parse response: {e}")))?; - let _debug = format!( - "OpenRouter request with payload: {} and response: {}", - serde_json::to_string_pretty(payload).unwrap_or_else(|_| "Invalid JSON".to_string()), - serde_json::to_string_pretty(&response_body) - .unwrap_or_else(|_| "Invalid JSON".to_string()) - ); - - // OpenRouter can return errors in 200 OK responses, so we have to check for errors explicitly - // https://openrouter.ai/docs/api-reference/errors if let Some(error_obj) = response_body.get("error") { - // If there's an error object, extract the error message and code let error_message = error_obj .get("message") .and_then(|m| m.as_str()) @@ -102,14 +87,12 @@ impl OpenRouterProvider { let error_code = error_obj.get("code").and_then(|c| c.as_u64()).unwrap_or(0); - // Check for context length errors in the error message if error_code == 400 && error_message.contains("maximum context length") { return Err(ProviderError::ContextLengthExceeded( error_message.to_string(), )); } - // Return appropriate error based on the OpenRouter error code match error_code { 401 | 403 => return Err(ProviderError::Authentication(error_message.to_string())), 429 => { @@ -123,7 +106,6 @@ impl OpenRouterProvider { } } - // No error detected, return the response body Ok(response_body) } } @@ -201,12 +183,16 @@ fn update_request_for_anthropic(original_payload: &Value) -> Value { payload } +fn is_gemini_model(model_name: &str) -> bool { + model_name.starts_with("google/") +} + async fn create_request_based_on_model( provider: &OpenRouterProvider, system: &str, messages: &[Message], tools: &[Tool], -) -> anyhow::Result { +) -> Result { let mut payload = create_request( &provider.model, system, @@ -220,10 +206,13 @@ async fn create_request_based_on_model( payload = update_request_for_anthropic(&payload); } - payload - .as_object_mut() - .unwrap() - .insert("transforms".to_string(), json!(["middle-out"])); + if is_gemini_model(&provider.model.model_name) { + openrouter_format::add_reasoning_details_to_request(&mut payload, messages); + } + + if let Some(obj) = payload.as_object_mut() { + obj.insert("transforms".to_string(), json!(["middle-out"])); + } Ok(payload) } @@ -272,7 +261,6 @@ impl Provider for OpenRouterProvider { let payload = create_request_based_on_model(self, system, messages, tools).await?; let mut log = RequestLog::start(model_config, &payload)?; - // Make request let response = self .with_retry(|| async { let payload_clone = payload.clone(); @@ -280,13 +268,17 @@ impl Provider for OpenRouterProvider { }) .await?; - // Parse response - let message = response_to_message(&response)?; + let response_model = get_model(&response); + let message = if is_gemini_model(&self.model.model_name) { + openrouter_format::response_to_message(&response)? + } else { + crate::providers::formats::openai::response_to_message(&response)? + }; + let usage = response.get("usage").map(get_usage).unwrap_or_else(|| { tracing::debug!("Failed to get usage data"); Usage::default() }); - let response_model = get_model(&response); log.write(&response, Some(&usage))?; Ok((message, ProviderUsage::new(response_model, usage))) } @@ -397,10 +389,13 @@ impl Provider for OpenRouterProvider { payload = update_request_for_anthropic(&payload); } - payload - .as_object_mut() - .unwrap() - .insert("transforms".to_string(), json!(["middle-out"])); + if is_gemini_model(&self.model.model_name) { + openrouter_format::add_reasoning_details_to_request(&mut payload, messages); + } + + if let Some(obj) = payload.as_object_mut() { + obj.insert("transforms".to_string(), json!(["middle-out"])); + } let mut log = RequestLog::start(&self.model, &payload)?;