Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 59 additions & 24 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use super::base::{ConfigKey, ModelInfo, Provider, ProviderMetadata, ProviderUsag
use super::embedding::{EmbeddingCapable, EmbeddingRequest, EmbeddingResponse};
use super::errors::ProviderError;
use super::formats::openai::{create_request, get_usage, response_to_message};
use super::retry::ProviderRetry;
use super::utils::{
get_model, handle_response_openai_compat, handle_status_openai_compat, ImageFormat,
};
Expand Down Expand Up @@ -240,9 +241,15 @@ impl Provider for OpenAiProvider {
let payload = create_request(model_config, system, messages, tools, &ImageFormat::OpenAi)?;

let mut log = RequestLog::start(&self.model, &payload)?;
let json_response = self.post(&payload).await.inspect_err(|e| {
let _ = log.error(e);
})?;
let json_response = self
.with_retry(|| async {
let payload_clone = payload.clone();
self.post(&payload_clone).await
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;

let message = response_to_message(&json_response)?;
let usage = json_response
Expand All @@ -260,19 +267,30 @@ impl Provider for OpenAiProvider {

async fn fetch_supported_models(&self) -> Result<Option<Vec<String>>, ProviderError> {
let models_path = self.base_path.replace("v1/chat/completions", "v1/models");
let response = self.api_client.response_get(&models_path).await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
}
let response = self
.with_retry(|| async {
let response = self.api_client.response_get(&models_path).await?;
let json = handle_response_openai_compat(response).await?;
if let Some(err_obj) = json.get("error") {
let msg = err_obj
.get("message")
.and_then(|v| v.as_str())
.unwrap_or("unknown error");
return Err(ProviderError::Authentication(msg.to_string()));
}
Ok(json)
})
.await
.inspect_err(|e| {
tracing::warn!("Failed to fetch supported models from OpenAI: {:?}", e);
})?;

let data = json.get("data").and_then(|v| v.as_array()).ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let data = response
.get("data")
.and_then(|v| v.as_array())
.ok_or_else(|| {
ProviderError::UsageError("Missing data field in JSON response".into())
})?;
let mut models: Vec<String> = data
.iter()
.filter_map(|m| m.get("id").and_then(|v| v.as_str()).map(str::to_string))
Expand Down Expand Up @@ -310,17 +328,24 @@ impl Provider for OpenAiProvider {
let mut log = RequestLog::start(&self.model, &payload)?;

let response = self
.api_client
.response_post(&self.base_path, &payload)
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
let response = handle_status_openai_compat(response)
.with_retry(|| async {
let resp = self
.api_client
.response_post(&self.base_path, &payload)
.await?;
let status = resp.status();
if !status.is_success() {
return Err(super::utils::map_http_error_to_provider_error(
status, None, // We'll let handle_status_openai_compat parse the error
));
}
Ok(resp)
})
.await
.inspect_err(|e| {
let _ = log.error(e);
})?;
let response = handle_status_openai_compat(response).await?;

let stream = response.bytes_stream().map_err(io::Error::other);

Expand Down Expand Up @@ -366,8 +391,18 @@ impl EmbeddingCapable for OpenAiProvider {
};

let response = self
.api_client
.api_post("v1/embeddings", &serde_json::to_value(request)?)
.with_retry(|| async {
let request_clone = EmbeddingRequest {
input: request.input.clone(),
model: request.model.clone(),
};
let request_value = serde_json::to_value(request_clone)
.map_err(|e| ProviderError::ExecutionError(e.to_string()))?;
self.api_client
.api_post("v1/embeddings", &request_value)
.await
.map_err(|e| ProviderError::ExecutionError(e.to_string()))
})
.await?;

if response.status != StatusCode::OK {
Expand Down