Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion crates/goose-cli/src/commands/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ pub async fn handle_configure(
provider: provider_name.to_string(),
model: model.clone(),
additional_systems,
temperature: None,
context_limit: None,
max_tokens: None,
estimate_factor: None,
};

// Confirm everything is configured correctly by calling a model!
let provider_config = get_provider_config(&provider_name, model.clone());
let provider_config = get_provider_config(&provider_name, profile.clone());
let spin = spinner();
spin.start("Checking your configuration...");
let provider = factory::get_provider(provider_config).unwrap();
Expand Down
3 changes: 1 addition & 2 deletions crates/goose-cli/src/commands/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ pub fn build_session<'a>(

let loaded_profile = load_profile(profile);

let provider_config =
get_provider_config(&loaded_profile.provider, loaded_profile.model.clone());
let provider_config = get_provider_config(&loaded_profile.provider, (*loaded_profile).clone());

// TODO: Odd to be prepping the provider rather than having that done in the agent?
let provider = factory::get_provider(provider_config).unwrap();
Expand Down
79 changes: 62 additions & 17 deletions crates/goose-cli/src/profile.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use anyhow::Result;
use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy};
use goose::providers::configs::{
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig,
OpenAiProviderConfig, ProviderConfig,
AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, ModelConfig,
OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand All @@ -16,6 +16,10 @@ pub struct Profile {
pub model: String,
#[serde(default)]
pub additional_systems: Vec<AdditionalSystem>,
pub temperature: Option<f32>,
pub context_limit: Option<usize>,
pub max_tokens: Option<i32>,
pub estimate_factor: Option<f32>,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -71,7 +75,13 @@ pub fn has_no_profiles() -> Result<bool> {
load_profiles().map(|profiles| Ok(profiles.is_empty()))?
}

pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig {
pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderConfig {
let model_config = ModelConfig::new(profile.model)
.with_context_limit(profile.context_limit)
.with_temperature(profile.temperature)
.with_max_tokens(profile.max_tokens)
.with_estimate_factor(profile.estimate_factor);

match provider_name.to_lowercase().as_str() {
"openai" => {
// TODO error propagation throughout the CLI
Expand All @@ -81,9 +91,7 @@ pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig
ProviderConfig::OpenAi(OpenAiProviderConfig {
host: "https://api.openai.com".to_string(),
api_key,
model,
temperature: None,
max_tokens: None,
model: model_config,
})
}
"databricks" => {
Expand All @@ -94,34 +102,71 @@ pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig
host: host.clone(),
// TODO revisit configuration
auth: DatabricksAuth::oauth(host),
model,
temperature: None,
max_tokens: None,
model: model_config,
image_format: goose::providers::utils::ImageFormat::Anthropic,
})
}
"ollama" => {
let host = get_keyring_secret("OLLAMA_HOST", KeyRetrievalStrategy::Both)
.expect("OLLAMA_HOST not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Ollama(OllamaProviderConfig {
host: host.clone(),
model,
temperature: None,
max_tokens: None,
host,
model: model_config,
})
}
"anthropic" => {
let api_key = get_keyring_secret("ANTHROPIC_API_KEY", KeyRetrievalStrategy::Both)
.expect("ANTHROPIC_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`");

ProviderConfig::Anthropic(AnthropicProviderConfig {
host: "https://api.anthropic.com".to_string(), // Default Anthropic API endpoint
host: "https://api.anthropic.com".to_string(),
api_key,
model,
temperature: None,
max_tokens: None,
model: model_config,
})
}
_ => panic!("Invalid provider name"),
}
}

#[cfg(test)]
mod tests {
use goose::providers::configs::ProviderModelConfig;

use crate::test_helpers::run_profile_with_tmp_dir;

use super::*;

#[test]
fn test_partial_profile_config() -> Result<()> {
let profile = r#"
{
"profile_items": {
"default": {
"provider": "databricks",
"model": "claude-3",
"temperature": 0.7,
"context_limit": 50000
}
}
}
"#;
run_profile_with_tmp_dir(profile, || {
let profiles = load_profiles()?;
let profile = profiles.get("default").unwrap();

assert_eq!(profile.temperature, Some(0.7));
assert_eq!(profile.context_limit, Some(50_000));
assert_eq!(profile.max_tokens, None);
assert_eq!(profile.estimate_factor, None);

let provider_config = get_provider_config(&profile.provider, profile.clone());

if let ProviderConfig::Databricks(config) = provider_config {
assert_eq!(config.model_config().estimate_factor(), 0.8);
assert_eq!(config.model_config().context_limit(), 50_000);
}
Ok(())
})
}
}
30 changes: 25 additions & 5 deletions crates/goose-cli/src/test_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,25 @@ pub fn run_with_tmp_dir<F: FnOnce() -> T, T>(func: F) -> T {

let temp_dir = tempdir().unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
setup_profile(&temp_dir_path);
setup_profile(&temp_dir_path, None);

temp_env::with_vars(
[
("HOME", Some(temp_dir_path.as_os_str())),
("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))),
],
func,
)
}

#[cfg(test)]
pub fn run_profile_with_tmp_dir<F: FnOnce() -> T, T>(profile: &str, func: F) -> T {
use std::ffi::OsStr;
use tempfile::tempdir;

let temp_dir = tempdir().unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
setup_profile(&temp_dir_path, Some(profile));

temp_env::with_vars(
[
Expand All @@ -29,7 +47,7 @@ where

let temp_dir = tempdir().unwrap();
let temp_dir_path = temp_dir.path().to_path_buf();
setup_profile(&temp_dir_path);
setup_profile(&temp_dir_path, None);

temp_env::async_with_vars(
[
Expand All @@ -44,15 +62,16 @@ where
#[cfg(test)]
use std::path::PathBuf;
#[cfg(test)]
fn setup_profile(temp_dir_path: &PathBuf) {
/// Setup a goose profile for testing, and an optional profile string
fn setup_profile(temp_dir_path: &PathBuf, profile_string: Option<&str>) {
use std::fs;

let profile_path = temp_dir_path
.join(".config")
.join("goose")
.join("profiles.json");
fs::create_dir_all(profile_path.parent().unwrap()).unwrap();
let profile = r#"
let default_profile = r#"
{
"profile_items": {
"default": {
Expand All @@ -62,5 +81,6 @@ fn setup_profile(temp_dir_path: &PathBuf) {
}
}
}"#;
fs::write(&profile_path, profile).unwrap();

fs::write(&profile_path, profile_string.unwrap_or(default_profile)).unwrap();
}
Loading
Loading