diff --git a/crates/goose/src/model.rs b/crates/goose/src/model.rs index 9a9f8eb546dc..452ce71ae297 100644 --- a/crates/goose/src/model.rs +++ b/crates/goose/src/model.rs @@ -96,6 +96,7 @@ impl ModelConfig { ) -> Result { let context_limit = Self::parse_context_limit(&model_name, None, context_env_var)?; let temperature = Self::parse_temperature()?; + let max_tokens = Self::parse_max_tokens()?; let toolshim = Self::parse_toolshim()?; let toolshim_model = Self::parse_toolshim_model()?; @@ -103,7 +104,7 @@ impl ModelConfig { model_name, context_limit, temperature, - max_tokens: None, + max_tokens, toolshim, toolshim_model, fast_model: None, @@ -184,6 +185,26 @@ impl ModelConfig { } } + fn parse_max_tokens() -> Result, ConfigError> { + match crate::config::Config::global().get_param::("GOOSE_MAX_TOKENS") { + Ok(tokens) => { + if tokens <= 0 { + return Err(ConfigError::InvalidRange( + "goose_max_tokens".to_string(), + "must be greater than 0".to_string(), + )); + } + Ok(Some(tokens)) + } + Err(crate::config::ConfigError::NotFound(_)) => Ok(None), + Err(e) => Err(ConfigError::InvalidValue( + "goose_max_tokens".to_string(), + String::new(), + e.to_string(), + )), + } + } + fn parse_toolshim() -> Result { if let Ok(val) = std::env::var("GOOSE_TOOLSHIM") { match val.to_lowercase().as_str() { @@ -296,3 +317,72 @@ impl ModelConfig { .unwrap_or_else(|_| panic!("Failed to create model config for {}", model_name)) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_max_tokens_valid() { + let _guard = env_lock::lock_env([("GOOSE_MAX_TOKENS", Some("4096"))]); + let result = ModelConfig::parse_max_tokens().unwrap(); + assert_eq!(result, Some(4096)); + } + + #[test] + fn test_parse_max_tokens_not_set() { + let _guard = env_lock::lock_env([("GOOSE_MAX_TOKENS", None::<&str>)]); + let result = ModelConfig::parse_max_tokens().unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_parse_max_tokens_invalid_string() { + let _guard = env_lock::lock_env([("GOOSE_MAX_TOKENS", Some("not_a_number"))]); + let result = ModelConfig::parse_max_tokens(); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConfigError::InvalidValue(..))); + } + + #[test] + fn test_parse_max_tokens_zero() { + let _guard = env_lock::lock_env([("GOOSE_MAX_TOKENS", Some("0"))]); + let result = ModelConfig::parse_max_tokens(); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConfigError::InvalidRange(..))); + } + + #[test] + fn test_parse_max_tokens_negative() { + let _guard = env_lock::lock_env([("GOOSE_MAX_TOKENS", Some("-100"))]); + let result = ModelConfig::parse_max_tokens(); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), ConfigError::InvalidRange(..))); + } + + #[test] + fn test_model_config_with_max_tokens_env() { + let _guard = env_lock::lock_env([ + ("GOOSE_MAX_TOKENS", Some("8192")), + ("GOOSE_TEMPERATURE", None::<&str>), + ("GOOSE_CONTEXT_LIMIT", None::<&str>), + ("GOOSE_TOOLSHIM", None::<&str>), + ("GOOSE_TOOLSHIM_OLLAMA_MODEL", None::<&str>), + ]); + let config = ModelConfig::new("test-model").unwrap(); + assert_eq!(config.max_tokens, Some(8192)); + } + + #[test] + fn test_model_config_without_max_tokens_env() { + let _guard = env_lock::lock_env([ + ("GOOSE_MAX_TOKENS", None::<&str>), + ("GOOSE_TEMPERATURE", None::<&str>), + ("GOOSE_CONTEXT_LIMIT", None::<&str>), + ("GOOSE_TOOLSHIM", None::<&str>), + ("GOOSE_TOOLSHIM_OLLAMA_MODEL", None::<&str>), + ]); + let config = ModelConfig::new("test-model").unwrap(); + assert_eq!(config.max_tokens, None); + } +}