diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e183a46a9b4..eac44cb86386 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,6 +70,28 @@ jobs: restore-keys: | ${{ runner.os }}-cargo-build- + # Add disk space cleanup before linting + - name: Check disk space before build + run: df -h + + #https://github.com/actions/runner-images/issues/2840 + - name: Clean up disk space + run: | + echo "Cleaning up disk space..." + sudo rm -rf \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /usr/lib/mono \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/dotnet \ + /usr/share/swift + + df -h + - name: Build and Test run: | gnome-keyring-daemon --components=secrets --daemonize --unlock <<< 'foobar' @@ -129,4 +151,4 @@ jobs: uses: ./.github/workflows/bundle-desktop.yml if: github.event_name == 'pull_request' with: - signing: false \ No newline at end of file + signing: false diff --git a/Cargo.lock b/Cargo.lock index 75b37e28e24d..81331964e336 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2420,6 +2420,7 @@ dependencies = [ "lopdf", "mcp-core", "mcp-server", + "oauth2", "once_cell", "regex", "reqwest 0.11.27", @@ -3920,6 +3921,26 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "oauth2" +version = "5.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" +dependencies = [ + "base64 0.22.1", + "chrono", + "getrandom 0.2.15", + "http 1.2.0", + "rand", + "reqwest 0.12.12", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + [[package]] name = "objc" version = "0.2.7" @@ -6200,6 +6221,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 66ab11dc6a72..c400451a64e0 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -47,6 +47,7 @@ docx-rs = "0.4.7" image = "0.24.9" umya-spreadsheet = "2.2.3" keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } +oauth2 = { version = "5.0.0", features = ["reqwest"] } [dev-dependencies] serial_test = "3.0.0" diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index fa19fd47738c..3bdeaf1e0feb 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,13 +1,13 @@ -mod token_storage; +mod oauth_pkce; +pub mod storage; use indoc::indoc; +use oauth_pkce::PkceOAuth2Client; use regex::Regex; use serde_json::{json, Value}; -use token_storage::{CredentialsManager, KeychainTokenStorage}; - use std::io::Cursor; -use std::sync::Arc; -use std::{env, fs, future::Future, path::Path, pin::Pin}; +use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; +use storage::CredentialsManager; use mcp_core::content::Content; use mcp_core::{ @@ -26,47 +26,15 @@ use google_drive3::{ api::{File, Scope}, hyper_rustls::{self, HttpsConnector}, hyper_util::{self, client::legacy::connect::HttpConnector}, - yup_oauth2::{ - self, - authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}, - InstalledFlowAuthenticator, - }, DriveHub, }; use google_sheets4::{self, Sheets}; use http_body_util::BodyExt; -/// async function to be pinned by the `present_user_url` method of the trait -/// we use the existing `DefaultInstalledFlowDelegate::present_user_url` method as a fallback for -/// when the browser did not open for example, the user still see's the URL. -async fn browser_user_url(url: &str, need_code: bool) -> Result { - tracing::info!(oauth_url = url, "Attempting OAuth login flow"); - if let Err(e) = webbrowser::open(url) { - tracing::debug!(oauth_url = url, error = ?e, "Failed to open OAuth flow"); - println!("Please open this URL in your browser:\n{}", url); - } - let def_delegate = DefaultInstalledFlowDelegate; - def_delegate.present_user_url(url, need_code).await -} - -/// our custom delegate struct we will implement a flow delegate trait for: -/// in this case we will implement the `InstalledFlowDelegated` trait -#[derive(Copy, Clone)] -struct LocalhostBrowserDelegate; - -/// here we implement only the present_user_url method with the added webbrowser opening -/// the other behaviour of the trait does not need to be changed. -impl InstalledFlowDelegate for LocalhostBrowserDelegate { - /// the actual presenting of URL and browser opening happens in the function defined above here - /// we only pin it - fn present_user_url<'a>( - &'a self, - url: &'a str, - need_code: bool, - ) -> Pin> + Send + 'a>> { - Box::pin(browser_user_url(url, need_code)) - } -} +// Constants for credential storage +pub const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; +pub const KEYCHAIN_USERNAME: &str = "oauth_credentials"; +pub const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; #[derive(Debug)] enum FileOperation { @@ -141,38 +109,31 @@ impl GoogleDriveRouter { } } - // Create a credentials manager for storing tokens securely - let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); - - // Read the application secret from the OAuth keyfile - let secret = yup_oauth2::read_application_secret(keyfile_path) - .await - .expect("expected keyfile for google auth"); - - // Create custom token storage using our credentials manager - let token_storage = KeychainTokenStorage::new( - secret - .project_id - .clone() - .unwrap_or("unknown-project-id".to_string()) - .to_string(), - credentials_manager.clone(), - ); - - // Create the authenticator with the installed flow - let auth = InstalledFlowAuthenticator::builder( - secret, - yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, - ) - .with_storage(Box::new(token_storage)) // Use our custom storage - .flow_delegate(Box::new(LocalhostBrowserDelegate)) - .build() - .await - .expect("expected successful authentication"); + // Check if we should fall back to disk, must be explicitly enabled + let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { + Ok(value) => value.to_lowercase() == "true", + Err(_) => false, + }; - // Create the HTTP client - let client = - hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + // Create a credentials manager for storing tokens securely + let credentials_manager = Arc::new(CredentialsManager::new( + credentials_path.clone(), + fallback_to_disk, + KEYCHAIN_SERVICE.to_string(), + KEYCHAIN_USERNAME.to_string(), + )); + + // Read the OAuth credentials from the keyfile + match fs::read_to_string(keyfile_path) { + Ok(_) => { + // Create the PKCE OAuth2 client + let auth = PkceOAuth2Client::new(keyfile_path, credentials_manager.clone()) + .expect("Failed to create OAuth2 client"); + + // Create the HTTP client + let client = hyper_util::client::legacy::Client::builder( + hyper_util::rt::TokioExecutor::new(), + ) .build( hyper_rustls::HttpsConnectorBuilder::new() .with_native_roots() @@ -182,11 +143,21 @@ impl GoogleDriveRouter { .build(), ); - let drive_hub = DriveHub::new(client.clone(), auth.clone()); - let sheets_hub = Sheets::new(client, auth); + let drive_hub = DriveHub::new(client.clone(), auth.clone()); + let sheets_hub = Sheets::new(client, auth); - // Create and return the DriveHub - (drive_hub, sheets_hub, credentials_manager) + // Create and return the DriveHub, Sheets and our PKCE OAuth2 client + (drive_hub, sheets_hub, credentials_manager) + } + Err(e) => { + tracing::error!( + "Failed to read OAuth config from {}: {}", + keyfile_path.display(), + e + ); + panic!("Failed to read OAuth config: {}", e); + } + } } pub async fn new() -> Self { @@ -715,7 +686,7 @@ impl GoogleDriveRouter { .collect::>() .join("\n"); - Ok(vec![Content::text(content.to_string())]) + Ok(vec![Content::text(content.to_string()).with_priority(0.3)]) } } } diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs new file mode 100644 index 000000000000..1da7380aa252 --- /dev/null +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -0,0 +1,351 @@ +use std::error::Error; +use std::fs; +use std::future::Future; +use std::io::{BufRead, BufReader, Write}; +use std::net::TcpListener; +use std::path::Path; +use std::pin::Pin; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use google_drive3::common::GetToken; +use oauth2::basic::BasicClient; +use oauth2::reqwest; +use oauth2::{ + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet, EndpointSet, + PkceCodeChallenge, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl, +}; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error, info}; +use url::Url; + +use super::storage::CredentialsManager; + +/// Structure representing the OAuth2 configuration file format +#[derive(Debug, Deserialize, Serialize)] +struct OAuth2Config { + installed: InstalledConfig, +} + +#[derive(Debug, Deserialize, Serialize)] +struct InstalledConfig { + client_id: String, + project_id: String, + auth_uri: String, + token_uri: String, + auth_provider_x509_cert_url: String, + client_secret: String, + redirect_uris: Vec, +} + +/// Structure for token storage +#[derive(Debug, Deserialize, Serialize)] +struct TokenData { + access_token: String, + refresh_token: String, + #[serde(skip_serializing_if = "Option::is_none")] + expires_at: Option, + project_id: String, +} + +/// PkceOAuth2Client implements the GetToken trait required by DriveHub +/// It uses the oauth2 crate to implement a PKCE-enabled OAuth2 flow +#[derive(Clone)] +pub struct PkceOAuth2Client { + client: BasicClient, + credentials_manager: Arc, + http_client: reqwest::Client, + project_id: String, +} + +impl PkceOAuth2Client { + pub fn new( + config_path: impl AsRef, + credentials_manager: Arc, + ) -> Result> { + // Load and parse the config file + let config_content = fs::read_to_string(config_path)?; + let config: OAuth2Config = serde_json::from_str(&config_content)?; + + // Extract the project_id from the config + let project_id = config.installed.project_id.clone(); + + // Create OAuth URLs + let auth_url = + AuthUrl::new(config.installed.auth_uri).expect("Invalid authorization endpoint URL"); + let token_url = + TokenUrl::new(config.installed.token_uri).expect("Invalid token endpoint URL"); + + // Set up the OAuth2 client + let client = BasicClient::new(ClientId::new(config.installed.client_id)) + .set_client_secret(ClientSecret::new(config.installed.client_secret)) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri( + RedirectUrl::new("http://localhost:18080".to_string()) + .expect("Invalid redirect URL"), + ); + + let http_client = reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Oauth2 HTTP Client should build"); + + Ok(Self { + client, + credentials_manager, + http_client, + project_id, + }) + } + + /// Check if a token is expired or about to expire within the buffer period + fn is_token_expired(&self, expires_at: Option, buffer_seconds: u64) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + + // Consider the token expired if it's within buffer_seconds of expiring + // This gives us a safety margin to avoid using tokens right before expiration + expires_at + .map(|expiry_time| now + buffer_seconds >= expiry_time) + .unwrap_or(true) // If we don't know when it expires, assume it's expired to be safe + } + + async fn perform_oauth_flow( + &self, + scopes: &[&str], + ) -> Result> { + // Create a PKCE code verifier and challenge + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // Generate the authorization URL + let (auth_url, csrf_token) = self + .client + .authorize_url(CsrfToken::new_random) + .add_scopes(scopes.iter().map(|&s| Scope::new(s.to_string()))) + .set_pkce_challenge(pkce_challenge) + .url(); + + info!("Opening browser for OAuth2 authentication"); + if let Err(e) = webbrowser::open(auth_url.as_str()) { + error!("Failed to open browser: {}", e); + println!("Please open this URL in your browser:\n{}\n", auth_url); + } + + // Start a local server to receive the authorization code + // We'll spawn this in a separate thread since it's blocking + let (tx, rx) = tokio::sync::oneshot::channel(); + std::thread::spawn(move || match Self::start_redirect_server() { + Ok(result) => { + let _ = tx.send(Ok(result)); + } + Err(e) => { + let _ = tx.send(Err(e)); + } + }); + + // Wait for the code from the redirect server + let (code, received_state) = rx.await??; + + // Verify the CSRF state + if received_state.secret() != csrf_token.secret() { + return Err("CSRF token mismatch".into()); + } + + // Use the built-in exchange_code method with PKCE verifier + let token_result = self + .client + .exchange_code(code) + .set_pkce_verifier(pkce_verifier) + .request_async(&self.http_client) + .await + .map_err(|e| Box::new(e) as Box)?; + + let access_token = token_result.access_token().secret().clone(); + + // Calculate expires_at as a Unix timestamp by adding expires_in to current time + let expires_at = token_result.expires_in().map(|duration| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + now + duration.as_secs() + }); + + // Get the refresh token if provided + if let Some(refresh_token) = token_result.refresh_token() { + let refresh_token_str = refresh_token.secret().clone(); + + // Store token data + let token_data = TokenData { + access_token: access_token.clone(), + refresh_token: refresh_token_str.clone(), + expires_at, + project_id: self.project_id.clone(), + }; + + // Store updated token data + self.credentials_manager + .write_credentials(&token_data) + .map(|_| debug!("Successfully stored token data")) + .unwrap_or_else(|e| error!("Failed to store token data: {}", e)); + } else { + debug!("No refresh token provided in OAuth flow response"); + } + + Ok(access_token) + } + + async fn refresh_token( + &self, + refresh_token: &str, + ) -> Result> { + debug!("Attempting to refresh access token"); + + // Create a RefreshToken from the string + let refresh_token = RefreshToken::new(refresh_token.to_string()); + + // Use the built-in exchange_refresh_token method + let token_result = self + .client + .exchange_refresh_token(&refresh_token) + .request_async(&self.http_client) + .await + .map_err(|e| Box::new(e) as Box)?; + + let access_token = token_result.access_token().secret().clone(); + + // Calculate expires_at as a Unix timestamp by adding expires_in to current time + let expires_at = token_result.expires_in().map(|duration| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + now + duration.as_secs() + }); + + // Get the refresh token - either the new one or reuse the existing one + let new_refresh_token = token_result + .refresh_token() + .map(|token| token.secret().clone()) + .unwrap_or_else(|| refresh_token.secret().to_string()); + + // Always update the token data with the new access token and expiration + let token_data = TokenData { + access_token: access_token.clone(), + refresh_token: new_refresh_token.clone(), + expires_at, + project_id: self.project_id.clone(), + }; + + // Store updated token data + self.credentials_manager + .write_credentials(&token_data) + .map(|_| debug!("Successfully stored token data")) + .unwrap_or_else(|e| error!("Failed to store token data: {}", e)); + + Ok(access_token) + } + + fn start_redirect_server( + ) -> Result<(AuthorizationCode, CsrfToken), Box> { + let listener = TcpListener::bind("127.0.0.1:18080")?; + println!("Listening for the authorization code on http://localhost:18080"); + + for stream in listener.incoming() { + match stream { + Ok(mut stream) => { + let mut reader = BufReader::new(&stream); + let mut request_line = String::new(); + reader.read_line(&mut request_line)?; + + let redirect_url = request_line + .split_whitespace() + .nth(1) + .ok_or("Invalid request")?; + + let url = Url::parse(&format!("http://localhost{}", redirect_url))?; + + let code = url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| AuthorizationCode::new(value.into_owned())) + .ok_or("No code found in the response")?; + + let state = url + .query_pairs() + .find(|(key, _)| key == "state") + .map(|(_, value)| CsrfToken::new(value.into_owned())) + .ok_or("No state found in the response")?; + + // Send a success response to the browser + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\ +

Authentication successful!

\ +

You can now close this window and return to the application.

"; + + stream.write_all(response.as_bytes())?; + stream.flush()?; + + return Ok((code, state)); + } + Err(e) => { + error!("Failed to accept connection: {}", e); + } + } + } + + Err("Failed to receive authorization code".into()) + } +} + +// impl GetToken for use with DriveHub directly +// see google_drive3::common::GetToken +impl GetToken for PkceOAuth2Client { + fn get_token<'a>( + &'a self, + scopes: &'a [&str], + ) -> Pin< + Box, Box>> + Send + 'a>, + > { + Box::pin(async move { + // Try to read token data from storage to check if we have a valid token + if let Ok(token_data) = self.credentials_manager.read_credentials::() { + // Verify the project_id matches + if token_data.project_id == self.project_id { + // Check if the token is expired or expiring within a 5-min buffer + if !self.is_token_expired(token_data.expires_at, 300) { + return Ok(Some(token_data.access_token)); + } + + // Token is expired or will expire soon, try to refresh it + debug!("Token is expired or will expire soon, refreshing..."); + + // Try to refresh the token + if let Ok(access_token) = self.refresh_token(&token_data.refresh_token).await { + debug!("Successfully refreshed access token"); + return Ok(Some(access_token)); + } + } + } + + // If we get here, either: + // 1. The project ID didn't match + // 2. Token refresh failed + // 3. There are no valid tokens yet + // Fallback: perform interactive OAuth flow + match self.perform_oauth_flow(scopes).await { + Ok(token) => { + debug!("Successfully obtained new access token through OAuth flow"); + Ok(Some(token)) + } + Err(e) => { + error!("OAuth flow failed: {}", e); + Err(e) + } + } + }) + } +} diff --git a/crates/goose-mcp/src/google_drive/storage.rs b/crates/goose-mcp/src/google_drive/storage.rs new file mode 100644 index 000000000000..8e8f3c08dec3 --- /dev/null +++ b/crates/goose-mcp/src/google_drive/storage.rs @@ -0,0 +1,344 @@ +use anyhow::Result; +use keyring::Entry; +use serde::{de::DeserializeOwned, Serialize}; +use std::fs; +use std::path::Path; +use thiserror::Error; +use tracing::{debug, error, warn}; + +#[allow(dead_code)] +#[derive(Error, Debug)] +pub enum StorageError { + #[error("Failed to access keychain: {0}")] + KeyringError(#[from] keyring::Error), + #[error("Failed to access file system: {0}")] + FileSystemError(#[from] std::io::Error), + #[error("No credentials found")] + NotFound, + #[error("Critical error: {0}")] + Critical(String), + #[error("Failed to serialize/deserialize: {0}")] + SerializationError(#[from] serde_json::Error), +} + +/// CredentialsManager handles secure storage of OAuth credentials. +/// It attempts to store credentials in the system keychain first, +/// with fallback to file system storage if keychain access fails and fallback is enabled. +pub struct CredentialsManager { + credentials_path: String, + fallback_to_disk: bool, + keychain_service: String, + keychain_username: String, +} + +impl CredentialsManager { + pub fn new( + credentials_path: String, + fallback_to_disk: bool, + keychain_service: String, + keychain_username: String, + ) -> Self { + Self { + credentials_path, + fallback_to_disk, + keychain_service, + keychain_username, + } + } + + /// Reads and deserializes credentials from secure storage. + /// + /// This method attempts to read credentials from the system keychain first. + /// If keychain access fails and fallback is enabled, it will try to read from the file system. + /// + /// # Type Parameters + /// + /// * `T` - The type to deserialize the credentials into. Must implement `serde::de::DeserializeOwned`. + /// + /// # Returns + /// + /// * `Ok(T)` - The deserialized credentials + /// * `Err(StorageError)` - If reading or deserialization fails + /// + /// # Examples + /// + /// ```no_run + /// # use goose_mcp::google_drive::storage::CredentialsManager; + /// use serde::{Serialize, Deserialize}; + /// + /// #[derive(Serialize, Deserialize)] + /// struct OAuthToken { + /// access_token: String, + /// refresh_token: String, + /// expiry: u64, + /// } + /// + /// let manager = CredentialsManager::new( + /// String::from("/path/to/credentials.json"), + /// true, // fallback to disk if keychain fails + /// String::from("test_service"), + /// String::from("test_user") + /// ); + /// match manager.read_credentials::() { + /// Ok(token) => println!("Token expires at: {}", token.expiry), + /// Err(e) => eprintln!("Failed to read token: {}", e), + /// } + /// ``` + pub fn read_credentials(&self) -> Result + where + T: DeserializeOwned, + { + let json_str = Entry::new(&self.keychain_service, &self.keychain_username) + .and_then(|entry| entry.get_password()) + .inspect(|_| { + debug!("Successfully read credentials from keychain"); + }) + .or_else(|e| { + if self.fallback_to_disk { + debug!("Falling back to file system due to keyring error: {}", e); + self.read_from_file() + } else { + match e { + keyring::Error::NoEntry => Err(StorageError::NotFound), + _ => Err(StorageError::KeyringError(e)), + } + } + })?; + + serde_json::from_str(&json_str).map_err(StorageError::SerializationError) + } + + fn read_from_file(&self) -> Result { + let path = Path::new(&self.credentials_path); + if path.exists() { + match fs::read_to_string(path) { + Ok(content) => { + debug!("Successfully read credentials from file system"); + Ok(content) + } + Err(e) => { + error!("Failed to read credentials file: {}", e); + Err(StorageError::FileSystemError(e)) + } + } + } else { + debug!("No credentials found in file system"); + Err(StorageError::NotFound) + } + } + + /// Serializes and writes credentials to secure storage. + /// + /// This method attempts to write credentials to the system keychain first. + /// If keychain access fails and fallback is enabled, it will try to write to the file system. + /// + /// # Type Parameters + /// + /// * `T` - The type to serialize. Must implement `serde::Serialize`. + /// + /// # Parameters + /// + /// * `content` - The data to serialize and store + /// + /// # Returns + /// + /// * `Ok(())` - If writing succeeds + /// * `Err(StorageError)` - If serialization or writing fails + /// + /// # Examples + /// + /// ```no_run + /// # use goose_mcp::google_drive::storage::CredentialsManager; + /// use serde::{Serialize, Deserialize}; + /// + /// #[derive(Serialize, Deserialize)] + /// struct OAuthToken { + /// access_token: String, + /// refresh_token: String, + /// expiry: u64, + /// } + /// + /// let token = OAuthToken { + /// access_token: String::from("access_token_value"), + /// refresh_token: String::from("refresh_token_value"), + /// expiry: 1672531200, // Unix timestamp + /// }; + /// + /// let manager = CredentialsManager::new( + /// String::from("/path/to/credentials.json"), + /// true, // fallback to disk if keychain fails + /// String::from("test_service"), + /// String::from("test_user") + /// ); + /// if let Err(e) = manager.write_credentials(&token) { + /// eprintln!("Failed to write token: {}", e); + /// } + /// ``` + pub fn write_credentials(&self, content: &T) -> Result<(), StorageError> + where + T: Serialize, + { + let json_str = serde_json::to_string(content).map_err(StorageError::SerializationError)?; + + Entry::new(&self.keychain_service, &self.keychain_username) + .and_then(|entry| entry.set_password(&json_str)) + .inspect(|_| { + debug!("Successfully wrote credentials to keychain"); + }) + .or_else(|e| { + if self.fallback_to_disk { + warn!("Falling back to file system due to keyring error: {}", e); + self.write_to_file(&json_str) + } else { + Err(StorageError::KeyringError(e)) + } + }) + } + + fn write_to_file(&self, content: &str) -> Result<(), StorageError> { + let path = Path::new(&self.credentials_path); + if let Some(parent) = path.parent() { + if !parent.exists() { + match fs::create_dir_all(parent) { + Ok(_) => debug!("Created parent directories for credentials file"), + Err(e) => { + error!("Failed to create directories for credentials file: {}", e); + return Err(StorageError::FileSystemError(e)); + } + } + } + } + + match fs::write(path, content) { + Ok(_) => { + debug!("Successfully wrote credentials to file system"); + Ok(()) + } + Err(e) => { + error!("Failed to write credentials to file system: {}", e); + Err(StorageError::FileSystemError(e)) + } + } + } +} + +impl Clone for CredentialsManager { + fn clone(&self) -> Self { + Self { + credentials_path: self.credentials_path.clone(), + fallback_to_disk: self.fallback_to_disk, + keychain_service: self.keychain_service.clone(), + keychain_username: self.keychain_username.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + use tempfile::tempdir; + + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct TestCredentials { + access_token: String, + refresh_token: String, + expiry: u64, + } + + impl TestCredentials { + fn new() -> Self { + Self { + access_token: "test_access_token".to_string(), + refresh_token: "test_refresh_token".to_string(), + expiry: 1672531200, + } + } + } + + #[test] + fn test_read_write_from_keychain() { + // Create a temporary directory for test files + let temp_dir = tempdir().expect("Failed to create temp dir"); + let cred_path = temp_dir.path().join("test_credentials.json"); + let cred_path_str = cred_path.to_str().unwrap().to_string(); + + // Create a credentials manager with fallback enabled + // Using a unique service name to ensure keychain operation fails + let manager = CredentialsManager::new( + cred_path_str, + true, // fallback to disk + "test_service".to_string(), + "test_user".to_string(), + ); + + // Test credentials to store + let creds = TestCredentials::new(); + + // Write should write to keychain + let write_result = manager.write_credentials(&creds); + assert!(write_result.is_ok(), "Write should succeed with fallback"); + + // Read should read from keychain + let read_result = manager.read_credentials::(); + assert!(read_result.is_ok(), "Read should succeed with fallback"); + + // Verify the read credentials match what we wrote + assert_eq!( + read_result.unwrap(), + creds, + "Read credentials should match written credentials" + ); + } + + #[test] + fn test_no_fallback_not_found() { + // Create a temporary directory for test files + let temp_dir = tempdir().expect("Failed to create temp dir"); + let cred_path = temp_dir.path().join("nonexistent_credentials.json"); + let cred_path_str = cred_path.to_str().unwrap().to_string(); + + // Create a credentials manager with fallback disabled + let manager = CredentialsManager::new( + cred_path_str, + false, // no fallback to disk + "test_service_that_should_not_exist".to_string(), + "test_user_no_fallback".to_string(), + ); + + // Read should fail with NotFound or KeyringError depending on the system + let read_result = manager.read_credentials::(); + println!("{:?}", read_result); + assert!( + read_result.is_err(), + "Read should fail when credentials don't exist" + ); + } + + #[test] + fn test_serialization_error() { + // This test verifies that serialization errors are properly handled + let error = serde_json::from_str::("invalid json").unwrap_err(); + let storage_error = StorageError::SerializationError(error); + assert!(matches!(storage_error, StorageError::SerializationError(_))); + } + + #[test] + fn test_file_system_error_handling() { + // Test handling of file system errors by using an invalid path + let invalid_path = String::from("/nonexistent_directory/credentials.json"); + let manager = CredentialsManager::new( + invalid_path, + true, + "test_service".to_string(), + "test_user".to_string(), + ); + + // Create test credentials + let creds = TestCredentials::new(); + + // Attempt to write to an invalid path should result in FileSystemError + let result = manager.write_to_file(&serde_json::to_string(&creds).unwrap()); + assert!(matches!(result, Err(StorageError::FileSystemError(_)))); + } +} diff --git a/crates/goose-mcp/src/google_drive/token_storage.rs b/crates/goose-mcp/src/google_drive/token_storage.rs deleted file mode 100644 index f40ab4baab76..000000000000 --- a/crates/goose-mcp/src/google_drive/token_storage.rs +++ /dev/null @@ -1,301 +0,0 @@ -use anyhow::Result; -use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; -use keyring::Entry; -use std::env; -use std::fs; -use std::path::Path; -use std::sync::Arc; -use thiserror::Error; -use tracing::{debug, error, warn}; - -const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; -const KEYCHAIN_USERNAME: &str = "oauth_credentials"; -const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; - -#[allow(dead_code)] -#[derive(Error, Debug)] -pub enum AuthError { - #[error("Failed to access keychain: {0}")] - KeyringError(#[from] keyring::Error), - #[error("Failed to access file system: {0}")] - FileSystemError(#[from] std::io::Error), - #[error("No credentials found")] - NotFound, - #[error("Critical error: {0}")] - Critical(String), - #[error("Failed to serialize/deserialize: {0}")] - SerializationError(#[from] serde_json::Error), -} - -/// CredentialsManager handles secure storage of OAuth credentials. -/// It attempts to store credentials in the system keychain first, -/// with fallback to file system storage if keychain access fails and fallback is enabled. -pub struct CredentialsManager { - credentials_path: String, - fallback_to_disk: bool, -} - -impl CredentialsManager { - pub fn new(credentials_path: String) -> Self { - // Check if we should fall back to disk, must be explicitly enabled - let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { - Ok(value) => value.to_lowercase() == "true", - Err(_) => false, - }; - - Self { - credentials_path, - fallback_to_disk, - } - } - - pub fn read_credentials(&self) -> Result { - Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) - .and_then(|entry| entry.get_password()) - .inspect(|_| { - debug!("Successfully read credentials from keychain"); - }) - .or_else(|e| { - if self.fallback_to_disk { - debug!("Falling back to file system due to keyring error: {}", e); - self.read_from_file() - } else { - match e { - keyring::Error::NoEntry => Err(AuthError::NotFound), - _ => Err(AuthError::KeyringError(e)), - } - } - }) - } - - fn read_from_file(&self) -> Result { - let path = Path::new(&self.credentials_path); - if path.exists() { - match fs::read_to_string(path) { - Ok(content) => { - debug!("Successfully read credentials from file system"); - Ok(content) - } - Err(e) => { - error!("Failed to read credentials file: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } else { - debug!("No credentials found in file system"); - Err(AuthError::NotFound) - } - } - - pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { - Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) - .and_then(|entry| entry.set_password(content)) - .inspect(|_| { - debug!("Successfully wrote credentials to keychain"); - }) - .or_else(|e| { - if self.fallback_to_disk { - warn!("Falling back to file system due to keyring error: {}", e); - self.write_to_file(content) - } else { - Err(AuthError::KeyringError(e)) - } - }) - } - - fn write_to_file(&self, content: &str) -> Result<(), AuthError> { - let path = Path::new(&self.credentials_path); - if let Some(parent) = path.parent() { - if !parent.exists() { - match fs::create_dir_all(parent) { - Ok(_) => debug!("Created parent directories for credentials file"), - Err(e) => { - error!("Failed to create directories for credentials file: {}", e); - return Err(AuthError::FileSystemError(e)); - } - } - } - } - - match fs::write(path, content) { - Ok(_) => { - debug!("Successfully wrote credentials to file system"); - Ok(()) - } - Err(e) => { - error!("Failed to write credentials to file system: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } -} - -/// Storage entry that includes the token, scopes and project it's valid for -#[derive(serde::Serialize, serde::Deserialize)] -struct StorageEntry { - token: TokenInfo, - scopes: String, - project_id: String, -} - -/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 -/// to enable secure storage of OAuth tokens in the system keychain. -pub struct KeychainTokenStorage { - project_id: String, - credentials_manager: Arc, -} - -impl KeychainTokenStorage { - /// Create a new KeychainTokenStorage with the given CredentialsManager - pub fn new(project_id: String, credentials_manager: Arc) -> Self { - Self { - project_id, - credentials_manager, - } - } - - fn generate_scoped_key(&self, scopes: &[&str]) -> String { - // Create a key based on the scopes and project_id - // Sort so we can be consistent using scopes as the key - let mut sorted_scopes = scopes.to_vec(); - sorted_scopes.sort(); - sorted_scopes.join(" ") - } -} - -#[async_trait::async_trait] -impl TokenStorage for KeychainTokenStorage { - /// Store a token in the keychain - async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { - let key = self.generate_scoped_key(scopes); - - // Create a storage entry that includes the scopes - let storage_entry = StorageEntry { - token: token_info, - scopes: key, - project_id: self.project_id.clone(), - }; - - let json = serde_json::to_string(&storage_entry)?; - self.credentials_manager - .write_credentials(&json) - .map_err(|e| { - error!("Failed to write token to keychain: {}", e); - anyhow::anyhow!("Failed to write token to keychain: {}", e) - }) - } - - /// Retrieve a token from the keychain - async fn get(&self, scopes: &[&str]) -> Option { - let key = self.generate_scoped_key(scopes); - - match self.credentials_manager.read_credentials() { - Ok(json) => { - debug!("Successfully read credentials from storage"); - match serde_json::from_str::(&json) { - Ok(entry) => { - // Check if token has the requested scopes and matches the project_id - if entry.project_id == self.project_id && entry.scopes == key { - debug!("Successfully retrieved OAuth token from storage"); - Some(entry.token) - } else { - None - } - } - Err(e) => { - warn!("Failed to deserialize token from storage: {}", e); - None - } - } - } - Err(AuthError::NotFound) => { - debug!("No OAuth token found in storage"); - None - } - Err(e) => { - warn!("Error reading OAuth token from storage: {}", e); - None - } - } - } -} - -impl Clone for CredentialsManager { - fn clone(&self) -> Self { - Self { - credentials_path: self.credentials_path.clone(), - fallback_to_disk: self.fallback_to_disk, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serial_test::serial; - use tempfile::NamedTempFile; - - #[tokio::test] - #[serial] - async fn test_token_storage_set_get() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project_1".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; - - // Store the token - storage.set(scopes, token_info.clone()).await.unwrap(); - - // Retrieve the token - let retrieved = storage.get(scopes).await.unwrap(); - - // Verify the token matches - assert_eq!(retrieved.access_token, token_info.access_token); - assert_eq!(retrieved.refresh_token, token_info.refresh_token); - } - - #[tokio::test] - #[serial] - async fn test_token_storage_scope_mismatch() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project_2".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; - let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; - - // Store the token with scopes1 - storage.set(scopes1, token_info).await.unwrap(); - - // Try to retrieve with different scopes - let result = storage.get(scopes2).await; - assert!(result.is_none()); - } -} diff --git a/crates/goose-mcp/src/lib.rs b/crates/goose-mcp/src/lib.rs index 1345d2dd0951..472349f571fd 100644 --- a/crates/goose-mcp/src/lib.rs +++ b/crates/goose-mcp/src/lib.rs @@ -9,7 +9,7 @@ pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { pub mod computercontroller; mod developer; -mod google_drive; +pub mod google_drive; mod jetbrains; mod memory; mod tutorial;