diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 3f493dcd8f6e..d63f93199147 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -60,13 +60,12 @@ impl_provider_default!(OpenAiProvider); impl OpenAiProvider { pub fn from_env(model: ModelConfig) -> Result { - let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string()); - let config = crate::config::Config::global(); let api_key: String = config.get_secret("OPENAI_API_KEY")?; let host: String = config .get_param("OPENAI_HOST") .unwrap_or_else(|_| "https://api.openai.com".to_string()); + let base_path: String = config .get_param("OPENAI_BASE_PATH") .unwrap_or_else(|_| "v1/chat/completions".to_string()); @@ -80,8 +79,11 @@ impl OpenAiProvider { let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600); let auth = AuthMethod::BearerToken(api_key); - let mut api_client = - ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?; + let mut api_client = ApiClient::with_timeout( + host.clone(), + auth, + std::time::Duration::from_secs(timeout_secs), + )?; if let Some(org) = &organization { api_client = api_client.with_header("OpenAI-Organization", org)?; @@ -101,15 +103,44 @@ impl OpenAiProvider { api_client = api_client.with_headers(header_map)?; } - Ok(Self { + let mut provider = Self { api_client, base_path, organization, project, - model, + model: model.clone(), custom_headers, supports_streaming: true, - }) + }; + + let model_with_fast = tokio::task::block_in_place(|| { + tokio::runtime::Handle::current().block_on(async { + if let Ok(Some(models)) = provider.fetch_supported_models().await { + if models.contains(&OPEN_AI_DEFAULT_FAST_MODEL.to_string()) { + tracing::debug!( + "Found {} in OpenAI workspace, setting as fast model", + OPEN_AI_DEFAULT_FAST_MODEL + ); + provider + .model + .clone() + .with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string()) + } else { + tracing::debug!( + "{} not found in OpenAI workspace, not setting fast model", + OPEN_AI_DEFAULT_FAST_MODEL + ); + provider.model.clone() + } + } else { + tracing::debug!("Could not fetch OpenAI models, not setting fast model"); + provider.model.clone() + } + }) + }); + + provider.model = model_with_fast; + Ok(provider) } pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result {