diff --git a/crates/goose/src/config/extensions.rs b/crates/goose/src/config/extensions.rs index b03415752306..3019f81ca537 100644 --- a/crates/goose/src/config/extensions.rs +++ b/crates/goose/src/config/extensions.rs @@ -9,6 +9,7 @@ pub const DEFAULT_EXTENSION: &str = "developer"; pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300; pub const DEFAULT_EXTENSION_DESCRIPTION: &str = ""; pub const DEFAULT_DISPLAY_NAME: &str = "Developer"; +const EXTENSIONS_CONFIG_KEY: &str = "extensions"; #[derive(Debug, Deserialize, Serialize, Clone, ToSchema)] pub struct ExtensionEntry { @@ -24,137 +25,64 @@ pub fn name_to_key(name: &str) -> String { .to_lowercase() } -/// Extension configuration management pub struct ExtensionConfigManager; impl ExtensionConfigManager { - /// Get the extension configuration if enabled -- uses key - pub fn get_config(key: &str) -> Result> { + fn get_extensions_map() -> Result> { let config = Config::global(); - - // Try to get the extension entry - let extensions: HashMap = match config.get_param("extensions") { - Ok(exts) => exts, - Err(super::ConfigError::NotFound(_)) => { - // Initialize with default developer extension - let defaults = HashMap::from([( - name_to_key(DEFAULT_EXTENSION), // Use key format for top-level key in config - ExtensionEntry { - enabled: true, - config: ExtensionConfig::Builtin { - name: DEFAULT_EXTENSION.to_string(), - display_name: Some(DEFAULT_DISPLAY_NAME.to_string()), - timeout: Some(DEFAULT_EXTENSION_TIMEOUT), - bundled: Some(true), - description: Some(DEFAULT_EXTENSION_DESCRIPTION.to_string()), - available_tools: Vec::new(), - }, - }, - )]); - config.set_param("extensions", serde_json::to_value(&defaults)?)?; - defaults - } - Err(e) => return Err(e.into()), - }; - - Ok(extensions.get(key).and_then(|entry| { - if entry.enabled { - Some(entry.config.clone()) - } else { - None - } - })) + Ok(config + .get_param(EXTENSIONS_CONFIG_KEY) + .unwrap_or_else(|_| HashMap::new())) } - pub fn get_config_by_name(name: &str) -> Result> { + fn save_extensions_map(extensions: HashMap) -> Result<()> { let config = Config::global(); + config.set_param(EXTENSIONS_CONFIG_KEY, serde_json::to_value(extensions)?)?; + Ok(()) + } - // Try to get the extension entry - let extensions: HashMap = match config.get_param("extensions") { - Ok(exts) => exts, - Err(super::ConfigError::NotFound(_)) => HashMap::new(), - Err(_) => HashMap::new(), - }; - + pub fn get_config_by_name(name: &str) -> Result> { + let extensions = Self::get_extensions_map()?; Ok(extensions .values() .find(|entry| entry.config.name() == name) .map(|entry| entry.config.clone())) } - /// Set or update an extension configuration pub fn set(entry: ExtensionEntry) -> Result<()> { - let config = Config::global(); - - let mut extensions: HashMap = config - .get_param("extensions") - .unwrap_or_else(|_| HashMap::new()); - + let mut extensions = Self::get_extensions_map()?; let key = entry.config.key(); - extensions.insert(key, entry); - config.set_param("extensions", serde_json::to_value(extensions)?)?; - Ok(()) + Self::save_extensions_map(extensions) } - /// Remove an extension configuration -- uses the key pub fn remove(key: &str) -> Result<()> { - let config = Config::global(); - - let mut extensions: HashMap = config - .get_param("extensions") - .unwrap_or_else(|_| HashMap::new()); - + let mut extensions = Self::get_extensions_map()?; extensions.remove(key); - config.set_param("extensions", serde_json::to_value(extensions)?)?; - Ok(()) + Self::save_extensions_map(extensions) } - /// Enable or disable an extension -- uses key pub fn set_enabled(key: &str, enabled: bool) -> Result<()> { - let config = Config::global(); - - let mut extensions: HashMap = config - .get_param("extensions") - .unwrap_or_else(|_| HashMap::new()); - + let mut extensions = Self::get_extensions_map()?; if let Some(entry) = extensions.get_mut(key) { entry.enabled = enabled; - config.set_param("extensions", serde_json::to_value(extensions)?)?; + Self::save_extensions_map(extensions)?; } Ok(()) } - /// Get all extensions and their configurations pub fn get_all() -> Result> { - let config = Config::global(); - let extensions: HashMap = match config.get_param("extensions") { - Ok(exts) => exts, - Err(super::ConfigError::NotFound(_)) => HashMap::new(), - Err(e) => return Err(e.into()), - }; - Ok(Vec::from_iter(extensions.values().cloned())) + let extensions = Self::get_extensions_map()?; + Ok(extensions.into_values().collect()) } - /// Get all extension names pub fn get_all_names() -> Result> { - let config = Config::global(); - Ok(config - .get_param("extensions") - .unwrap_or_else(|_| get_keys(Default::default()))) + let extensions = Self::get_extensions_map()?; + Ok(extensions.keys().cloned().collect()) } - /// Check if an extension is enabled - FIXED to use key pub fn is_enabled(key: &str) -> Result { - let config = Config::global(); - let extensions: HashMap = config - .get_param("extensions") - .unwrap_or_else(|_| HashMap::new()); - + let extensions = Self::get_extensions_map()?; Ok(extensions.get(key).map(|e| e.enabled).unwrap_or(false)) } } - -fn get_keys(entries: HashMap) -> Vec { - entries.into_keys().collect() -}