Skip to content
Merged
53 changes: 34 additions & 19 deletions crates/goose/src/providers/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ fn check_context_length_exceeded(text: &str) -> bool {
.any(|phrase| text_lower.contains(phrase))
}

fn format_server_error_message(status_code: StatusCode, payload: Option<&Value>) -> String {
match payload {
Some(Value::Null) | None => format!(
"HTTP {}: No response body received from server",
status_code.as_u16()
),
Some(p) => format!("HTTP {}: {}", status_code.as_u16(), p),
}
}

pub fn map_http_error_to_provider_error(
status: StatusCode,
payload: Option<Value>,
Expand All @@ -79,7 +89,7 @@ pub fn map_http_error_to_provider_error(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. \
Status: {}{}",
status,
payload.as_ref().map(|p| format!(". Response: {:?}", p)).unwrap_or_default()
payload.as_ref().map(|p| format!(". Response: {}", p)).unwrap_or_default()
);
ProviderError::Authentication(message)
}
Expand Down Expand Up @@ -116,7 +126,9 @@ pub fn map_http_error_to_provider_error(
details: format!("{:?}", payload),
retry_delay: None,
},
_ if status.is_server_error() => ProviderError::ServerError(format!("{:?}", payload)),
_ if status.is_server_error() => {
ProviderError::ServerError(format_server_error_message(status, payload.as_ref()))
}
_ => ProviderError::RequestFailed(format!("Request failed with status: {}", status)),
};

Expand Down Expand Up @@ -295,12 +307,9 @@ pub async fn handle_response_google_compat(response: Response) -> Result<Value,
retry_delay,
})
}
_ if final_status.is_server_error() => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
StatusCode::INTERNAL_SERVER_ERROR | StatusCode::SERVICE_UNAVAILABLE => {
Err(ProviderError::ServerError(format!("{:?}", payload)))
}
_ if final_status.is_server_error() => Err(ProviderError::ServerError(
format_server_error_message(final_status, payload.as_ref()),
)),
_ => {
tracing::debug!(
"{}", format!("Provider request failed with status: {}. Payload: {:?}", final_status, payload)
Expand Down Expand Up @@ -1059,47 +1068,41 @@ mod tests {
#[test]
fn test_map_http_error_to_provider_error() {
let test_cases = vec![
// UNAUTHORIZED/FORBIDDEN - with payload
(
StatusCode::UNAUTHORIZED,
Some(json!({"error": "auth failed"})),
ProviderError::Authentication(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 401 Unauthorized. Response: Object {\"error\": String(\"auth failed\")}".to_string(),
"Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 401 Unauthorized. Response: {\"error\":\"auth failed\"}".to_string(),
),
),
// UNAUTHORIZED/FORBIDDEN - without payload
(
StatusCode::FORBIDDEN,
None,
ProviderError::Authentication(
"Authentication failed. Please ensure your API keys are valid and have the required permissions. Status: 403 Forbidden".to_string(),
),
),
// BAD_REQUEST - with context_length_exceeded detection
(
StatusCode::BAD_REQUEST,
Some(json!({"error": {"message": "context_length_exceeded"}})),
ProviderError::ContextLengthExceeded(
"{\"error\":{\"message\":\"context_length_exceeded\"}}".to_string(),
),
),
// BAD_REQUEST - with error.message extraction
(
StatusCode::BAD_REQUEST,
Some(json!({"error": {"message": "Custom error"}})),
ProviderError::RequestFailed(
"Request failed with status: 400 Bad Request. Message: Custom error".to_string(),
),
),
// BAD_REQUEST - without payload
(
StatusCode::BAD_REQUEST,
None,
ProviderError::RequestFailed(
"Request failed with status: 400 Bad Request".to_string(),
),
),
// TOO_MANY_REQUESTS
(
StatusCode::TOO_MANY_REQUESTS,
Some(json!({"retry_after": 60})),
Expand All @@ -1108,17 +1111,29 @@ mod tests {
retry_delay: None,
},
),
// is_server_error() without payload
(
StatusCode::INTERNAL_SERVER_ERROR,
None,
ProviderError::ServerError("None".to_string()),
ProviderError::ServerError(format_server_error_message(
StatusCode::INTERNAL_SERVER_ERROR,
None,
)),
),
(
StatusCode::INTERNAL_SERVER_ERROR,
Some(Value::Null),
ProviderError::ServerError(format_server_error_message(
StatusCode::INTERNAL_SERVER_ERROR,
Some(&Value::Null),
)),
),
// is_server_error() with payload
(
StatusCode::BAD_GATEWAY,
Some(json!({"error": "upstream error"})),
ProviderError::ServerError("Some(Object {\"error\": String(\"upstream error\")})".to_string()),
ProviderError::ServerError(format_server_error_message(
StatusCode::BAD_GATEWAY,
Some(&json!({"error": "upstream error"})),
)),
),
// Default - any other status code
(
Expand Down