Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 22 additions & 94 deletions crates/goose/src/config/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub const DEFAULT_EXTENSION: &str = "developer";
pub const DEFAULT_EXTENSION_TIMEOUT: u64 = 300;
pub const DEFAULT_EXTENSION_DESCRIPTION: &str = "";
pub const DEFAULT_DISPLAY_NAME: &str = "Developer";
const EXTENSIONS_CONFIG_KEY: &str = "extensions";

#[derive(Debug, Deserialize, Serialize, Clone, ToSchema)]
pub struct ExtensionEntry {
Expand All @@ -24,137 +25,64 @@ pub fn name_to_key(name: &str) -> String {
.to_lowercase()
}

/// Extension configuration management
pub struct ExtensionConfigManager;

impl ExtensionConfigManager {
/// Get the extension configuration if enabled -- uses key
pub fn get_config(key: &str) -> Result<Option<ExtensionConfig>> {
fn get_extensions_map() -> Result<HashMap<String, ExtensionEntry>> {
let config = Config::global();

// Try to get the extension entry
let extensions: HashMap<String, ExtensionEntry> = match config.get_param("extensions") {
Ok(exts) => exts,
Err(super::ConfigError::NotFound(_)) => {
// Initialize with default developer extension
let defaults = HashMap::from([(
name_to_key(DEFAULT_EXTENSION), // Use key format for top-level key in config
ExtensionEntry {
enabled: true,
config: ExtensionConfig::Builtin {
name: DEFAULT_EXTENSION.to_string(),
display_name: Some(DEFAULT_DISPLAY_NAME.to_string()),
timeout: Some(DEFAULT_EXTENSION_TIMEOUT),
bundled: Some(true),
description: Some(DEFAULT_EXTENSION_DESCRIPTION.to_string()),
available_tools: Vec::new(),
},
},
)]);
config.set_param("extensions", serde_json::to_value(&defaults)?)?;
defaults
}
Err(e) => return Err(e.into()),
};

Ok(extensions.get(key).and_then(|entry| {
if entry.enabled {
Some(entry.config.clone())
} else {
None
}
}))
Ok(config
.get_param(EXTENSIONS_CONFIG_KEY)
.unwrap_or_else(|_| HashMap::new()))
}

pub fn get_config_by_name(name: &str) -> Result<Option<ExtensionConfig>> {
fn save_extensions_map(extensions: HashMap<String, ExtensionEntry>) -> Result<()> {
let config = Config::global();
config.set_param(EXTENSIONS_CONFIG_KEY, serde_json::to_value(extensions)?)?;
Ok(())
}

// Try to get the extension entry
let extensions: HashMap<String, ExtensionEntry> = match config.get_param("extensions") {
Ok(exts) => exts,
Err(super::ConfigError::NotFound(_)) => HashMap::new(),
Err(_) => HashMap::new(),
};

pub fn get_config_by_name(name: &str) -> Result<Option<ExtensionConfig>> {
let extensions = Self::get_extensions_map()?;
Ok(extensions
.values()
.find(|entry| entry.config.name() == name)
.map(|entry| entry.config.clone()))
}

/// Set or update an extension configuration
pub fn set(entry: ExtensionEntry) -> Result<()> {
let config = Config::global();

let mut extensions: HashMap<String, ExtensionEntry> = config
.get_param("extensions")
.unwrap_or_else(|_| HashMap::new());

let mut extensions = Self::get_extensions_map()?;
let key = entry.config.key();

extensions.insert(key, entry);
config.set_param("extensions", serde_json::to_value(extensions)?)?;
Ok(())
Self::save_extensions_map(extensions)
}

/// Remove an extension configuration -- uses the key
pub fn remove(key: &str) -> Result<()> {
let config = Config::global();

let mut extensions: HashMap<String, ExtensionEntry> = config
.get_param("extensions")
.unwrap_or_else(|_| HashMap::new());

let mut extensions = Self::get_extensions_map()?;
extensions.remove(key);
config.set_param("extensions", serde_json::to_value(extensions)?)?;
Ok(())
Self::save_extensions_map(extensions)
}

/// Enable or disable an extension -- uses key
pub fn set_enabled(key: &str, enabled: bool) -> Result<()> {
let config = Config::global();

let mut extensions: HashMap<String, ExtensionEntry> = config
.get_param("extensions")
.unwrap_or_else(|_| HashMap::new());

let mut extensions = Self::get_extensions_map()?;
if let Some(entry) = extensions.get_mut(key) {
entry.enabled = enabled;
config.set_param("extensions", serde_json::to_value(extensions)?)?;
Self::save_extensions_map(extensions)?;
}
Ok(())
}

/// Get all extensions and their configurations
pub fn get_all() -> Result<Vec<ExtensionEntry>> {
let config = Config::global();
let extensions: HashMap<String, ExtensionEntry> = match config.get_param("extensions") {
Ok(exts) => exts,
Err(super::ConfigError::NotFound(_)) => HashMap::new(),
Err(e) => return Err(e.into()),
};
Ok(Vec::from_iter(extensions.values().cloned()))
let extensions = Self::get_extensions_map()?;
Ok(extensions.into_values().collect())
}

/// Get all extension names
pub fn get_all_names() -> Result<Vec<String>> {
let config = Config::global();
Ok(config
.get_param("extensions")
.unwrap_or_else(|_| get_keys(Default::default())))
let extensions = Self::get_extensions_map()?;
Ok(extensions.keys().cloned().collect())
}

/// Check if an extension is enabled - FIXED to use key
pub fn is_enabled(key: &str) -> Result<bool> {
let config = Config::global();
let extensions: HashMap<String, ExtensionEntry> = config
.get_param("extensions")
.unwrap_or_else(|_| HashMap::new());

let extensions = Self::get_extensions_map()?;
Ok(extensions.get(key).map(|e| e.enabled).unwrap_or(false))
}
}

fn get_keys(entries: HashMap<String, ExtensionEntry>) -> Vec<String> {
entries.into_keys().collect()
}
Loading