diff --git a/crates/goose-cli/src/commands/configure.rs b/crates/goose-cli/src/commands/configure.rs index 67a5bc1f7b8d..c611e0595601 100644 --- a/crates/goose-cli/src/commands/configure.rs +++ b/crates/goose-cli/src/commands/configure.rs @@ -110,10 +110,14 @@ pub async fn handle_configure( provider: provider_name.to_string(), model: model.clone(), additional_systems, + temperature: None, + context_limit: None, + max_tokens: None, + estimate_factor: None, }; // Confirm everything is configured correctly by calling a model! - let provider_config = get_provider_config(&provider_name, model.clone()); + let provider_config = get_provider_config(&provider_name, profile.clone()); let spin = spinner(); spin.start("Checking your configuration..."); let provider = factory::get_provider(provider_config).unwrap(); diff --git a/crates/goose-cli/src/commands/session.rs b/crates/goose-cli/src/commands/session.rs index ea8c8fcb4733..7de8bf40edc5 100644 --- a/crates/goose-cli/src/commands/session.rs +++ b/crates/goose-cli/src/commands/session.rs @@ -42,8 +42,7 @@ pub fn build_session<'a>( let loaded_profile = load_profile(profile); - let provider_config = - get_provider_config(&loaded_profile.provider, loaded_profile.model.clone()); + let provider_config = get_provider_config(&loaded_profile.provider, (*loaded_profile).clone()); // TODO: Odd to be prepping the provider rather than having that done in the agent? let provider = factory::get_provider(provider_config).unwrap(); diff --git a/crates/goose-cli/src/profile.rs b/crates/goose-cli/src/profile.rs index c1bcb33e26c0..429932b0e054 100644 --- a/crates/goose-cli/src/profile.rs +++ b/crates/goose-cli/src/profile.rs @@ -1,8 +1,8 @@ use anyhow::Result; use goose::key_manager::{get_keyring_secret, KeyRetrievalStrategy}; use goose::providers::configs::{ - AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, - OpenAiProviderConfig, ProviderConfig, + AnthropicProviderConfig, DatabricksAuth, DatabricksProviderConfig, ModelConfig, + OllamaProviderConfig, OpenAiProviderConfig, ProviderConfig, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -16,6 +16,10 @@ pub struct Profile { pub model: String, #[serde(default)] pub additional_systems: Vec, + pub temperature: Option, + pub context_limit: Option, + pub max_tokens: Option, + pub estimate_factor: Option, } #[derive(Serialize, Deserialize)] @@ -71,7 +75,13 @@ pub fn has_no_profiles() -> Result { load_profiles().map(|profiles| Ok(profiles.is_empty()))? } -pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig { +pub fn get_provider_config(provider_name: &str, profile: Profile) -> ProviderConfig { + let model_config = ModelConfig::new(profile.model) + .with_context_limit(profile.context_limit) + .with_temperature(profile.temperature) + .with_max_tokens(profile.max_tokens) + .with_estimate_factor(profile.estimate_factor); + match provider_name.to_lowercase().as_str() { "openai" => { // TODO error propagation throughout the CLI @@ -81,9 +91,7 @@ pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig ProviderConfig::OpenAi(OpenAiProviderConfig { host: "https://api.openai.com".to_string(), api_key, - model, - temperature: None, - max_tokens: None, + model: model_config, }) } "databricks" => { @@ -94,20 +102,17 @@ pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig host: host.clone(), // TODO revisit configuration auth: DatabricksAuth::oauth(host), - model, - temperature: None, - max_tokens: None, + model: model_config, image_format: goose::providers::utils::ImageFormat::Anthropic, }) } "ollama" => { let host = get_keyring_secret("OLLAMA_HOST", KeyRetrievalStrategy::Both) .expect("OLLAMA_HOST not available in env or the keychain\nSet an env var or rerun `goose configure`"); + ProviderConfig::Ollama(OllamaProviderConfig { - host: host.clone(), - model, - temperature: None, - max_tokens: None, + host, + model: model_config, }) } "anthropic" => { @@ -115,13 +120,53 @@ pub fn get_provider_config(provider_name: &str, model: String) -> ProviderConfig .expect("ANTHROPIC_API_KEY not available in env or the keychain\nSet an env var or rerun `goose configure`"); ProviderConfig::Anthropic(AnthropicProviderConfig { - host: "https://api.anthropic.com".to_string(), // Default Anthropic API endpoint + host: "https://api.anthropic.com".to_string(), api_key, - model, - temperature: None, - max_tokens: None, + model: model_config, }) } _ => panic!("Invalid provider name"), } } + +#[cfg(test)] +mod tests { + use goose::providers::configs::ProviderModelConfig; + + use crate::test_helpers::run_profile_with_tmp_dir; + + use super::*; + + #[test] + fn test_partial_profile_config() -> Result<()> { + let profile = r#" +{ + "profile_items": { + "default": { + "provider": "databricks", + "model": "claude-3", + "temperature": 0.7, + "context_limit": 50000 + } + } +} +"#; + run_profile_with_tmp_dir(profile, || { + let profiles = load_profiles()?; + let profile = profiles.get("default").unwrap(); + + assert_eq!(profile.temperature, Some(0.7)); + assert_eq!(profile.context_limit, Some(50_000)); + assert_eq!(profile.max_tokens, None); + assert_eq!(profile.estimate_factor, None); + + let provider_config = get_provider_config(&profile.provider, profile.clone()); + + if let ProviderConfig::Databricks(config) = provider_config { + assert_eq!(config.model_config().estimate_factor(), 0.8); + assert_eq!(config.model_config().context_limit(), 50_000); + } + Ok(()) + }) + } +} diff --git a/crates/goose-cli/src/test_helpers.rs b/crates/goose-cli/src/test_helpers.rs index 8e611da0b15b..9ae3bc440af9 100644 --- a/crates/goose-cli/src/test_helpers.rs +++ b/crates/goose-cli/src/test_helpers.rs @@ -7,7 +7,25 @@ pub fn run_with_tmp_dir T, T>(func: F) -> T { let temp_dir = tempdir().unwrap(); let temp_dir_path = temp_dir.path().to_path_buf(); - setup_profile(&temp_dir_path); + setup_profile(&temp_dir_path, None); + + temp_env::with_vars( + [ + ("HOME", Some(temp_dir_path.as_os_str())), + ("DATABRICKS_HOST", Some(OsStr::new("tmp_host_url"))), + ], + func, + ) +} + +#[cfg(test)] +pub fn run_profile_with_tmp_dir T, T>(profile: &str, func: F) -> T { + use std::ffi::OsStr; + use tempfile::tempdir; + + let temp_dir = tempdir().unwrap(); + let temp_dir_path = temp_dir.path().to_path_buf(); + setup_profile(&temp_dir_path, Some(profile)); temp_env::with_vars( [ @@ -29,7 +47,7 @@ where let temp_dir = tempdir().unwrap(); let temp_dir_path = temp_dir.path().to_path_buf(); - setup_profile(&temp_dir_path); + setup_profile(&temp_dir_path, None); temp_env::async_with_vars( [ @@ -44,7 +62,8 @@ where #[cfg(test)] use std::path::PathBuf; #[cfg(test)] -fn setup_profile(temp_dir_path: &PathBuf) { +/// Setup a goose profile for testing, and an optional profile string +fn setup_profile(temp_dir_path: &PathBuf, profile_string: Option<&str>) { use std::fs; let profile_path = temp_dir_path @@ -52,7 +71,7 @@ fn setup_profile(temp_dir_path: &PathBuf) { .join("goose") .join("profiles.json"); fs::create_dir_all(profile_path.parent().unwrap()).unwrap(); - let profile = r#" + let default_profile = r#" { "profile_items": { "default": { @@ -62,5 +81,6 @@ fn setup_profile(temp_dir_path: &PathBuf) { } } }"#; - fs::write(&profile_path, profile).unwrap(); + + fs::write(&profile_path, profile_string.unwrap_or(default_profile)).unwrap(); } diff --git a/crates/goose-server/src/configuration.rs b/crates/goose-server/src/configuration.rs index c4e2d706c8fe..047cb379a36e 100644 --- a/crates/goose-server/src/configuration.rs +++ b/crates/goose-server/src/configuration.rs @@ -2,8 +2,8 @@ use crate::error::{to_env_var, ConfigError}; use config::{Config, Environment}; use goose::providers::{ configs::{ - DatabricksAuth, DatabricksProviderConfig, OllamaProviderConfig, OpenAiProviderConfig, - ProviderConfig, + DatabricksAuth, DatabricksProviderConfig, ModelConfig, OllamaProviderConfig, + OpenAiProviderConfig, ProviderConfig, }, factory::ProviderType, ollama, @@ -41,6 +41,10 @@ pub enum ProviderSettings { temperature: Option, #[serde(default)] max_tokens: Option, + #[serde(default)] + context_limit: Option, + #[serde(default)] + estimate_factor: Option, }, Databricks { #[serde(default = "default_databricks_host")] @@ -51,6 +55,10 @@ pub enum ProviderSettings { temperature: Option, #[serde(default)] max_tokens: Option, + #[serde(default)] + context_limit: Option, + #[serde(default)] + estimate_factor: Option, #[serde(default = "default_image_format")] image_format: ImageFormat, }, @@ -63,6 +71,10 @@ pub enum ProviderSettings { temperature: Option, #[serde(default)] max_tokens: Option, + #[serde(default)] + context_limit: Option, + #[serde(default)] + estimate_factor: Option, }, } @@ -86,25 +98,33 @@ impl ProviderSettings { model, temperature, max_tokens, + context_limit, + estimate_factor, } => ProviderConfig::OpenAi(OpenAiProviderConfig { host, api_key, - model, - temperature, - max_tokens, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens) + .with_context_limit(context_limit) + .with_estimate_factor(estimate_factor), }), ProviderSettings::Databricks { host, model, temperature, max_tokens, + context_limit, image_format, + estimate_factor, } => ProviderConfig::Databricks(DatabricksProviderConfig { host: host.clone(), auth: DatabricksAuth::oauth(host), - model, - temperature, - max_tokens, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens) + .with_context_limit(context_limit) + .with_estimate_factor(estimate_factor), image_format, }), ProviderSettings::Ollama { @@ -112,11 +132,15 @@ impl ProviderSettings { model, temperature, max_tokens, + context_limit, + estimate_factor, } => ProviderConfig::Ollama(OllamaProviderConfig { host, - model, - temperature, - max_tokens, + model: ModelConfig::new(model) + .with_temperature(temperature) + .with_max_tokens(max_tokens) + .with_context_limit(context_limit) + .with_estimate_factor(estimate_factor), }), } } @@ -246,6 +270,8 @@ mod tests { model, temperature, max_tokens, + context_limit, + estimate_factor, } = settings.provider { assert_eq!(host, "https://api.openai.com"); @@ -253,6 +279,8 @@ mod tests { assert_eq!(model, "gpt-4o"); assert_eq!(temperature, None); assert_eq!(max_tokens, None); + assert_eq!(context_limit, None); + assert_eq!(estimate_factor, None); } else { panic!("Expected OpenAI provider"); } @@ -262,6 +290,33 @@ mod tests { env::remove_var("GOOSE_PROVIDER__API_KEY"); } + #[test] + #[serial] + fn test_into_config_conversion() { + // Test OpenAI conversion + let settings = ProviderSettings::OpenAi { + host: "https://api.openai.com".to_string(), + api_key: "test-key".to_string(), + model: "gpt-4o".to_string(), + temperature: Some(0.7), + max_tokens: Some(1000), + context_limit: Some(150_000), + estimate_factor: Some(0.8), + }; + + if let ProviderConfig::OpenAi(config) = settings.into_config() { + assert_eq!(config.host, "https://api.openai.com"); + assert_eq!(config.api_key, "test-key"); + assert_eq!(config.model.model_name, "gpt-4o"); + assert_eq!(config.model.temperature, Some(0.7)); + assert_eq!(config.model.max_tokens, Some(1000)); + assert_eq!(config.model.context_limit, Some(150_000)); + assert_eq!(config.model.estimate_factor, Some(0.8)); + } else { + panic!("Expected OpenAI config"); + } + } + #[test] #[serial] fn test_databricks_settings() { @@ -271,6 +326,7 @@ mod tests { env::set_var("GOOSE_PROVIDER__MODEL", "llama-2-70b"); env::set_var("GOOSE_PROVIDER__TEMPERATURE", "0.7"); env::set_var("GOOSE_PROVIDER__MAX_TOKENS", "2000"); + env::set_var("GOOSE_PROVIDER__CONTEXT_LIMIT", "150000"); let settings = Settings::new().unwrap(); if let ProviderSettings::Databricks { @@ -278,6 +334,8 @@ mod tests { model, temperature, max_tokens, + context_limit, + estimate_factor, image_format: _, } = settings.provider { @@ -285,6 +343,8 @@ mod tests { assert_eq!(model, "llama-2-70b"); assert_eq!(temperature, Some(0.7)); assert_eq!(max_tokens, Some(2000)); + assert_eq!(context_limit, Some(150000)); + assert_eq!(estimate_factor, None); } else { panic!("Expected Databricks provider"); } @@ -295,6 +355,7 @@ mod tests { env::remove_var("GOOSE_PROVIDER__MODEL"); env::remove_var("GOOSE_PROVIDER__TEMPERATURE"); env::remove_var("GOOSE_PROVIDER__MAX_TOKENS"); + env::remove_var("GOOSE_PROVIDER__CONTEXT_LIMIT"); } #[test] @@ -306,6 +367,8 @@ mod tests { env::set_var("GOOSE_PROVIDER__MODEL", "llama2"); env::set_var("GOOSE_PROVIDER__TEMPERATURE", "0.7"); env::set_var("GOOSE_PROVIDER__MAX_TOKENS", "2000"); + env::set_var("GOOSE_PROVIDER__CONTEXT_LIMIT", "150000"); + env::set_var("GOOSE_PROVIDER__ESTIMATE_FACTOR", "0.7"); let settings = Settings::new().unwrap(); if let ProviderSettings::Ollama { @@ -313,12 +376,16 @@ mod tests { model, temperature, max_tokens, + context_limit, + estimate_factor, } = settings.provider { assert_eq!(host, "http://custom.ollama.host"); assert_eq!(model, "llama2"); assert_eq!(temperature, Some(0.7)); assert_eq!(max_tokens, Some(2000)); + assert_eq!(context_limit, Some(150000)); + assert_eq!(estimate_factor, Some(0.7)); } else { panic!("Expected Ollama provider"); } @@ -329,6 +396,8 @@ mod tests { env::remove_var("GOOSE_PROVIDER__MODEL"); env::remove_var("GOOSE_PROVIDER__TEMPERATURE"); env::remove_var("GOOSE_PROVIDER__MAX_TOKENS"); + env::remove_var("GOOSE_PROVIDER__CONTEXT_LIMIT"); + env::remove_var("GOOSE_PROVIDER__ESTIMATE_FACTOR"); } #[test] @@ -341,6 +410,7 @@ mod tests { env::set_var("GOOSE_PROVIDER__HOST", "https://custom.openai.com"); env::set_var("GOOSE_PROVIDER__MODEL", "gpt-3.5-turbo"); env::set_var("GOOSE_PROVIDER__TEMPERATURE", "0.8"); + env::set_var("GOOSE_PROVIDER__CONTEXT_LIMIT", "150000"); let settings = Settings::new().unwrap(); assert_eq!(settings.server.port, 8080); @@ -350,6 +420,7 @@ mod tests { api_key, model, temperature, + context_limit, .. } = settings.provider { @@ -357,6 +428,7 @@ mod tests { assert_eq!(api_key, "test-key"); assert_eq!(model, "gpt-3.5-turbo"); assert_eq!(temperature, Some(0.8)); + assert_eq!(context_limit, Some(150000)); } else { panic!("Expected OpenAI provider"); } @@ -368,6 +440,7 @@ mod tests { env::remove_var("GOOSE_PROVIDER__HOST"); env::remove_var("GOOSE_PROVIDER__MODEL"); env::remove_var("GOOSE_PROVIDER__TEMPERATURE"); + env::remove_var("GOOSE_PROVIDER__CONTEXT_LIMIT"); } #[test] diff --git a/crates/goose-server/src/routes/reply.rs b/crates/goose-server/src/routes/reply.rs index c5e07e2878d0..ae4f89de418b 100644 --- a/crates/goose-server/src/routes/reply.rs +++ b/crates/goose-server/src/routes/reply.rs @@ -393,14 +393,16 @@ mod tests { agent::Agent, providers::{ base::{Provider, ProviderUsage, Usage}, - configs::OpenAiProviderConfig, + configs::{ModelConfig, OpenAiProviderConfig}, }, }; use mcp_core::tool::Tool; // Mock Provider implementation for testing #[derive(Clone)] - struct MockProvider; + struct MockProvider { + model_config: ModelConfig, + } #[async_trait::async_trait] impl Provider for MockProvider { @@ -415,6 +417,10 @@ mod tests { ProviderUsage::new("mock".to_string(), Usage::default(), None), )) } + + fn get_model_config(&self) -> &ModelConfig { + &self.model_config + } } #[test] @@ -493,7 +499,7 @@ mod tests { mod integration_tests { use super::*; use axum::{body::Body, http::Request}; - use goose::providers::configs::ProviderConfig; + use goose::providers::configs::{ModelConfig, ProviderConfig}; use std::sync::Arc; use tokio::sync::Mutex; use tower::ServiceExt; @@ -502,16 +508,17 @@ mod tests { #[tokio::test] async fn test_ask_endpoint() { // Create a mock app state with mock provider - let mock_provider = Box::new(MockProvider); + let mock_model_config = ModelConfig::new("test-model".to_string()); + let mock_provider = Box::new(MockProvider { + model_config: mock_model_config, + }); let agent = Agent::new(mock_provider); let state = AppState { agent: Arc::new(Mutex::new(agent)), provider_config: ProviderConfig::OpenAi(OpenAiProviderConfig { host: "https://api.openai.com".to_string(), api_key: "test-key".to_string(), - model: "test-model".to_string(), - temperature: None, - max_tokens: None, + model: ModelConfig::new("test-model".to_string()), }), secret_key: "test-secret".to_string(), }; diff --git a/crates/goose-server/src/state.rs b/crates/goose-server/src/state.rs index 10ad3e4e2502..9ab997715f88 100644 --- a/crates/goose-server/src/state.rs +++ b/crates/goose-server/src/state.rs @@ -41,8 +41,6 @@ impl Clone for AppState { host: config.host.clone(), api_key: config.api_key.clone(), model: config.model.clone(), - temperature: config.temperature, - max_tokens: config.max_tokens, }) } ProviderConfig::Databricks(config) => ProviderConfig::Databricks( @@ -50,8 +48,6 @@ impl Clone for AppState { host: config.host.clone(), auth: config.auth.clone(), model: config.model.clone(), - temperature: config.temperature, - max_tokens: config.max_tokens, image_format: config.image_format, }, ), @@ -59,8 +55,6 @@ impl Clone for AppState { ProviderConfig::Ollama(goose::providers::configs::OllamaProviderConfig { host: config.host.clone(), model: config.model.clone(), - temperature: config.temperature, - max_tokens: config.max_tokens, }) } ProviderConfig::Anthropic(config) => { @@ -68,8 +62,6 @@ impl Clone for AppState { host: config.host.clone(), api_key: config.api_key.clone(), model: config.model.clone(), - temperature: config.temperature, - max_tokens: config.max_tokens, }) } }, diff --git a/crates/goose/examples/image_tool.rs b/crates/goose/examples/image_tool.rs index 25baacb744e8..ed86684b9d56 100644 --- a/crates/goose/examples/image_tool.rs +++ b/crates/goose/examples/image_tool.rs @@ -4,7 +4,7 @@ use dotenv::dotenv; use goose::{ message::Message, providers::{ - configs::{DatabricksProviderConfig, OpenAiProviderConfig, ProviderConfig}, + configs::{DatabricksProviderConfig, ModelConfig, OpenAiProviderConfig, ProviderConfig}, factory::get_provider, }, }; @@ -34,9 +34,7 @@ async fn main() -> Result<()> { let config2 = ProviderConfig::OpenAi(OpenAiProviderConfig { host: "https://api.openai.com".into(), api_key, - model: "gpt-4o".into(), - temperature: None, - max_tokens: None, + model: ModelConfig::new("gpt-4o".into()), }); for config in [config1, config2] { diff --git a/crates/goose/src/agent.rs b/crates/goose/src/agent.rs index bf82b9823cee..5c02bf3be6a8 100644 --- a/crates/goose/src/agent.rs +++ b/crates/goose/src/agent.rs @@ -15,10 +15,6 @@ use crate::token_counter::TokenCounter; use mcp_core::{Content, Resource, Tool, ToolCall}; use serde::Serialize; -const CONTEXT_LIMIT: usize = 200_000; // TODO: model's context limit should be in provider config -const ESTIMATE_FACTOR: f32 = 0.8; -const ESTIMATED_TOKEN_LIMIT: usize = (CONTEXT_LIMIT as f32 * ESTIMATE_FACTOR) as usize; - // used to sort resources by priority within error margin const PRIORITY_EPSILON: f32 = 0.001; @@ -77,6 +73,11 @@ impl Agent { self.systems.push(system); } + /// Get the context limit from the provider's configuration + fn get_context_limit(&self) -> usize { + self.provider.get_model_config().context_limit() + } + /// Get all tools from all systems with proper system prefixing fn get_prefixed_tools(&self) -> Vec { let mut tools = Vec::new(); @@ -206,7 +207,7 @@ impl Agent { messages, tools, &resources, - Some("gpt-4"), + Some(&self.provider.get_model_config().model_name), ); let mut status_content: Vec = Vec::new(); @@ -221,7 +222,9 @@ impl Agent { for (system_name, resources) in &resource_content { let mut resource_counts = HashMap::new(); for (uri, (_resource, content)) in resources { - let token_count = token_counter.count_tokens(content, Some("gpt-4")) as u32; + let token_count = token_counter + .count_tokens(&content, Some(&self.provider.get_model_config().model_name)) + as u32; resource_counts.insert(uri.clone(), token_count); } system_token_counts.insert(system_name.clone(), resource_counts); @@ -327,6 +330,7 @@ impl Agent { let mut messages = messages.to_vec(); let tools = self.get_prefixed_tools(); let system_prompt = self.get_system_prompt()?; + let estimated_limit = self.provider.get_model_config().get_estimated_limit(); // Update conversation history for the start of the reply messages = self @@ -335,13 +339,12 @@ impl Agent { &tools, &messages, &Vec::new(), - ESTIMATED_TOKEN_LIMIT, + estimated_limit, ) .await?; Ok(Box::pin(async_stream::try_stream! { loop { - // Get completion from provider let (response, usage) = self.provider.complete( &system_prompt, @@ -396,7 +399,7 @@ impl Agent { messages.pop(); let pending = vec![response, message_tool_response]; - messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, ESTIMATED_TOKEN_LIMIT).await?; + messages = self.prepare_inference(&system_prompt, &tools, &messages, &pending, estimated_limit).await?; } })) } @@ -434,6 +437,7 @@ impl Agent { mod tests { use super::*; use crate::message::MessageContent; + use crate::providers::configs::ModelConfig; use crate::providers::mock::MockProvider; use async_trait::async_trait; use chrono::Utc; @@ -725,4 +729,53 @@ mod tests { assert!(status_content.contains("low_priority")); Ok(()) } + + #[tokio::test] + async fn test_context_trimming_with_custom_model_config() -> Result<()> { + let provider = MockProvider::with_config( + vec![], + ModelConfig::new("test_model".to_string()).with_context_limit(Some(20)), + ); + let mut agent = Agent::new(Box::new(provider)); + + // Create a mock system with a resource that will exceed the context limit + let mut system = MockSystem::new("test"); + + // Add a resource that will exceed our tiny context limit + let hello_1_tokens = "hello ".repeat(1); // 1 tokens + let goodbye_10_tokens = "goodbye ".repeat(10); // 10 tokens + system.add_resource("test_resource_removed", &goodbye_10_tokens, 0.1); + system.add_resource("test_resource_expected", &hello_1_tokens, 0.5); + + agent.add_system(Box::new(system)); + + // Set up test parameters + // 18 tokens with system + user msg in chat format + let system_prompt = "This is a system prompt"; + let messages = vec![Message::user().with_text("Hi there")]; + let tools = vec![]; + let pending = vec![]; + + // Use the context limit from the model config + let target_limit = agent.get_context_limit(); + assert_eq!(target_limit, 20, "Context limit should be 20"); + + // Call prepare_inference + let result = agent + .prepare_inference(system_prompt, &tools, &messages, &pending, target_limit) + .await?; + + // Get the last message which should be the tool response containing status + let status_message = result.last().unwrap(); + let status_content = status_message + .content + .first() + .and_then(|content| content.as_tool_response_text()) + .unwrap_or_default(); + + // verify that "hello" is within the response, should be just under 20 tokens with "hello" + assert!(status_content.contains("hello")); + + Ok(()) + } } diff --git a/crates/goose/src/providers/anthropic.rs b/crates/goose/src/providers/anthropic.rs index 9648d6592a53..c769eef4f4e5 100644 --- a/crates/goose/src/providers/anthropic.rs +++ b/crates/goose/src/providers/anthropic.rs @@ -8,7 +8,7 @@ use std::time::Duration; use super::base::ProviderUsage; use super::base::{Provider, Usage}; -use super::configs::AnthropicProviderConfig; +use super::configs::{AnthropicProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; use super::utils::get_model; @@ -225,9 +225,9 @@ impl Provider for AnthropicProvider { } let mut payload = json!({ - "model": self.config.model, + "model": self.config.model.model_name, "messages": anthropic_messages, - "max_tokens": self.config.max_tokens.unwrap_or(4096) + "max_tokens": self.config.model.max_tokens.unwrap_or(4096) }); // Add system message if present @@ -247,7 +247,7 @@ impl Provider for AnthropicProvider { } // Add temperature if specified - if let Some(temp) = self.config.temperature { + if let Some(temp) = self.config.model.temperature { payload .as_object_mut() .unwrap() @@ -265,10 +265,16 @@ impl Provider for AnthropicProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } + + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } } #[cfg(test)] mod tests { + use crate::providers::configs::ModelConfig; + use super::*; use rust_decimal_macros::dec; use serde_json::json; @@ -287,9 +293,9 @@ mod tests { let config = AnthropicProviderConfig { host: mock_server.uri(), api_key: "test_api_key".to_string(), - model: "claude-3-sonnet-20241022".to_string(), - temperature: Some(0.7), - max_tokens: None, + model: ModelConfig::new("claude-3-sonnet-20241022".to_string()) + .with_temperature(Some(0.7)) + .with_context_limit(Some(200_000)), }; let provider = AnthropicProvider::new(config).unwrap(); diff --git a/crates/goose/src/providers/base.rs b/crates/goose/src/providers/base.rs index 84dbb78d35b5..a6ee06eabbc8 100644 --- a/crates/goose/src/providers/base.rs +++ b/crates/goose/src/providers/base.rs @@ -2,6 +2,7 @@ use anyhow::Result; use rust_decimal::Decimal; use serde::{Deserialize, Serialize}; +use super::configs::ModelConfig; use crate::message::Message; use mcp_core::tool::Tool; @@ -51,6 +52,9 @@ use async_trait::async_trait; /// Base trait for AI providers (OpenAI, Anthropic, etc) #[async_trait] pub trait Provider: Send + Sync { + /// Get the model configuration + fn get_model_config(&self) -> &ModelConfig; + /// Generate the next message using the configured model and other parameters /// /// # Arguments diff --git a/crates/goose/src/providers/configs.rs b/crates/goose/src/providers/configs.rs index 91a827909eda..346892924810 100644 --- a/crates/goose/src/providers/configs.rs +++ b/crates/goose/src/providers/configs.rs @@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize}; const DEFAULT_CLIENT_ID: &str = "databricks-cli"; const DEFAULT_REDIRECT_URL: &str = "http://localhost:8020"; const DEFAULT_SCOPES: &[&str] = &["all-apis"]; +const DEFAULT_CONTEXT_LIMIT: usize = 200_000; +const DEFAULT_ESTIMATE_FACTOR: f32 = 0.8; #[derive(Debug, Clone, Serialize, Deserialize)] pub enum ProviderConfig { @@ -13,6 +15,117 @@ pub enum ProviderConfig { Anthropic(AnthropicProviderConfig), } +/// Configuration for model-specific settings and limits +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelConfig { + /// The name of the model to use + pub model_name: String, + /// Optional explicit context limit that overrides any defaults + pub context_limit: Option, + /// Optional temperature setting (0.0 - 1.0) + pub temperature: Option, + /// Optional maximum tokens to generate + pub max_tokens: Option, + /// Factor used to estimate safe context window size (0.0 - 1.0) + /// Defaults to 0.8 (80%) of the context limit to leave headroom for responses + pub estimate_factor: Option, +} + +impl ModelConfig { + /// Create a new ModelConfig with the specified model name + /// + /// The context limit is set with the following precedence: + /// 1. Explicit context_limit if provided in config + /// 2. Model-specific default based on model name + /// 3. Global default (128_000) (in get_context_limit) + pub fn new(model_name: String) -> Self { + let context_limit = Self::get_model_specific_limit(&model_name); + + Self { + model_name, + context_limit, + temperature: None, + max_tokens: None, + estimate_factor: None, + } + } + + /// Get model-specific context limit based on model name + fn get_model_specific_limit(model_name: &str) -> Option { + // Implement some sensible defaults + match model_name { + // OpenAI models, https://platform.openai.com/docs/models#models-overview + name if name.contains("gpt-4o") => Some(128_000), + name if name.contains("gpt-4-turbo") => Some(128_000), + + // Anthropic models, https://docs.anthropic.com/en/docs/about-claude/models + name if name.contains("claude-3") => Some(200_000), + + // Meta Llama models, https://github.com/meta-llama/llama-models/tree/main?tab=readme-ov-file#llama-models-1 + name if name.contains("llama3.2") => Some(128_000), + name if name.contains("llama3.3") => Some(128_000), + _ => None, + } + } + + /// Set an explicit context limit + pub fn with_context_limit(mut self, limit: Option) -> Self { + // Default is None and therefore DEFAULT_CONTEXT_LIMIT, only set + // if input is Some to allow passing through with_context_limit in + // configuration cases + if limit.is_some() { + self.context_limit = limit; + } + self + } + + /// Set the temperature + pub fn with_temperature(mut self, temp: Option) -> Self { + self.temperature = temp; + self + } + + /// Set the max tokens + pub fn with_max_tokens(mut self, tokens: Option) -> Self { + self.max_tokens = tokens; + self + } + + /// Set the estimate factor + pub fn with_estimate_factor(mut self, factor: Option) -> Self { + self.estimate_factor = factor; + self + } + + /// Get the context_limit for the current model + /// If none are defined, use the DEFAULT_CONTEXT_LIMIT + pub fn context_limit(&self) -> usize { + self.context_limit.unwrap_or(DEFAULT_CONTEXT_LIMIT) + } + + /// Get the estimate factor for the current model configuration + /// + /// # Returns + /// The estimate factor with the following precedence: + /// 1. Explicit estimate_factor if provided in config + /// 2. Default value (0.8) + pub fn estimate_factor(&self) -> f32 { + self.estimate_factor.unwrap_or(DEFAULT_ESTIMATE_FACTOR) + } + + /// Get the estimated limit of the context size, this is defined as + /// context_limit * estimate_factor + pub fn get_estimated_limit(&self) -> usize { + (self.context_limit() as f32 * self.estimate_factor()) as usize + } +} + +/// Base trait for provider configurations +pub trait ProviderModelConfig { + /// Get the model configuration + fn model_config(&self) -> &ModelConfig; +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DatabricksAuth { Token(String), @@ -39,61 +152,195 @@ impl DatabricksAuth { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DatabricksProviderConfig { pub host: String, - pub model: String, pub auth: DatabricksAuth, - pub temperature: Option, - pub max_tokens: Option, + pub model: ModelConfig, pub image_format: ImageFormat, } impl DatabricksProviderConfig { /// Create a new configuration with token authentication - pub fn with_token(host: String, model: String, token: String) -> Self { + pub fn with_token(host: String, model_name: String, token: String) -> Self { Self { host, - model, auth: DatabricksAuth::Token(token), - temperature: None, - max_tokens: None, + model: ModelConfig::new(model_name), image_format: ImageFormat::Anthropic, } } /// Create a new configuration with OAuth authentication using default settings - pub fn with_oauth(host: String, model: String) -> Self { + pub fn with_oauth(host: String, model_name: String) -> Self { Self { host: host.clone(), - model, auth: DatabricksAuth::oauth(host), - temperature: None, - max_tokens: None, + model: ModelConfig::new(model_name), image_format: ImageFormat::Anthropic, } } } +impl ProviderModelConfig for DatabricksProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OpenAiProviderConfig { pub host: String, pub api_key: String, - pub model: String, - pub temperature: Option, - pub max_tokens: Option, + pub model: ModelConfig, +} + +impl OpenAiProviderConfig { + pub fn new(host: String, api_key: String, model_name: String) -> Self { + Self { + host, + api_key, + model: ModelConfig::new(model_name), + } + } +} + +impl ProviderModelConfig for OpenAiProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct OllamaProviderConfig { pub host: String, - pub model: String, - pub temperature: Option, - pub max_tokens: Option, + pub model: ModelConfig, +} + +impl OllamaProviderConfig { + pub fn new(host: String, model_config: ModelConfig) -> Self { + Self { + host, + model: model_config, + } + } +} + +impl ProviderModelConfig for OllamaProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AnthropicProviderConfig { pub host: String, pub api_key: String, - pub model: String, - pub temperature: Option, - pub max_tokens: Option, + pub model: ModelConfig, +} + +impl AnthropicProviderConfig { + pub fn new(host: String, api_key: String, model_name: String) -> Self { + Self { + host, + api_key, + model: ModelConfig::new(model_name), + } + } +} + +impl ProviderModelConfig for AnthropicProviderConfig { + fn model_config(&self) -> &ModelConfig { + &self.model + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_model_config_context_limits() { + // Test explicit limit + let config = + ModelConfig::new("claude-3-opus".to_string()).with_context_limit(Some(150_000)); + assert_eq!(config.context_limit(), 150_000); + + // Test model-specific defaults + let config = ModelConfig::new("claude-3-opus".to_string()); + assert_eq!(config.context_limit(), 200_000); + + let config = ModelConfig::new("gpt-4-turbo".to_string()); + assert_eq!(config.context_limit(), 128_000); + + // Test fallback to default + let config = ModelConfig::new("unknown-model".to_string()); + assert_eq!(config.context_limit(), DEFAULT_CONTEXT_LIMIT); + } + + #[test] + fn test_estimate_factor() { + // Test default value + let config = ModelConfig::new("test-model".to_string()); + assert_eq!(config.estimate_factor(), DEFAULT_ESTIMATE_FACTOR); + + // Test explicit value + let config = ModelConfig::new("test-model".to_string()).with_estimate_factor(Some(0.9)); + assert_eq!(config.estimate_factor(), 0.9); + } + + #[test] + fn test_anthropic_config() { + let config = AnthropicProviderConfig::new( + "https://api.anthropic.com".to_string(), + "test-key".to_string(), + "claude-3-opus".to_string(), + ); + + assert_eq!(config.model_config().context_limit(), 200_000); + + let config = AnthropicProviderConfig::new( + "https://api.anthropic.com".to_string(), + "test-key".to_string(), + "claude-3-opus".to_string(), + ); + let model_config = config + .model_config() + .clone() + .with_context_limit(Some(150_000)); + assert_eq!(model_config.context_limit(), 150_000); + } + + #[test] + fn test_openai_config() { + let config = OpenAiProviderConfig::new( + "https://api.openai.com".to_string(), + "test-key".to_string(), + "gpt-4-turbo".to_string(), + ); + + assert_eq!(config.model_config().context_limit(), 128_000); + + let config = OpenAiProviderConfig::new( + "https://api.openai.com".to_string(), + "test-key".to_string(), + "gpt-4-turbo".to_string(), + ); + let model_config = config + .model_config() + .clone() + .with_context_limit(Some(150_000)); + assert_eq!(model_config.context_limit(), 150_000); + } + + #[test] + fn test_model_config_settings() { + let config = ModelConfig::new("test-model".to_string()) + .with_temperature(Some(0.7)) + .with_max_tokens(Some(1000)) + .with_context_limit(Some(50_000)) + .with_estimate_factor(Some(0.9)); + + assert_eq!(config.temperature, Some(0.7)); + assert_eq!(config.max_tokens, Some(1000)); + assert_eq!(config.context_limit, Some(50_000)); + assert_eq!(config.estimate_factor, Some(0.9)); + } } diff --git a/crates/goose/src/providers/databricks.rs b/crates/goose/src/providers/databricks.rs index 0f385e6f02ef..460341d450e5 100644 --- a/crates/goose/src/providers/databricks.rs +++ b/crates/goose/src/providers/databricks.rs @@ -5,7 +5,7 @@ use serde_json::{json, Value}; use std::time::Duration; use super::base::{Provider, ProviderUsage, Usage}; -use super::configs::{DatabricksAuth, DatabricksProviderConfig}; +use super::configs::{DatabricksAuth, DatabricksProviderConfig, ModelConfig, ProviderModelConfig}; use super::model_pricing::{cost, model_pricing_for}; use super::oauth; use super::utils::{ @@ -76,7 +76,7 @@ impl DatabricksProvider { let url = format!( "{}/serving-endpoints/{}/invocations", self.config.host.trim_end_matches('/'), - self.config.model + self.config.model.model_name ); let auth_header = self.ensure_auth_header().await?; @@ -129,10 +129,10 @@ impl Provider for DatabricksProvider { if !tools_spec.is_empty() { payload["tools"] = json!(tools_spec); } - if let Some(temp) = self.config.temperature { + if let Some(temp) = self.config.model.temperature { payload["temperature"] = json!(temp); } - if let Some(tokens) = self.config.max_tokens { + if let Some(tokens) = self.config.model.max_tokens { payload["max_tokens"] = json!(tokens); } @@ -168,12 +168,17 @@ impl Provider for DatabricksProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } + + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } } #[cfg(test)] mod tests { use super::*; use crate::message::MessageContent; + use crate::providers::configs::ModelConfig; use wiremock::matchers::{body_json, header, method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; @@ -219,10 +224,8 @@ mod tests { // Create the DatabricksProvider with the mock server's URL as the host let config = DatabricksProviderConfig { host: mock_server.uri(), - model: "my-databricks-model".to_string(), auth: DatabricksAuth::Token("test_token".to_string()), - temperature: None, - max_tokens: None, + model: ModelConfig::new("my-databricks-model".to_string()), image_format: crate::providers::utils::ImageFormat::Anthropic, }; diff --git a/crates/goose/src/providers/mock.rs b/crates/goose/src/providers/mock.rs index aedc3d67648c..830c20601a82 100644 --- a/crates/goose/src/providers/mock.rs +++ b/crates/goose/src/providers/mock.rs @@ -6,6 +6,7 @@ use std::sync::Mutex; use crate::message::Message; use crate::providers::base::{Provider, Usage}; +use crate::providers::configs::ModelConfig; use mcp_core::tool::Tool; use super::base::ProviderUsage; @@ -13,6 +14,7 @@ use super::base::ProviderUsage; /// A mock provider that returns pre-configured responses for testing pub struct MockProvider { responses: Arc>>, + model_config: ModelConfig, } impl MockProvider { @@ -20,12 +22,25 @@ impl MockProvider { pub fn new(responses: Vec) -> Self { Self { responses: Arc::new(Mutex::new(responses)), + model_config: ModelConfig::new("mock-model".to_string()), + } + } + + /// Create a new mock provider with specific responses and model config + pub fn with_config(responses: Vec, model_config: ModelConfig) -> Self { + Self { + responses: Arc::new(Mutex::new(responses)), + model_config, } } } #[async_trait] impl Provider for MockProvider { + fn get_model_config(&self) -> &ModelConfig { + &self.model_config + } + async fn complete( &self, _system_prompt: &str, diff --git a/crates/goose/src/providers/ollama.rs b/crates/goose/src/providers/ollama.rs index 36af2f6dd947..feee301cf16f 100644 --- a/crates/goose/src/providers/ollama.rs +++ b/crates/goose/src/providers/ollama.rs @@ -1,5 +1,5 @@ use super::base::{Provider, ProviderUsage, Usage}; -use super::configs::OllamaProviderConfig; +use super::configs::{ModelConfig, OllamaProviderConfig, ProviderModelConfig}; use super::utils::{ get_model, messages_to_openai_spec, openai_response_to_message, tools_to_openai_spec, ImageFormat, @@ -99,7 +99,7 @@ impl Provider for OllamaProvider { messages_array.extend(messages_spec); let mut payload = json!({ - "model": self.config.model, + "model": self.config.model.model_name, "messages": messages_array }); @@ -109,13 +109,13 @@ impl Provider for OllamaProvider { .unwrap() .insert("tools".to_string(), json!(tools_spec)); } - if let Some(temp) = self.config.temperature { + if let Some(temp) = self.config.model.temperature { payload .as_object_mut() .unwrap() .insert("temperature".to_string(), json!(temp)); } - if let Some(tokens) = self.config.max_tokens { + if let Some(tokens) = self.config.model.max_tokens { payload .as_object_mut() .unwrap() @@ -132,6 +132,10 @@ impl Provider for OllamaProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } + + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } } #[cfg(test)] @@ -153,9 +157,7 @@ mod tests { // Create the OllamaProvider with the mock server's URL as the host let config = OllamaProviderConfig { host: mock_server.uri(), - model: OLLAMA_MODEL.to_string(), - temperature: None, - max_tokens: None, + model: ModelConfig::new(OLLAMA_MODEL.to_string()), }; let provider = OllamaProvider::new(config).unwrap(); @@ -289,9 +291,7 @@ mod tests { let config = OllamaProviderConfig { host: mock_server.uri(), - model: OLLAMA_MODEL.to_string(), - temperature: None, - max_tokens: None, + model: ModelConfig::new(OLLAMA_MODEL.to_string()), }; let provider = OllamaProvider::new(config)?; diff --git a/crates/goose/src/providers/openai.rs b/crates/goose/src/providers/openai.rs index 7816f12cd6dd..8b6a2748c9b7 100644 --- a/crates/goose/src/providers/openai.rs +++ b/crates/goose/src/providers/openai.rs @@ -8,6 +8,7 @@ use std::time::Duration; use super::base::ProviderUsage; use super::base::{Provider, Usage}; use super::configs::OpenAiProviderConfig; +use super::configs::{ModelConfig, ProviderModelConfig}; use super::model_pricing::cost; use super::model_pricing::model_pricing_for; use super::utils::get_model; @@ -116,7 +117,7 @@ impl Provider for OpenAiProvider { messages_array.extend(messages_spec); let mut payload = json!({ - "model": self.config.model, + "model": self.config.model.model_name, "messages": messages_array }); @@ -127,13 +128,13 @@ impl Provider for OpenAiProvider { .unwrap() .insert("tools".to_string(), json!(tools_spec)); } - if let Some(temp) = self.config.temperature { + if let Some(temp) = self.config.model.temperature { payload .as_object_mut() .unwrap() .insert("temperature".to_string(), json!(temp)); } - if let Some(tokens) = self.config.max_tokens { + if let Some(tokens) = self.config.model.max_tokens { payload .as_object_mut() .unwrap() @@ -159,12 +160,17 @@ impl Provider for OpenAiProvider { Ok((message, ProviderUsage::new(model, usage, cost))) } + + fn get_model_config(&self) -> &ModelConfig { + self.config.model_config() + } } #[cfg(test)] mod tests { use super::*; use crate::message::MessageContent; + use crate::providers::configs::ModelConfig; use rust_decimal_macros::dec; use serde_json::json; use wiremock::matchers::{method, path}; @@ -182,9 +188,7 @@ mod tests { let config = OpenAiProviderConfig { host: mock_server.uri(), api_key: "test_api_key".to_string(), - model: "gpt-3.5-turbo".to_string(), - temperature: Some(0.7), - max_tokens: None, + model: ModelConfig::new("gpt-3.5-turbo".to_string()).with_temperature(Some(0.7)), }; let provider = OpenAiProvider::new(config).unwrap(); diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index 42f0e500a2cc..1f2550b33ddd 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -7,7 +7,9 @@ use mcp_core::tool::Tool; use goose::providers::{ base::Provider, - configs::{DatabricksAuth, DatabricksProviderConfig, OpenAiProviderConfig, ProviderConfig}, + configs::{ + DatabricksAuth, DatabricksProviderConfig, ModelConfig, OpenAiProviderConfig, ProviderConfig, + }, factory::get_provider, }; @@ -113,9 +115,7 @@ async fn test_openai_provider() -> Result<()> { let config = ProviderConfig::OpenAi(OpenAiProviderConfig { host: "https://api.openai.com".to_string(), api_key: std::env::var("OPENAI_API_KEY")?, - model: std::env::var("OPENAI_MODEL")?, - temperature: None, - max_tokens: None, + model: ModelConfig::new(std::env::var("OPENAI_MODEL")?), }); let tester = ProviderTester::new(config)?; @@ -139,10 +139,8 @@ async fn test_databricks_provider() -> Result<()> { let config = ProviderConfig::Databricks(DatabricksProviderConfig { host: std::env::var("DATABRICKS_HOST")?, - model: std::env::var("DATABRICKS_MODEL")?, + model: ModelConfig::new(std::env::var("DATABRICKS_MODEL")?), auth: DatabricksAuth::Token(std::env::var("DATABRICKS_TOKEN")?), - temperature: None, - max_tokens: None, image_format: goose::providers::utils::ImageFormat::Anthropic, }); @@ -164,10 +162,8 @@ async fn test_databricks_provider_oauth() -> Result<()> { let config = ProviderConfig::Databricks(DatabricksProviderConfig { host: std::env::var("DATABRICKS_HOST")?, - model: std::env::var("DATABRICKS_MODEL")?, + model: ModelConfig::new(std::env::var("DATABRICKS_MODEL")?), auth: DatabricksAuth::oauth(std::env::var("DATABRICKS_HOST")?), - temperature: None, - max_tokens: None, image_format: goose::providers::utils::ImageFormat::Anthropic, }); @@ -190,9 +186,9 @@ async fn test_ollama_provider() -> Result<()> { let config = ProviderConfig::Ollama(OllamaProviderConfig { host: std::env::var("OLLAMA_HOST").unwrap_or_else(|_| String::from(OLLAMA_HOST)), - model: std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| String::from(OLLAMA_MODEL)), - temperature: None, - max_tokens: None, + model: ModelConfig::new( + std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| String::from(OLLAMA_MODEL)), + ), }); let tester = ProviderTester::new(config)?;