Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions crates/goose/src/providers/openai.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,12 @@ impl_provider_default!(OpenAiProvider);

impl OpenAiProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let model = model.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string());

let config = crate::config::Config::global();
let api_key: String = config.get_secret("OPENAI_API_KEY")?;
let host: String = config
.get_param("OPENAI_HOST")
.unwrap_or_else(|_| "https://api.openai.com".to_string());

let base_path: String = config
.get_param("OPENAI_BASE_PATH")
.unwrap_or_else(|_| "v1/chat/completions".to_string());
Expand All @@ -80,8 +79,11 @@ impl OpenAiProvider {
let timeout_secs: u64 = config.get_param("OPENAI_TIMEOUT").unwrap_or(600);

let auth = AuthMethod::BearerToken(api_key);
let mut api_client =
ApiClient::with_timeout(host, auth, std::time::Duration::from_secs(timeout_secs))?;
let mut api_client = ApiClient::with_timeout(
host.clone(),
auth,
std::time::Duration::from_secs(timeout_secs),
)?;

if let Some(org) = &organization {
api_client = api_client.with_header("OpenAI-Organization", org)?;
Expand All @@ -101,15 +103,44 @@ impl OpenAiProvider {
api_client = api_client.with_headers(header_map)?;
}

Ok(Self {
let mut provider = Self {
api_client,
base_path,
organization,
project,
model,
model: model.clone(),
custom_headers,
supports_streaming: true,
})
};

let model_with_fast = tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
if let Ok(Some(models)) = provider.fetch_supported_models().await {
if models.contains(&OPEN_AI_DEFAULT_FAST_MODEL.to_string()) {
tracing::debug!(
"Found {} in OpenAI workspace, setting as fast model",
OPEN_AI_DEFAULT_FAST_MODEL
);
provider
.model
.clone()
.with_fast(OPEN_AI_DEFAULT_FAST_MODEL.to_string())
} else {
tracing::debug!(
"{} not found in OpenAI workspace, not setting fast model",
OPEN_AI_DEFAULT_FAST_MODEL
);
provider.model.clone()
}
} else {
tracing::debug!("Could not fetch OpenAI models, not setting fast model");
provider.model.clone()
}
})
});

provider.model = model_with_fast;
Ok(provider)
}

pub fn from_custom_config(model: ModelConfig, config: CustomProviderConfig) -> Result<Self> {
Expand Down
Loading