diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index e9c8ce1c3ae8..62f1647d954b 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -257,6 +257,165 @@ fn extract_auth_error( } } +/// Merge environment variables from direct envs and keychain-stored env_keys +async fn merge_environments( + envs: &Envs, + env_keys: &[String], + ext_name: &str, +) -> Result, ExtensionError> { + let mut all_envs = envs.get_env(); + let config_instance = Config::global(); + + for key in env_keys { + if all_envs.contains_key(key) { + continue; + } + + match config_instance.get(key, true) { + Ok(value) => { + if value.is_null() { + warn!( + key = %key, + ext_name = %ext_name, + "Secret key not found in config (returned null)." + ); + continue; + } + + if let Some(str_val) = value.as_str() { + all_envs.insert(key.clone(), str_val.to_string()); + } else { + warn!( + key = %key, + ext_name = %ext_name, + value_type = %value.get("type").and_then(|t| t.as_str()).unwrap_or("unknown"), + "Secret value is not a string; skipping." + ); + } + } + Err(e) => { + error!( + key = %key, + ext_name = %ext_name, + error = %e, + "Failed to fetch secret from config." + ); + return Err(ExtensionError::ConfigError(format!( + "Failed to fetch secret '{}' from config: {}", + key, e + ))); + } + } + } + + Ok(all_envs) +} + +/// Substitute environment variables in a string. Supports both ${VAR} and $VAR syntax. +fn substitute_env_vars(value: &str, env_map: &HashMap) -> String { + let mut result = value.to_string(); + + let re_braces = + regex::Regex::new(r"\$\{\s*([A-Za-z_][A-Za-z0-9_]*)\s*\}").expect("valid regex"); + for cap in re_braces.captures_iter(value) { + if let Some(var_name) = cap.get(1) { + if let Some(env_value) = env_map.get(var_name.as_str()) { + result = result.replace(&cap[0], env_value); + } + } + } + + let re_simple = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").expect("valid regex"); + for cap in re_simple.captures_iter(&result.clone()) { + if let Some(var_name) = cap.get(1) { + if !value.contains(&format!("${{{}}}", var_name.as_str())) { + if let Some(env_value) = env_map.get(var_name.as_str()) { + result = result.replace(&cap[0], env_value); + } + } + } + } + + result +} + +async fn create_streamable_http_client( + uri: &str, + timeout: Option, + headers: &HashMap, + name: &str, + all_envs: &HashMap, + provider: SharedProvider, +) -> ExtensionResult> { + let mut default_headers = HeaderMap::new(); + for (key, value) in headers { + let substituted_value = substitute_env_vars(value, all_envs); + default_headers.insert( + HeaderName::try_from(key) + .map_err(|_| ExtensionError::ConfigError(format!("invalid header: {}", key)))?, + substituted_value.parse().map_err(|_| { + ExtensionError::ConfigError(format!("invalid header value: {}", key)) + })?, + ); + } + + let http_client = reqwest::Client::builder() + .default_headers(default_headers) + .build() + .map_err(|_| ExtensionError::ConfigError("could not construct http client".to_string()))?; + + let transport = StreamableHttpClientTransport::with_client( + http_client, + StreamableHttpClientTransportConfig { + uri: uri.into(), + ..Default::default() + }, + ); + + let timeout_duration = + Duration::from_secs(timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT)); + + let client_res = McpClient::connect(transport, timeout_duration, provider.clone()).await; + + if extract_auth_error(&client_res).is_some() { + let am = oauth_flow(&uri.to_string(), &name.to_string()) + .await + .map_err(|_| ExtensionError::SetupError("auth error".to_string()))?; + let auth_client = AuthClient::new(reqwest::Client::default(), am); + let transport = StreamableHttpClientTransport::with_client( + auth_client, + StreamableHttpClientTransportConfig { + uri: uri.into(), + ..Default::default() + }, + ); + Ok(Box::new( + McpClient::connect(transport, timeout_duration, provider).await?, + )) + } else { + Ok(Box::new(client_res?)) + } +} + +async fn create_stdio_client( + cmd: &str, + args: &[String], + all_envs: HashMap, + timeout: &Option, + provider: SharedProvider, +) -> ExtensionResult> { + extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?; + + let resolved_cmd = resolve_command(cmd); + let command = Command::new(resolved_cmd).configure(|command| { + command.args(args).envs(all_envs); + }); + + Ok(Box::new( + child_process_client(command, timeout, provider).await?, + )) +} + impl ExtensionManager { pub fn new(provider: SharedProvider) -> Self { Self { @@ -295,63 +454,6 @@ impl ExtensionManager { let sanitized_name = normalize(config_name.clone()); let mut temp_dir = None; - /// Helper function to merge environment variables from direct envs and keychain-stored env_keys - async fn merge_environments( - envs: &Envs, - env_keys: &[String], - ext_name: &str, - ) -> Result, ExtensionError> { - let mut all_envs = envs.get_env(); - let config_instance = Config::global(); - - for key in env_keys { - // If the Envs payload already contains the key, prefer that value - // over looking into the keychain/secret store - if all_envs.contains_key(key) { - continue; - } - - match config_instance.get(key, true) { - Ok(value) => { - if value.is_null() { - warn!( - key = %key, - ext_name = %ext_name, - "Secret key not found in config (returned null)." - ); - continue; - } - - // Try to get string value - if let Some(str_val) = value.as_str() { - all_envs.insert(key.clone(), str_val.to_string()); - } else { - warn!( - key = %key, - ext_name = %ext_name, - value_type = %value.get("type").and_then(|t| t.as_str()).unwrap_or("unknown"), - "Secret value is not a string; skipping." - ); - } - } - Err(e) => { - error!( - key = %key, - ext_name = %ext_name, - error = %e, - "Failed to fetch secret from config." - ); - return Err(ExtensionError::ConfigError(format!( - "Failed to fetch secret '{}' from config: {}", - key, e - ))); - } - } - } - - Ok(all_envs) - } - let client: Box = match &config { ExtensionConfig::Sse { uri, timeout, .. } => { let transport = SseClientTransport::start(uri.to_string()).await.map_err( @@ -382,101 +484,16 @@ impl ExtensionManager { env_keys, .. } => { - // Merge environment variables from direct envs and keychain-stored env_keys let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - - // Helper function to substitute environment variables in a string - // Supports both ${VAR} and $VAR syntax - fn substitute_env_vars(value: &str, env_map: &HashMap) -> String { - let mut result = value.to_string(); - - // First handle ${VAR} syntax (with optional whitespace) - let re_braces = regex::Regex::new(r"\$\{\s*([A-Za-z_][A-Za-z0-9_]*)\s*\}") - .expect("valid regex"); - for cap in re_braces.captures_iter(value) { - if let Some(var_name) = cap.get(1) { - if let Some(env_value) = env_map.get(var_name.as_str()) { - result = result.replace(&cap[0], env_value); - } - } - } - - // Then handle $VAR syntax (simple variable without braces) - let re_simple = - regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").expect("valid regex"); - for cap in re_simple.captures_iter(&result.clone()) { - if let Some(var_name) = cap.get(1) { - // Only substitute if it wasn't already part of ${VAR} syntax - if !value.contains(&format!("${{{}}}", var_name.as_str())) { - if let Some(env_value) = env_map.get(var_name.as_str()) { - result = result.replace(&cap[0], env_value); - } - } - } - } - - result - } - - let mut default_headers = HeaderMap::new(); - for (key, value) in headers { - // Substitute environment variables in header values - let substituted_value = substitute_env_vars(value, &all_envs); - - default_headers.insert( - HeaderName::try_from(key).map_err(|_| { - ExtensionError::ConfigError(format!("invalid header: {}", key)) - })?, - substituted_value.parse().map_err(|_| { - ExtensionError::ConfigError(format!("invalid header value: {}", key)) - })?, - ); - } - let client = reqwest::Client::builder() - .default_headers(default_headers) - .build() - .map_err(|_| { - ExtensionError::ConfigError("could not construct http client".to_string()) - })?; - let transport = StreamableHttpClientTransport::with_client( - client, - StreamableHttpClientTransportConfig { - uri: uri.clone().into(), - ..Default::default() - }, - ); - let client_res = McpClient::connect( - transport, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), + create_streamable_http_client( + uri, + *timeout, + headers, + name, + &all_envs, self.provider.clone(), ) - .await; - let client = if let Some(_auth_error) = extract_auth_error(&client_res) { - let am = oauth_flow(uri, name) - .await - .map_err(|_| ExtensionError::SetupError("auth error".to_string()))?; - let client = AuthClient::new(reqwest::Client::default(), am); - let transport = StreamableHttpClientTransport::with_client( - client, - StreamableHttpClientTransportConfig { - uri: uri.clone().into(), - ..Default::default() - }, - ); - McpClient::connect( - transport, - Duration::from_secs( - timeout.unwrap_or(crate::config::DEFAULT_EXTENSION_TIMEOUT), - ), - self.provider.clone(), - ) - .await? - } else { - client_res? - }; - Box::new(client) + .await? } ExtensionConfig::Stdio { cmd, @@ -487,27 +504,9 @@ impl ExtensionManager { .. } => { let all_envs = merge_environments(envs, env_keys, &sanitized_name).await?; - - // Check for malicious packages before launching the process - extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?; - - let cmd = resolve_command(cmd); - - let command = Command::new(cmd).configure(|command| { - command.args(args).envs(all_envs); - }); - - let client = child_process_client(command, timeout, self.provider.clone()).await?; - Box::new(client) + create_stdio_client(cmd, args, all_envs, timeout, self.provider.clone()).await? } - ExtensionConfig::Builtin { - name, - display_name: _, - description: _, - timeout, - bundled: _, - available_tools: _, - } => { + ExtensionConfig::Builtin { name, timeout, .. } => { let cmd = std::env::current_exe() .and_then(|path| { path.to_str().map(|s| s.to_string()).ok_or_else(|| { @@ -526,11 +525,9 @@ impl ExtensionManager { let command = Command::new(cmd).configure(|command| { command.arg("mcp").arg(name); }); - let client = child_process_client(command, timeout, self.provider.clone()).await?; - Box::new(client) + Box::new(child_process_client(command, timeout, self.provider.clone()).await?) } ExtensionConfig::Platform { name, .. } => { - // Normalize the name to match the key used in PLATFORM_EXTENSIONS let normalized_key = normalize(name.clone()); let def = PLATFORM_EXTENSIONS .get(normalized_key.as_str()) @@ -554,17 +551,13 @@ impl ExtensionManager { let command = Command::new("uvx").configure(|command| { command.arg("--with").arg("mcp"); - dependencies.iter().flatten().for_each(|dep| { command.arg("--with").arg(dep); }); - command.arg("python").arg(file_path.to_str().unwrap()); }); - let client = child_process_client(command, timeout, self.provider.clone()).await?; - - Box::new(client) + Box::new(child_process_client(command, timeout, self.provider.clone()).await?) } ExtensionConfig::Frontend { .. } => { return Err(ExtensionError::ConfigError( @@ -1716,40 +1709,6 @@ mod tests { #[tokio::test] async fn test_streamable_http_header_env_substitution() { - use std::collections::HashMap; - - // Test the substitute_env_vars helper function (which is defined inside add_extension) - // We'll recreate it here for testing purposes - fn substitute_env_vars(value: &str, env_map: &HashMap) -> String { - let mut result = value.to_string(); - - // First handle ${VAR} syntax (with optional whitespace) - let re_braces = - regex::Regex::new(r"\$\{\s*([A-Za-z_][A-Za-z0-9_]*)\s*\}").expect("valid regex"); - for cap in re_braces.captures_iter(value) { - if let Some(var_name) = cap.get(1) { - if let Some(env_value) = env_map.get(var_name.as_str()) { - result = result.replace(&cap[0], env_value); - } - } - } - - // Then handle $VAR syntax (simple variable without braces) - let re_simple = regex::Regex::new(r"\$([A-Za-z_][A-Za-z0-9_]*)").expect("valid regex"); - for cap in re_simple.captures_iter(&result.clone()) { - if let Some(var_name) = cap.get(1) { - // Only substitute if it wasn't already part of ${VAR} syntax - if !value.contains(&format!("${{{}}}", var_name.as_str())) { - if let Some(env_value) = env_map.get(var_name.as_str()) { - result = result.replace(&cap[0], env_value); - } - } - } - } - - result - } - let mut env_map = HashMap::new(); env_map.insert("AUTH_TOKEN".to_string(), "secret123".to_string()); env_map.insert("API_KEY".to_string(), "key456".to_string());