From 2ff3888736ec21d0e0f4ce408535b7cfa486ff49 Mon Sep 17 00:00:00 2001 From: Matt Yaple Date: Fri, 31 Oct 2025 21:55:58 +0000 Subject: [PATCH] fix: adds ProviderRetry to openai provider Signed-off-by: Matt Yaple --- crates/goose/src/providers/openai.rs | 83 ++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 24 deletions(-) diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 8bb9f1a9e780..dca3c0c6bdd7 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -16,6 +16,7 @@ use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsag use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse}; use super::errors::ProviderError; use super::formats::openai::{create_request, get_usage, response_to_message}; +use super::retry::ProviderRetry; use super::utils::{ get_model, handle_response_openai_compat, handle_status_openai_compat, ImageFormat, }; @@ -240,9 +241,15 @@ impl Provider for OpenAiProvider { let payload = create_request(model_config, system, messages, tools, &ImageFormat::OpenAi)?; let mut log = RequestLog::start(&self.model, &payload)?; - let json_response = self.post(&payload).await.inspect_err(|e| { - let _ = log.error(e); - })?; + let json_response = self + .with_retry(|| async { + let payload_clone = payload.clone(); + self.post(&payload_clone).await + }) + .await + .inspect_err(|e| { + let _ = log.error(e); + })?; let message = response_to_message(&json_response)?; let usage = json_response @@ -260,19 +267,30 @@ impl Provider for OpenAiProvider { async fn fetch_supported_models(&self) -> Result>, ProviderError> { let models_path = self.base_path.replace("v1/chat/completions", "v1/models"); - let response = self.api_client.response_get(&models_path).await?; - let json = handle_response_openai_compat(response).await?; - if let Some(err_obj) = json.get("error") { - let msg = err_obj - .get("message") - .and_then(|v| v.as_str()) - .unwrap_or("unknown error"); - return Err(ProviderError::Authentication(msg.to_string())); - } + let response = self + .with_retry(|| async { + let response = self.api_client.response_get(&models_path).await?; + let json = handle_response_openai_compat(response).await?; + if let Some(err_obj) = json.get("error") { + let msg = err_obj + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("unknown error"); + return Err(ProviderError::Authentication(msg.to_string())); + } + Ok(json) + }) + .await + .inspect_err(|e| { + tracing::warn!("Failed to fetch supported models from OpenAI: {:?}", e); + })?; - let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| { - ProviderError::UsageError("Missing data field in JSON response".into()) - })?; + let data = response + .get("data") + .and_then(|v| v.as_array()) + .ok_or_else(|| { + ProviderError::UsageError("Missing data field in JSON response".into()) + })?; let mut models: Vec = data .iter() .filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string)) @@ -310,17 +328,24 @@ impl Provider for OpenAiProvider { let mut log = RequestLog::start(&self.model, &payload)?; let response = self - .api_client - .response_post(&self.base_path, &payload) - .await - .inspect_err(|e| { - let _ = log.error(e); - })?; - let response = handle_status_openai_compat(response) + .with_retry(|| async { + let resp = self + .api_client + .response_post(&self.base_path, &payload) + .await?; + let status = resp.status(); + if !status.is_success() { + return Err(super::utils::map_http_error_to_provider_error( + status, None, // We'll let handle_status_openai_compat parse the error + )); + } + Ok(resp) + }) .await .inspect_err(|e| { let _ = log.error(e); })?; + let response = handle_status_openai_compat(response).await?; let stream = response.bytes_stream().map_err(io::Error::other); @@ -366,8 +391,18 @@ impl EmbeddingCapable for OpenAiProvider { }; let response = self - .api_client - .api_post("v1/embeddings", &serde_json::to_value(request)?) + .with_retry(|| async { + let request_clone = EmbeddingRequest { + input: request.input.clone(), + model: request.model.clone(), + }; + let request_value = serde_json::to_value(request_clone) + .map_err(|e| ProviderError::ExecutionError(e.to_string()))?; + self.api_client + .api_post("v1/embeddings", &request_value) + .await + .map_err(|e| ProviderError::ExecutionError(e.to_string())) + }) .await?; if response.status != StatusCode::OK {