diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 06dbd8fecfad..171836d167e7 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -14,7 +14,7 @@ use super::errors::ProviderError; use super::formats::anthropic::{ create_request, get_usage, response_to_message, response_to_streaming_message, }; -use super::utils::{get_model, map_http_error_to_provider_error}; +use super::utils::{get_model, handle_status_openai_compat, map_http_error_to_provider_error}; use crate::config::declarative_providers::DeclarativeProviderConfig; use crate::conversation::message::Message; use crate::model::ModelConfig; @@ -273,17 +273,12 @@ impl Provider for AnthropicProvider { request = request.header(key, value)?; } - let response = request.response_post(&payload).await.inspect_err(|e| { + let resp = request.response_post(&payload).await.inspect_err(|e| { + let _ = log.error(e); + })?; + let response = handle_status_openai_compat(resp).await.inspect_err(|e| { let _ = log.error(e); })?; - if !response.status().is_success() { - let status = response.status(); - let error_text = response.text().await.unwrap_or_default(); - let error_json = serde_json::from_str::(&error_text).ok(); - let error = map_http_error_to_provider_error(status, error_json); - let _ = log.error(&error); - return Err(error); - } let stream = response.bytes_stream().map_err(io::Error::other); diff --git a/crates/goose/src/providers/errors.rs b/crates/goose/src/providers/errors.rs index 5a3ce7341eac..b6ee4e7431fe 100644 --- a/crates/goose/src/providers/errors.rs +++ b/crates/goose/src/providers/errors.rs @@ -112,106 +112,3 @@ impl GoogleErrorCode { } } } - -#[derive(serde::Deserialize, Debug)] -pub struct OpenAIError { - #[serde(deserialize_with = "code_as_string")] - pub code: Option, - pub message: Option, - #[serde(rename = "type")] - pub error_type: Option, -} - -fn code_as_string<'de, D>(deserializer: D) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - use serde::de::{self, Visitor}; - use std::fmt; - - struct CodeVisitor; - - impl<'de> Visitor<'de> for CodeVisitor { - type Value = Option; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string, a number, null, or none for the code field") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - Ok(Some(value.to_string())) - } - - fn visit_u64(self, value: u64) -> Result - where - E: de::Error, - { - Ok(Some(value.to_string())) - } - - fn visit_none(self) -> Result - where - E: de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: de::Error, - { - Ok(None) - } - - fn visit_some(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(CodeVisitor) - } - } - - deserializer.deserialize_option(CodeVisitor) -} - -impl OpenAIError { - pub fn is_context_length_exceeded(&self) -> bool { - if let Some(code) = &self.code { - code == "context_length_exceeded" || code == "string_above_max_length" - } else { - false - } - } -} - -impl std::fmt::Display for OpenAIError { - /// Format the error for display. - /// E.g. {"message": "Invalid API key", "code": "invalid_api_key", "type": "client_error"} - /// would be formatted as "Invalid API key (code: invalid_api_key, type: client_error)" - /// and {"message": "Foo"} as just "Foo", etc. - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(message) = &self.message { - write!(f, "{}", message)?; - } - let mut in_parenthesis = false; - if let Some(code) = &self.code { - write!(f, " (code: {}", code)?; - in_parenthesis = true; - } - if let Some(typ) = &self.error_type { - if in_parenthesis { - write!(f, ", type: {}", typ)?; - } else { - write!(f, " (type: {}", typ)?; - in_parenthesis = true; - } - } - if in_parenthesis { - write!(f, ")")?; - } - Ok(()) - } -} diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index f52a40aab9fb..b39b9036e5e1 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -273,18 +273,12 @@ impl Provider for OllamaProvider { .api_client .response_post("v1/chat/completions", &payload) .await?; - let status = resp.status(); - if !status.is_success() { - return Err(super::utils::map_http_error_to_provider_error(status, None)); - } - Ok(resp) + handle_status_openai_compat(resp).await }) .await .inspect_err(|e| { let _ = log.error(e); })?; - let response = handle_status_openai_compat(response).await?; - let stream = response.bytes_stream().map_err(io::Error::other); Ok(Box::pin(try_stream! { diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index dca3c0c6bdd7..07b0339bdf1c 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -333,20 +333,12 @@ impl Provider for OpenAiProvider { .api_client .response_post(&self.base_path, &payload) .await?; - let status = resp.status(); - if !status.is_success() { - return Err(super::utils::map_http_error_to_provider_error( - status, None, // We'll let handle_status_openai_compat parse the error - )); - } - Ok(resp) + handle_status_openai_compat(resp).await }) .await .inspect_err(|e| { let _ = log.error(e); })?; - let response = handle_status_openai_compat(response).await?; - let stream = response.bytes_stream().map_err(io::Error::other); Ok(Box::pin(try_stream! { diff --git a/crates/goose/src/providers/tetrate.rs b/crates/goose/src/providers/tetrate.rs index 709acb21e265..9547628d46aa 100644 --- a/crates/goose/src/providers/tetrate.rs +++ b/crates/goose/src/providers/tetrate.rs @@ -214,18 +214,18 @@ impl Provider for TetrateProvider { &super::utils::ImageFormat::OpenAi, )?; - // Enable streaming payload["stream"] = json!(true); payload["stream_options"] = json!({ "include_usage": true, }); - let response = self + let resp = self .api_client .response_post("v1/chat/completions", &payload) .await?; - let response = handle_status_openai_compat(response).await?; + let response = handle_status_openai_compat(resp).await?; + let stream = response.bytes_stream().map_err(io::Error::other); let mut log = RequestLog::start(&self.model, &payload)?; diff --git a/crates/goose/src/providers/utils.rs b/crates/goose/src/providers/utils.rs index 1839708fb51c..d059d6695ea5 100644 --- a/crates/goose/src/providers/utils.rs +++ b/crates/goose/src/providers/utils.rs @@ -2,7 +2,7 @@ use super::base::Usage; use super::errors::GoogleErrorCode; use crate::config::paths::Paths; use crate::model::ModelConfig; -use crate::providers::errors::{OpenAIError, ProviderError}; +use crate::providers::errors::ProviderError; use anyhow::{anyhow, Result}; use base64::Engine; use regex::Regex; @@ -17,11 +17,6 @@ use std::path::{Path, PathBuf}; use std::time::Duration; use uuid::Uuid; -#[derive(serde::Deserialize)] -struct OpenAIErrorResponse { - error: OpenAIError, -} - #[derive(Debug, Copy, Clone, Serialize, Deserialize)] pub enum ImageFormat { OpenAi, @@ -107,54 +102,50 @@ pub fn map_http_error_to_provider_error( status: StatusCode, payload: Option, ) -> ProviderError { + let extract_message = || -> String { + payload + .as_ref() + .and_then(|p| { + p.get("error") + .and_then(|e| e.get("message")) + .or_else(|| p.get("message")) + .and_then(|m| m.as_str()) + .map(String::from) + }) + .unwrap_or_else(|| payload.as_ref().map(|p| p.to_string()).unwrap_or_default()) + }; + let error = match status { StatusCode::OK => unreachable!("Should not call this function with OK status"), - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - let message = format!( - "Authentication failed. Please ensure your API keys are valid and have the required permissions. \ - Status: {}{}", - status, - payload.as_ref().map(|p| format!(". Response: {}", p)).unwrap_or_default() - ); - ProviderError::Authentication(message) - } - StatusCode::PAYLOAD_TOO_LARGE => { - let payload_str = if let Some(payload) = &payload { - payload.to_string() - } else { - "Payload is too large.".to_string() - }; - ProviderError::ContextLengthExceeded(payload_str) + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ProviderError::Authentication(format!( + "Authentication failed. Status: {}. Response: {}", + status, + extract_message() + )), + StatusCode::NOT_FOUND => { + ProviderError::RequestFailed(format!("Resource not found (404): {}", extract_message())) } + StatusCode::PAYLOAD_TOO_LARGE => ProviderError::ContextLengthExceeded(extract_message()), StatusCode::BAD_REQUEST => { - let base_msg = format!("Request failed with status: {}", status); - if let Some(payload) = &payload { - let payload_str = payload.to_string(); - if check_context_length_exceeded(&payload_str) { - ProviderError::ContextLengthExceeded(payload_str) - } else { - ProviderError::RequestFailed( - payload - .get("error") - .and_then(|e| e.get("message")) - .or_else(|| payload.get("message")) - .and_then(|m| m.as_str()) - .map(|msg| format!("{}. Message: {}", base_msg, msg)) - .unwrap_or(base_msg), - ) - } + let payload_str = extract_message(); + if check_context_length_exceeded(&payload_str) { + ProviderError::ContextLengthExceeded(payload_str) } else { - ProviderError::RequestFailed(base_msg) + ProviderError::RequestFailed(format!("Bad request (400): {}", payload_str)) } } StatusCode::TOO_MANY_REQUESTS => ProviderError::RateLimitExceeded { - details: format!("{:?}", payload), + details: extract_message(), retry_delay: None, }, _ if status.is_server_error() => { - ProviderError::ServerError(format_server_error_message(status, payload.as_ref())) + ProviderError::ServerError(format!("Server error ({}): {}", status, extract_message())) } - _ => ProviderError::RequestFailed(format!("Request failed with status: {}", status)), + _ => ProviderError::RequestFailed(format!( + "Request failed with status {}: {}", + status, + extract_message() + )), }; if !status.is_success() { @@ -169,51 +160,14 @@ pub fn map_http_error_to_provider_error( error } -/// Handles HTTP responses from OpenAI-compatible endpoints. -/// -/// Returns the response if status is OK; otherwise, reads the body and maps to a `ProviderError`, -/// with special handling for context length exceeded and other OpenAI-formatted errors. -/// -/// ### References -/// - Error Codes: https://platform.openai.com/docs/guides/error-codes -/// - Context Window Exceeded: https://community.openai.com/t/help-needed-tackling-context-length-limits-in-openai-models/617543 -/// -/// ### Arguments -/// - `response`: The HTTP response to process. -/// -/// ### Returns -/// - `Ok(Response)`: The original response on success. -/// - `Err(ProviderError)`: Describes the failure reason.``` pub async fn handle_status_openai_compat(response: Response) -> Result { let status = response.status(); - if status == StatusCode::OK { - return Ok(response); - } - - let body_str = response - .text() - .await - .map_err(|_| map_http_error_to_provider_error(status, None))?; - - if matches!(status, StatusCode::BAD_REQUEST | StatusCode::NOT_FOUND) { - if let Ok(err_resp) = serde_json::from_str::(&body_str) { - let err = err_resp.error; - if err.is_context_length_exceeded() { - return Err(ProviderError::ContextLengthExceeded( - err.message.unwrap_or("Unknown error".to_string()), - )); - } else { - return Err(ProviderError::RequestFailed(format!( - "{} (status {})", - err, - status.as_u16() - ))); - } - } + if !status.is_success() { + let body = response.text().await.unwrap_or_default(); + let payload = serde_json::from_str::(&body).ok(); + return Err(map_http_error_to_provider_error(status, payload)); } - - let payload = serde_json::from_str::(&body_str).ok(); - Err(map_http_error_to_provider_error(status, payload)) + Ok(response) } pub async fn handle_response_openai_compat(response: Response) -> Result { @@ -645,7 +599,6 @@ pub fn json_escape_control_chars_in_string(s: &str) -> String { mod tests { use super::*; use serde_json::json; - use wiremock::{matchers, Mock, MockServer, ResponseTemplate}; #[test] fn test_detect_image_path() { @@ -970,208 +923,4 @@ mod tests { Some(Duration::from_secs(42)) ); } - - #[tokio::test] - async fn test_handle_status_openai_compat() { - let test_cases = vec![ - // (status_code, body, expected_result) - // Success case - 200 OK returns response as-is - ( - 200, - Some(json!({ - "choices": [{ - "finish_reason": "stop", - "index": 0, - "message": { - "content": "Hi there! How can I help you today?", - "role": "assistant" - } - }], - "created": 1755133833, - "id": "chatcmpl-test", - "model": "gpt-5-nano", - "usage": { - "completion_tokens": 10, - "prompt_tokens": 8, - "total_tokens": 18 - } - })), - Ok(()), - ), - // 400 Bad Request with OpenAI-formatted error (directly handled) - ( - 400, - Some(json!({ - "error": { - "code": "unsupported_parameter", - "message": "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead.", - "param": "max_tokens", - "type": "invalid_request_error" - } - })), - Err(ProviderError::RequestFailed( - "Unsupported parameter: 'max_tokens' is not supported with this model. Use 'max_completion_tokens' instead. (code: unsupported_parameter, type: invalid_request_error) (status 400)".to_string(), - )), - ), - // 400 with context_length_exceeded in OpenAI format (directly handled) - ( - 400, - Some(json!({ - "error": { - "code": "context_length_exceeded", - "message": "This model's maximum context length is 4096 tokens.", - "type": "invalid_request_error" - } - })), - Err(ProviderError::ContextLengthExceeded( - "This model's maximum context length is 4096 tokens.".to_string(), - )), - ), - // 404 Not Found with OpenAI-formatted error (directly handled like 400) - ( - 404, - Some(json!({ - "error": { - "code": "model_not_found", - "message": "The model 'gpt-5' does not exist", - "type": "invalid_request_error" - } - })), - Err(ProviderError::RequestFailed( - "The model 'gpt-5' does not exist (code: model_not_found, type: invalid_request_error) (status 404)".to_string(), - )), - ), - // Non-JSON body error (tests 413 PAYLOAD_TOO_LARGE -> ContextLengthExceeded) - ( - 413, - Some(Value::String("Payload Too Large".to_string())), - Err(ProviderError::ContextLengthExceeded( - "Payload is too large.".to_string(), - )), - ), - ]; - - for (status_code, body, expected_result) in test_cases { - let mock_server = MockServer::start().await; - - let mut response_template = ResponseTemplate::new(status_code); - - // Set body based on test case - if let Some(body_value) = body { - if body_value.is_string() { - // For non-JSON bodies (like "Payload Too Large") - response_template = - response_template.set_body_string(body_value.as_str().unwrap().to_string()); - } else { - // For JSON bodies - response_template = response_template.set_body_json(&body_value); - } - } - - Mock::given(matchers::method("GET")) - .and(matchers::path("/test")) - .respond_with(response_template) - .mount(&mock_server) - .await; - - // Make request to mock server - let client = reqwest::Client::new(); - let response = client - .get(format!("{}/test", &mock_server.uri())) - .send() - .await - .unwrap(); - - // Test handle_status_openai_compat - let result = handle_status_openai_compat(response).await.map(|_| ()); - - assert_eq!(result, expected_result, "for status {}", status_code); - } - } - - #[test] - fn test_map_http_error_to_provider_error() { - let test_cases = vec![ - ( - StatusCode::UNAUTHORIZED, - Some(json!({"error": "auth failed"})), - ProviderError::Authentication( - "Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 401 Unauthorized. Response: {\"error\":\"auth failed\"}".to_string(), - ), - ), - ( - StatusCode::FORBIDDEN, - None, - ProviderError::Authentication( - "Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 403 Forbidden".to_string(), - ), - ), - ( - StatusCode::BAD_REQUEST, - Some(json!({"error": {"message": "context_length_exceeded"}})), - ProviderError::ContextLengthExceeded( - "{\"error\":{\"message\":\"context_length_exceeded\"}}".to_string(), - ), - ), - ( - StatusCode::BAD_REQUEST, - Some(json!({"error": {"message": "Custom error"}})), - ProviderError::RequestFailed( - "Request failed with status: 400 Bad Request. Message: Custom error".to_string(), - ), - ), - ( - StatusCode::BAD_REQUEST, - None, - ProviderError::RequestFailed( - "Request failed with status: 400 Bad Request".to_string(), - ), - ), - ( - StatusCode::TOO_MANY_REQUESTS, - Some(json!({"retry_after": 60})), - ProviderError::RateLimitExceeded{ - details: "Some(Object {\"retry_after\": Number(60)})".to_string(), - retry_delay: None, - }, - ), - ( - StatusCode::INTERNAL_SERVER_ERROR, - None, - ProviderError::ServerError(format_server_error_message( - StatusCode::INTERNAL_SERVER_ERROR, - None, - )), - ), - ( - StatusCode::INTERNAL_SERVER_ERROR, - Some(Value::Null), - ProviderError::ServerError(format_server_error_message( - StatusCode::INTERNAL_SERVER_ERROR, - Some(&Value::Null), - )), - ), - ( - StatusCode::BAD_GATEWAY, - Some(json!({"error": "upstream error"})), - ProviderError::ServerError(format_server_error_message( - StatusCode::BAD_GATEWAY, - Some(&json!({"error": "upstream error"})), - )), - ), - // Default - any other status code - ( - StatusCode::IM_A_TEAPOT, - Some(json!({"ignored": "payload"})), - ProviderError::RequestFailed( - "Request failed with status: 418 I'm a teapot".to_string(), - ), - ), - ]; - - for (status, payload, expected_error) in test_cases { - let result = map_http_error_to_provider_error(status, payload); - assert_eq!(result, expected_error); - } - } }