From 824c8a3e35a6f851bb59e53b052366e8e458207c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 14:02:59 -0700 Subject: [PATCH 01/18] feat(google_drive): add keychain support for token storage --- crates/goose-mcp/src/google_drive/auth.rs | 317 ++++++++++++++++++++++ crates/goose-mcp/src/google_drive/mod.rs | 63 +++-- 2 files changed, 358 insertions(+), 22 deletions(-) create mode 100644 crates/goose-mcp/src/google_drive/auth.rs diff --git a/crates/goose-mcp/src/google_drive/auth.rs b/crates/goose-mcp/src/google_drive/auth.rs new file mode 100644 index 000000000000..fa5065baa850 --- /dev/null +++ b/crates/goose-mcp/src/google_drive/auth.rs @@ -0,0 +1,317 @@ +use anyhow::Result; +use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; +use keyring::Entry; +use std::fs; +use std::path::Path; +use std::sync::Arc; +use thiserror::Error; +use tracing::{debug, error, warn}; + +const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; +const KEYCHAIN_USERNAME: &str = "oauth_credentials"; + +#[derive(Error, Debug)] +pub enum AuthError { + #[error("Failed to access keychain: {0}")] + KeyringError(#[from] keyring::Error), + #[error("Failed to access file system: {0}")] + FileSystemError(#[from] std::io::Error), + #[error("No credentials found")] + NotFound, + #[error("Critical error: {0}")] + Critical(String), + #[error("Failed to serialize/deserialize: {0}")] + SerializationError(#[from] serde_json::Error), +} + +/// CredentialsManager handles secure storage of OAuth credentials. +/// It attempts to store credentials in the system keychain first, +/// with fallback to file system storage if keychain access fails. +pub struct CredentialsManager { + credentials_path: String, +} + +impl CredentialsManager { + pub fn new(credentials_path: String) -> Self { + Self { credentials_path } + } + + pub fn read_credentials(&self) -> Result { + // First try to read from keychain + let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { + Ok(entry) => entry, + Err(e) => { + warn!("Failed to create keychain entry: {}", e); + return self.read_from_file(); + } + }; + + match entry.get_password() { + Ok(content) => { + debug!("Successfully read credentials from keychain"); + Ok(content) + } + Err(keyring::Error::NoEntry) => { + debug!("No credentials found in keychain, falling back to file system"); + self.read_from_file() + } + Err(e) => { + // Categorize errors - some might be critical and should not trigger fallback + warn!( + "Non-critical keychain error: {}, falling back to file system", + e + ); + self.read_from_file() + } + } + } + + fn read_from_file(&self) -> Result { + let path = Path::new(&self.credentials_path); + if path.exists() { + match fs::read_to_string(path) { + Ok(content) => { + debug!("Successfully read credentials from file system"); + Ok(content) + } + Err(e) => { + error!("Failed to read credentials file: {}", e); + Err(AuthError::FileSystemError(e)) + } + } + } else { + debug!("No credentials found in file system"); + Err(AuthError::NotFound) + } + } + + pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { + // Try to write to keychain first + let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { + Ok(entry) => entry, + Err(e) => { + warn!("Failed to create keychain entry: {}", e); + return self.write_to_file(content); + } + }; + + match entry.set_password(content) { + Ok(_) => { + debug!("Successfully wrote credentials to keychain"); + Ok(()) + } + Err(e) => { + // Categorize errors - some might be critical and should not trigger fallback + warn!( + "Non-critical keychain error: {}, falling back to file system", + e + ); + self.write_to_file(content) + } + } + } + + fn write_to_file(&self, content: &str) -> Result<(), AuthError> { + let path = Path::new(&self.credentials_path); + if let Some(parent) = path.parent() { + if !parent.exists() { + match fs::create_dir_all(parent) { + Ok(_) => debug!("Created parent directories for credentials file"), + Err(e) => { + error!("Failed to create directories for credentials file: {}", e); + return Err(AuthError::FileSystemError(e)); + } + } + } + } + + match fs::write(path, content) { + Ok(_) => { + debug!("Successfully wrote credentials to file system"); + Ok(()) + } + Err(e) => { + error!("Failed to write credentials to file system: {}", e); + Err(AuthError::FileSystemError(e)) + } + } + } +} + +/// Storage entry that includes both the token and the scopes it's valid for +#[derive(serde::Serialize, serde::Deserialize)] +struct StorageEntry { + token: TokenInfo, + scopes: String, +} + +/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 +/// to enable secure storage of OAuth tokens in the system keychain. +pub struct KeychainTokenStorage { + credentials_manager: Arc, +} + +impl KeychainTokenStorage { + /// Create a new KeychainTokenStorage with the given CredentialsManager + pub fn new(credentials_manager: Arc) -> Self { + Self { + credentials_manager, + } + } +} + +#[async_trait::async_trait] +impl TokenStorage for KeychainTokenStorage { + /// Store a token in the keychain + async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { + debug!("Storing OAuth token in keychain for scopes: {:?}", scopes); + + // Create a key based on the scopes + let scope_key = scopes.join(" "); + + // Create a storage entry that includes the scopes + let storage_entry = StorageEntry { + token: token_info, + scopes: scope_key.clone(), + }; + + let json = serde_json::to_string(&storage_entry)?; + self.credentials_manager + .write_credentials(&json) + .map_err(|e| { + error!("Failed to write token to keychain: {}", e); + anyhow::anyhow!("Failed to write token to keychain: {}", e) + }) + } + + /// Retrieve a token from the keychain + async fn get(&self, scopes: &[&str]) -> Option { + let scope_key = scopes.join(" "); + debug!( + "Retrieving OAuth token from keychain for scopes: {:?}", + scopes + ); + + match self.credentials_manager.read_credentials() { + Ok(json) => { + debug!("Successfully read credentials from storage"); + match serde_json::from_str::(&json) { + Ok(entry) => { + // Check if the stored token has the requested scopes + if entry.scopes == scope_key { + debug!("Successfully retrieved OAuth token from storage"); + Some(entry.token) + } else { + debug!( + "Found token but scopes don't match. Stored: {}, Requested: {}", + entry.scopes, scope_key + ); + None + } + } + Err(e) => { + warn!("Failed to deserialize token from storage: {}", e); + None + } + } + } + Err(AuthError::NotFound) => { + debug!("No OAuth token found in storage"); + None + } + Err(e) => { + warn!("Error reading OAuth token from storage: {}", e); + None + } + } + } +} + +impl Clone for CredentialsManager { + fn clone(&self) -> Self { + Self { + credentials_path: self.credentials_path.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn test_write_read_credentials() { + let temp_file = NamedTempFile::new().unwrap(); + let manager = CredentialsManager::new(temp_file.path().to_string_lossy().to_string()); + + // Write test credentials + let test_content = r#"{"access_token":"test_token","token_type":"Bearer"}"#; + manager.write_credentials(test_content).unwrap(); + + // Read back and verify + let read_content = manager.read_credentials().unwrap(); + assert_eq!(read_content, test_content); + } + + #[tokio::test] + async fn test_token_storage_set_get() { + // Create a temporary file for testing + let temp_file = NamedTempFile::new().unwrap(); + let credentials_manager = Arc::new(CredentialsManager::new( + temp_file.path().to_string_lossy().to_string(), + )); + + let storage = KeychainTokenStorage::new(credentials_manager); + + // Create a test token + let token_info = TokenInfo { + access_token: Some("test_access_token".to_string()), + refresh_token: Some("test_refresh_token".to_string()), + expires_at: None, + id_token: None, + }; + + let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; + + // Store the token + storage.set(scopes, token_info.clone()).await.unwrap(); + + // Retrieve the token + let retrieved = storage.get(scopes).await.unwrap(); + + // Verify the token matches + assert_eq!(retrieved.access_token, token_info.access_token); + assert_eq!(retrieved.refresh_token, token_info.refresh_token); + } + + #[tokio::test] + async fn test_token_storage_scope_mismatch() { + // Create a temporary file for testing + let temp_file = NamedTempFile::new().unwrap(); + let credentials_manager = Arc::new(CredentialsManager::new( + temp_file.path().to_string_lossy().to_string(), + )); + + let storage = KeychainTokenStorage::new(credentials_manager); + + // Create a test token + let token_info = TokenInfo { + access_token: Some("test_access_token".to_string()), + refresh_token: Some("test_refresh_token".to_string()), + expires_at: None, + id_token: None, + }; + + let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; + let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; + + // Store the token with scopes1 + storage.set(scopes1, token_info).await.unwrap(); + + // Try to retrieve with different scopes + let result = storage.get(scopes2).await; + assert!(result.is_none()); + } +} + diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 61028d7a5595..2cfa53ddd86c 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,9 +1,12 @@ +mod auth; + +use auth::{CredentialsManager, KeychainTokenStorage}; use indoc::indoc; use regex::Regex; use serde_json::{json, Value}; +use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin, sync::Arc}; -use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin}; - +use mcp_core::content::Content; use mcp_core::{ handler::{PromptError, ResourceError, ToolError}, prompt::Prompt, @@ -14,8 +17,6 @@ use mcp_core::{ use mcp_server::router::CapabilitiesBuilder; use mcp_server::Router; -use mcp_core::content::Content; - use google_drive3::{ self, api::{File, Scope}, @@ -28,9 +29,7 @@ use google_drive3::{ }, DriveHub, }; - use google_sheets4::{self, Sheets}; - use http_body_util::BodyExt; /// async function to be pinned by the `present_user_url` method of the trait @@ -70,14 +69,15 @@ pub struct GoogleDriveRouter { instructions: String, drive: DriveHub>, sheets: Sheets>, + credentials_manager: CredentialsManager, } impl GoogleDriveRouter { async fn google_auth() -> ( DriveHub>, Sheets>, + CredentialsManager, ) { - let oauth_config = env::var("GOOGLE_DRIVE_OAUTH_CONFIG"); let keyfile_path_str = env::var("GOOGLE_DRIVE_OAUTH_PATH") .unwrap_or_else(|_| "./gcp-oauth.keys.json".to_string()); let credentials_path_str = env::var("GOOGLE_DRIVE_CREDENTIALS_PATH") @@ -87,7 +87,7 @@ impl GoogleDriveRouter { let keyfile_path = Path::new(expanded_keyfile.as_ref()); let expanded_credentials = shellexpand::tilde(credentials_path_str.as_str()); - let credentials_path = Path::new(expanded_credentials.as_ref()); + let credentials_path = expanded_credentials.to_string(); tracing::info!( credentials_path = credentials_path_str, @@ -95,35 +95,47 @@ impl GoogleDriveRouter { "Google Drive MCP server authentication config paths" ); - if !keyfile_path.exists() && oauth_config.is_ok() { - // attempt to create the path - if let Some(parent_dir) = keyfile_path.parent() { - let _ = fs::create_dir_all(parent_dir); - } + // Handle OAuth config from environment variable + if let Ok(oauth_config) = env::var("GOOGLE_DRIVE_OAUTH_CONFIG") { + if !keyfile_path.exists() { + // attempt to create the path + if let Some(parent_dir) = keyfile_path.parent() { + let _ = fs::create_dir_all(parent_dir); + } - if let Ok(mut file) = fs::File::create(keyfile_path) { - let _ = file.write_all(oauth_config.unwrap().as_bytes()); - tracing::debug!( - "Wrote Google Drive MCP server OAuth config to {}", - keyfile_path.display() - ); + if let Ok(mut file) = fs::File::create(keyfile_path) { + let _ = file.write_all(oauth_config.as_bytes()); + tracing::debug!( + "Wrote Google Drive MCP server OAuth config to {}", + keyfile_path.display() + ); + } } } + // Create a credentials manager for storing tokens securely + let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); + + // Create custom token storage using our credentials manager + let token_storage = KeychainTokenStorage::new(credentials_manager.clone()); + + // Read the application secret from the OAuth keyfile let secret = yup_oauth2::read_application_secret(keyfile_path) .await .expect("expected keyfile for google auth"); + // Create the authenticator with the installed flow let auth = InstalledFlowAuthenticator::builder( secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, ) - .persist_tokens_to_disk(credentials_path) + .with_storage(Box::new(token_storage)) // Use our custom storage .flow_delegate(Box::new(LocalhostBrowserDelegate)) .build() .await .expect("expected successful authentication"); + // Create the HTTP client let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) .build( @@ -138,11 +150,16 @@ impl GoogleDriveRouter { let drive_hub = DriveHub::new(client.clone(), auth.clone()); let sheets_hub = Sheets::new(client, auth); - (drive_hub, sheets_hub) + // Create and return the DriveHub + ( + drive_hub, + sheets_hub, + Arc::try_unwrap(credentials_manager).unwrap_or_else(|arc| arc.as_ref().clone()), + ) } pub async fn new() -> Self { - let (drive, sheets) = Self::google_auth().await; + let (drive, sheets, credentials_manager) = Self::google_auth().await; // handle auth let search_tool = Tool::new( @@ -302,6 +319,7 @@ impl GoogleDriveRouter { instructions, drive, sheets, + credentials_manager, } } @@ -851,6 +869,7 @@ impl Clone for GoogleDriveRouter { instructions: self.instructions.clone(), drive: self.drive.clone(), sheets: self.sheets.clone(), + credentials_manager: self.credentials_manager.clone(), } } } From ce7065524d249eabcc57b51cd1aa22482bc1aef1 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 14:03:38 -0700 Subject: [PATCH 02/18] feat(google_drive): validate matching oauth config, use project_id to key secrets --- crates/goose-mcp/src/google_drive/auth.rs | 43 ++++++++++------- crates/goose-mcp/src/google_drive/mod.rs | 57 ++++++++++++++++------- 2 files changed, 65 insertions(+), 35 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/auth.rs b/crates/goose-mcp/src/google_drive/auth.rs index fa5065baa850..37e8ecdaae9f 100644 --- a/crates/goose-mcp/src/google_drive/auth.rs +++ b/crates/goose-mcp/src/google_drive/auth.rs @@ -95,13 +95,13 @@ impl CredentialsManager { } }; + // Fallback to writing on disk if we can't write to the keychain match entry.set_password(content) { Ok(_) => { debug!("Successfully wrote credentials to keychain"); Ok(()) } Err(e) => { - // Categorize errors - some might be critical and should not trigger fallback warn!( "Non-critical keychain error: {}, falling back to file system", e @@ -143,36 +143,46 @@ impl CredentialsManager { struct StorageEntry { token: TokenInfo, scopes: String, + project_id: String, } /// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 /// to enable secure storage of OAuth tokens in the system keychain. pub struct KeychainTokenStorage { + project_id: String, credentials_manager: Arc, } impl KeychainTokenStorage { /// Create a new KeychainTokenStorage with the given CredentialsManager - pub fn new(credentials_manager: Arc) -> Self { + pub fn new(project_id: String, credentials_manager: Arc) -> Self { Self { + project_id, credentials_manager, } } + + fn generate_scoped_key(&self, scopes: &[&str]) -> String { + // Create a key based on the scopes and project_id + let mut sorted_scopes = scopes.to_vec(); + sorted_scopes.sort(); + + sorted_scopes.join(" ") + } } #[async_trait::async_trait] impl TokenStorage for KeychainTokenStorage { /// Store a token in the keychain async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { - debug!("Storing OAuth token in keychain for scopes: {:?}", scopes); - - // Create a key based on the scopes - let scope_key = scopes.join(" "); + let key = self.generate_scoped_key(scopes); + debug!("Storing OAuth token in keychain for scopes: {:?}", key); // Create a storage entry that includes the scopes let storage_entry = StorageEntry { token: token_info, - scopes: scope_key.clone(), + scopes: key, + project_id: self.project_id.clone(), }; let json = serde_json::to_string(&storage_entry)?; @@ -186,11 +196,8 @@ impl TokenStorage for KeychainTokenStorage { /// Retrieve a token from the keychain async fn get(&self, scopes: &[&str]) -> Option { - let scope_key = scopes.join(" "); - debug!( - "Retrieving OAuth token from keychain for scopes: {:?}", - scopes - ); + let key = self.generate_scoped_key(scopes); + debug!("Retrieving OAuth token from keychain for key: {:?}", key); match self.credentials_manager.read_credentials() { Ok(json) => { @@ -198,13 +205,14 @@ impl TokenStorage for KeychainTokenStorage { match serde_json::from_str::(&json) { Ok(entry) => { // Check if the stored token has the requested scopes - if entry.scopes == scope_key { + debug!("{} == {}", entry.project_id, self.project_id); + if entry.project_id == self.project_id && entry.scopes == key { debug!("Successfully retrieved OAuth token from storage"); Some(entry.token) } else { debug!( "Found token but scopes don't match. Stored: {}, Requested: {}", - entry.scopes, scope_key + entry.scopes, key ); None } @@ -258,11 +266,12 @@ mod tests { async fn test_token_storage_set_get() { // Create a temporary file for testing let temp_file = NamedTempFile::new().unwrap(); + let project_id = "test_project".to_string(); let credentials_manager = Arc::new(CredentialsManager::new( temp_file.path().to_string_lossy().to_string(), )); - let storage = KeychainTokenStorage::new(credentials_manager); + let storage = KeychainTokenStorage::new(project_id, credentials_manager); // Create a test token let token_info = TokenInfo { @@ -289,11 +298,12 @@ mod tests { async fn test_token_storage_scope_mismatch() { // Create a temporary file for testing let temp_file = NamedTempFile::new().unwrap(); + let project_id = "test_project".to_string(); let credentials_manager = Arc::new(CredentialsManager::new( temp_file.path().to_string_lossy().to_string(), )); - let storage = KeychainTokenStorage::new(credentials_manager); + let storage = KeychainTokenStorage::new(project_id, credentials_manager); // Create a test token let token_info = TokenInfo { @@ -314,4 +324,3 @@ mod tests { assert!(result.is_none()); } } - diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 2cfa53ddd86c..952777502d3a 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -4,7 +4,7 @@ use auth::{CredentialsManager, KeychainTokenStorage}; use indoc::indoc; use regex::Regex; use serde_json::{json, Value}; -use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin, sync::Arc}; +use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; use mcp_core::content::Content; use mcp_core::{ @@ -69,14 +69,14 @@ pub struct GoogleDriveRouter { instructions: String, drive: DriveHub>, sheets: Sheets>, - credentials_manager: CredentialsManager, + credentials_manager: Arc, } impl GoogleDriveRouter { async fn google_auth() -> ( DriveHub>, Sheets>, - CredentialsManager, + Arc, ) { let keyfile_path_str = env::var("GOOGLE_DRIVE_OAUTH_PATH") .unwrap_or_else(|_| "./gcp-oauth.keys.json".to_string()); @@ -95,16 +95,34 @@ impl GoogleDriveRouter { "Google Drive MCP server authentication config paths" ); - // Handle OAuth config from environment variable if let Ok(oauth_config) = env::var("GOOGLE_DRIVE_OAUTH_CONFIG") { - if !keyfile_path.exists() { - // attempt to create the path - if let Some(parent_dir) = keyfile_path.parent() { - let _ = fs::create_dir_all(parent_dir); + // Ensure the parent directory exists (create_dir_all is idempotent) + if let Some(parent) = keyfile_path.parent() { + if let Err(e) = fs::create_dir_all(parent) { + tracing::error!( + "Failed to create parent directories for {}: {}", + keyfile_path.display(), + e + ); } + } - if let Ok(mut file) = fs::File::create(keyfile_path) { - let _ = file.write_all(oauth_config.as_bytes()); + // Check if the file exists and whether its content matches + // in every other case we attempt to overwrite + let need_to_write = match fs::read_to_string(keyfile_path) { + Ok(existing) if existing == oauth_config => false, + Ok(_) | Err(_) => true, + }; + + // Overwrite the file if needed + if need_to_write { + if let Err(e) = fs::write(keyfile_path, &oauth_config) { + tracing::error!( + "Failed to write OAuth config to {}: {}", + keyfile_path.display(), + e + ); + } else { tracing::debug!( "Wrote Google Drive MCP server OAuth config to {}", keyfile_path.display() @@ -116,14 +134,21 @@ impl GoogleDriveRouter { // Create a credentials manager for storing tokens securely let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); - // Create custom token storage using our credentials manager - let token_storage = KeychainTokenStorage::new(credentials_manager.clone()); - // Read the application secret from the OAuth keyfile let secret = yup_oauth2::read_application_secret(keyfile_path) .await .expect("expected keyfile for google auth"); + // Create custom token storage using our credentials manager + let token_storage = KeychainTokenStorage::new( + secret + .project_id + .clone() + .unwrap_or("unknown-project-id".to_string()) + .to_string(), + credentials_manager.clone(), + ); + // Create the authenticator with the installed flow let auth = InstalledFlowAuthenticator::builder( secret, @@ -151,11 +176,7 @@ impl GoogleDriveRouter { let sheets_hub = Sheets::new(client, auth); // Create and return the DriveHub - ( - drive_hub, - sheets_hub, - Arc::try_unwrap(credentials_manager).unwrap_or_else(|arc| arc.as_ref().clone()), - ) + (drive_hub, sheets_hub, credentials_manager) } pub async fn new() -> Self { From 5218285445871f05c22c2077417656152eaec331 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 14:03:50 -0700 Subject: [PATCH 03/18] feat(google_drive): add GOOGLE_DRIVE_DISK_FALLBACK flag --- crates/goose-mcp/src/google_drive/mod.rs | 4 +- .../{auth.rs => token_storage.rs} | 127 +++++++----------- 2 files changed, 53 insertions(+), 78 deletions(-) rename crates/goose-mcp/src/google_drive/{auth.rs => token_storage.rs} (74%) diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 952777502d3a..1bd545ff2e25 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,10 +1,10 @@ -mod auth; +mod token_storage; -use auth::{CredentialsManager, KeychainTokenStorage}; use indoc::indoc; use regex::Regex; use serde_json::{json, Value}; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; +use token_storage::{CredentialsManager, KeychainTokenStorage}; use mcp_core::content::Content; use mcp_core::{ diff --git a/crates/goose-mcp/src/google_drive/auth.rs b/crates/goose-mcp/src/google_drive/token_storage.rs similarity index 74% rename from crates/goose-mcp/src/google_drive/auth.rs rename to crates/goose-mcp/src/google_drive/token_storage.rs index 37e8ecdaae9f..f40ab4baab76 100644 --- a/crates/goose-mcp/src/google_drive/auth.rs +++ b/crates/goose-mcp/src/google_drive/token_storage.rs @@ -1,6 +1,7 @@ use anyhow::Result; use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; use keyring::Entry; +use std::env; use std::fs; use std::path::Path; use std::sync::Arc; @@ -9,7 +10,9 @@ use tracing::{debug, error, warn}; const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; const KEYCHAIN_USERNAME: &str = "oauth_credentials"; +const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; +#[allow(dead_code)] #[derive(Error, Debug)] pub enum AuthError { #[error("Failed to access keychain: {0}")] @@ -26,44 +29,43 @@ pub enum AuthError { /// CredentialsManager handles secure storage of OAuth credentials. /// It attempts to store credentials in the system keychain first, -/// with fallback to file system storage if keychain access fails. +/// with fallback to file system storage if keychain access fails and fallback is enabled. pub struct CredentialsManager { credentials_path: String, + fallback_to_disk: bool, } impl CredentialsManager { pub fn new(credentials_path: String) -> Self { - Self { credentials_path } + // Check if we should fall back to disk, must be explicitly enabled + let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { + Ok(value) => value.to_lowercase() == "true", + Err(_) => false, + }; + + Self { + credentials_path, + fallback_to_disk, + } } pub fn read_credentials(&self) -> Result { - // First try to read from keychain - let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { - Ok(entry) => entry, - Err(e) => { - warn!("Failed to create keychain entry: {}", e); - return self.read_from_file(); - } - }; - - match entry.get_password() { - Ok(content) => { + Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + .and_then(|entry| entry.get_password()) + .inspect(|_| { debug!("Successfully read credentials from keychain"); - Ok(content) - } - Err(keyring::Error::NoEntry) => { - debug!("No credentials found in keychain, falling back to file system"); - self.read_from_file() - } - Err(e) => { - // Categorize errors - some might be critical and should not trigger fallback - warn!( - "Non-critical keychain error: {}, falling back to file system", - e - ); - self.read_from_file() - } - } + }) + .or_else(|e| { + if self.fallback_to_disk { + debug!("Falling back to file system due to keyring error: {}", e); + self.read_from_file() + } else { + match e { + keyring::Error::NoEntry => Err(AuthError::NotFound), + _ => Err(AuthError::KeyringError(e)), + } + } + }) } fn read_from_file(&self) -> Result { @@ -86,29 +88,19 @@ impl CredentialsManager { } pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { - // Try to write to keychain first - let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { - Ok(entry) => entry, - Err(e) => { - warn!("Failed to create keychain entry: {}", e); - return self.write_to_file(content); - } - }; - - // Fallback to writing on disk if we can't write to the keychain - match entry.set_password(content) { - Ok(_) => { + Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + .and_then(|entry| entry.set_password(content)) + .inspect(|_| { debug!("Successfully wrote credentials to keychain"); - Ok(()) - } - Err(e) => { - warn!( - "Non-critical keychain error: {}, falling back to file system", - e - ); - self.write_to_file(content) - } - } + }) + .or_else(|e| { + if self.fallback_to_disk { + warn!("Falling back to file system due to keyring error: {}", e); + self.write_to_file(content) + } else { + Err(AuthError::KeyringError(e)) + } + }) } fn write_to_file(&self, content: &str) -> Result<(), AuthError> { @@ -138,7 +130,7 @@ impl CredentialsManager { } } -/// Storage entry that includes both the token and the scopes it's valid for +/// Storage entry that includes the token, scopes and project it's valid for #[derive(serde::Serialize, serde::Deserialize)] struct StorageEntry { token: TokenInfo, @@ -164,9 +156,9 @@ impl KeychainTokenStorage { fn generate_scoped_key(&self, scopes: &[&str]) -> String { // Create a key based on the scopes and project_id + // Sort so we can be consistent using scopes as the key let mut sorted_scopes = scopes.to_vec(); sorted_scopes.sort(); - sorted_scopes.join(" ") } } @@ -176,7 +168,6 @@ impl TokenStorage for KeychainTokenStorage { /// Store a token in the keychain async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { let key = self.generate_scoped_key(scopes); - debug!("Storing OAuth token in keychain for scopes: {:?}", key); // Create a storage entry that includes the scopes let storage_entry = StorageEntry { @@ -197,23 +188,17 @@ impl TokenStorage for KeychainTokenStorage { /// Retrieve a token from the keychain async fn get(&self, scopes: &[&str]) -> Option { let key = self.generate_scoped_key(scopes); - debug!("Retrieving OAuth token from keychain for key: {:?}", key); match self.credentials_manager.read_credentials() { Ok(json) => { debug!("Successfully read credentials from storage"); match serde_json::from_str::(&json) { Ok(entry) => { - // Check if the stored token has the requested scopes - debug!("{} == {}", entry.project_id, self.project_id); + // Check if token has the requested scopes and matches the project_id if entry.project_id == self.project_id && entry.scopes == key { debug!("Successfully retrieved OAuth token from storage"); Some(entry.token) } else { - debug!( - "Found token but scopes don't match. Stored: {}, Requested: {}", - entry.scopes, key - ); None } } @@ -239,6 +224,7 @@ impl Clone for CredentialsManager { fn clone(&self) -> Self { Self { credentials_path: self.credentials_path.clone(), + fallback_to_disk: self.fallback_to_disk, } } } @@ -246,27 +232,15 @@ impl Clone for CredentialsManager { #[cfg(test)] mod tests { use super::*; + use serial_test::serial; use tempfile::NamedTempFile; - #[test] - fn test_write_read_credentials() { - let temp_file = NamedTempFile::new().unwrap(); - let manager = CredentialsManager::new(temp_file.path().to_string_lossy().to_string()); - - // Write test credentials - let test_content = r#"{"access_token":"test_token","token_type":"Bearer"}"#; - manager.write_credentials(test_content).unwrap(); - - // Read back and verify - let read_content = manager.read_credentials().unwrap(); - assert_eq!(read_content, test_content); - } - #[tokio::test] + #[serial] async fn test_token_storage_set_get() { // Create a temporary file for testing let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project".to_string(); + let project_id = "test_project_1".to_string(); let credentials_manager = Arc::new(CredentialsManager::new( temp_file.path().to_string_lossy().to_string(), )); @@ -295,10 +269,11 @@ mod tests { } #[tokio::test] + #[serial] async fn test_token_storage_scope_mismatch() { // Create a temporary file for testing let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project".to_string(); + let project_id = "test_project_2".to_string(); let credentials_manager = Arc::new(CredentialsManager::new( temp_file.path().to_string_lossy().to_string(), )); From 9a6d89e0a20b1343b9b4993f087d29e9a0eef930 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 14:41:28 -0700 Subject: [PATCH 04/18] feat(goose-mcp): add keyring crate, same version as goose crate --- Cargo.lock | 1 + crates/goose-mcp/Cargo.toml | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 13b491a198ca..fbb41a55d0d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2414,6 +2414,7 @@ dependencies = [ "image 0.24.9", "include_dir", "indoc", + "keyring", "kill_tree", "lazy_static", "lopdf", diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 602b071fefaf..cd718134f9d9 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -43,7 +43,8 @@ lopdf = "0.35.0" docx-rs = "0.4.7" image = "0.24.9" umya-spreadsheet = "2.2.3" +keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } [dev-dependencies] serial_test = "3.0.0" -sysinfo = "0.32.1" \ No newline at end of file +sysinfo = "0.32.1" From 96d9aa7f68b2e6ad5e8ff275e0082d80b97a89b7 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 14:02:59 -0700 Subject: [PATCH 05/18] feat(google_drive): add keychain support for token storage --- crates/goose-mcp/src/google_drive/auth.rs | 317 ++++++++++++++++++++++ crates/goose-mcp/src/google_drive/mod.rs | 18 +- 2 files changed, 322 insertions(+), 13 deletions(-) create mode 100644 crates/goose-mcp/src/google_drive/auth.rs diff --git a/crates/goose-mcp/src/google_drive/auth.rs b/crates/goose-mcp/src/google_drive/auth.rs new file mode 100644 index 000000000000..fa5065baa850 --- /dev/null +++ b/crates/goose-mcp/src/google_drive/auth.rs @@ -0,0 +1,317 @@ +use anyhow::Result; +use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; +use keyring::Entry; +use std::fs; +use std::path::Path; +use std::sync::Arc; +use thiserror::Error; +use tracing::{debug, error, warn}; + +const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; +const KEYCHAIN_USERNAME: &str = "oauth_credentials"; + +#[derive(Error, Debug)] +pub enum AuthError { + #[error("Failed to access keychain: {0}")] + KeyringError(#[from] keyring::Error), + #[error("Failed to access file system: {0}")] + FileSystemError(#[from] std::io::Error), + #[error("No credentials found")] + NotFound, + #[error("Critical error: {0}")] + Critical(String), + #[error("Failed to serialize/deserialize: {0}")] + SerializationError(#[from] serde_json::Error), +} + +/// CredentialsManager handles secure storage of OAuth credentials. +/// It attempts to store credentials in the system keychain first, +/// with fallback to file system storage if keychain access fails. +pub struct CredentialsManager { + credentials_path: String, +} + +impl CredentialsManager { + pub fn new(credentials_path: String) -> Self { + Self { credentials_path } + } + + pub fn read_credentials(&self) -> Result { + // First try to read from keychain + let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { + Ok(entry) => entry, + Err(e) => { + warn!("Failed to create keychain entry: {}", e); + return self.read_from_file(); + } + }; + + match entry.get_password() { + Ok(content) => { + debug!("Successfully read credentials from keychain"); + Ok(content) + } + Err(keyring::Error::NoEntry) => { + debug!("No credentials found in keychain, falling back to file system"); + self.read_from_file() + } + Err(e) => { + // Categorize errors - some might be critical and should not trigger fallback + warn!( + "Non-critical keychain error: {}, falling back to file system", + e + ); + self.read_from_file() + } + } + } + + fn read_from_file(&self) -> Result { + let path = Path::new(&self.credentials_path); + if path.exists() { + match fs::read_to_string(path) { + Ok(content) => { + debug!("Successfully read credentials from file system"); + Ok(content) + } + Err(e) => { + error!("Failed to read credentials file: {}", e); + Err(AuthError::FileSystemError(e)) + } + } + } else { + debug!("No credentials found in file system"); + Err(AuthError::NotFound) + } + } + + pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { + // Try to write to keychain first + let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { + Ok(entry) => entry, + Err(e) => { + warn!("Failed to create keychain entry: {}", e); + return self.write_to_file(content); + } + }; + + match entry.set_password(content) { + Ok(_) => { + debug!("Successfully wrote credentials to keychain"); + Ok(()) + } + Err(e) => { + // Categorize errors - some might be critical and should not trigger fallback + warn!( + "Non-critical keychain error: {}, falling back to file system", + e + ); + self.write_to_file(content) + } + } + } + + fn write_to_file(&self, content: &str) -> Result<(), AuthError> { + let path = Path::new(&self.credentials_path); + if let Some(parent) = path.parent() { + if !parent.exists() { + match fs::create_dir_all(parent) { + Ok(_) => debug!("Created parent directories for credentials file"), + Err(e) => { + error!("Failed to create directories for credentials file: {}", e); + return Err(AuthError::FileSystemError(e)); + } + } + } + } + + match fs::write(path, content) { + Ok(_) => { + debug!("Successfully wrote credentials to file system"); + Ok(()) + } + Err(e) => { + error!("Failed to write credentials to file system: {}", e); + Err(AuthError::FileSystemError(e)) + } + } + } +} + +/// Storage entry that includes both the token and the scopes it's valid for +#[derive(serde::Serialize, serde::Deserialize)] +struct StorageEntry { + token: TokenInfo, + scopes: String, +} + +/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 +/// to enable secure storage of OAuth tokens in the system keychain. +pub struct KeychainTokenStorage { + credentials_manager: Arc, +} + +impl KeychainTokenStorage { + /// Create a new KeychainTokenStorage with the given CredentialsManager + pub fn new(credentials_manager: Arc) -> Self { + Self { + credentials_manager, + } + } +} + +#[async_trait::async_trait] +impl TokenStorage for KeychainTokenStorage { + /// Store a token in the keychain + async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { + debug!("Storing OAuth token in keychain for scopes: {:?}", scopes); + + // Create a key based on the scopes + let scope_key = scopes.join(" "); + + // Create a storage entry that includes the scopes + let storage_entry = StorageEntry { + token: token_info, + scopes: scope_key.clone(), + }; + + let json = serde_json::to_string(&storage_entry)?; + self.credentials_manager + .write_credentials(&json) + .map_err(|e| { + error!("Failed to write token to keychain: {}", e); + anyhow::anyhow!("Failed to write token to keychain: {}", e) + }) + } + + /// Retrieve a token from the keychain + async fn get(&self, scopes: &[&str]) -> Option { + let scope_key = scopes.join(" "); + debug!( + "Retrieving OAuth token from keychain for scopes: {:?}", + scopes + ); + + match self.credentials_manager.read_credentials() { + Ok(json) => { + debug!("Successfully read credentials from storage"); + match serde_json::from_str::(&json) { + Ok(entry) => { + // Check if the stored token has the requested scopes + if entry.scopes == scope_key { + debug!("Successfully retrieved OAuth token from storage"); + Some(entry.token) + } else { + debug!( + "Found token but scopes don't match. Stored: {}, Requested: {}", + entry.scopes, scope_key + ); + None + } + } + Err(e) => { + warn!("Failed to deserialize token from storage: {}", e); + None + } + } + } + Err(AuthError::NotFound) => { + debug!("No OAuth token found in storage"); + None + } + Err(e) => { + warn!("Error reading OAuth token from storage: {}", e); + None + } + } + } +} + +impl Clone for CredentialsManager { + fn clone(&self) -> Self { + Self { + credentials_path: self.credentials_path.clone(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::NamedTempFile; + + #[test] + fn test_write_read_credentials() { + let temp_file = NamedTempFile::new().unwrap(); + let manager = CredentialsManager::new(temp_file.path().to_string_lossy().to_string()); + + // Write test credentials + let test_content = r#"{"access_token":"test_token","token_type":"Bearer"}"#; + manager.write_credentials(test_content).unwrap(); + + // Read back and verify + let read_content = manager.read_credentials().unwrap(); + assert_eq!(read_content, test_content); + } + + #[tokio::test] + async fn test_token_storage_set_get() { + // Create a temporary file for testing + let temp_file = NamedTempFile::new().unwrap(); + let credentials_manager = Arc::new(CredentialsManager::new( + temp_file.path().to_string_lossy().to_string(), + )); + + let storage = KeychainTokenStorage::new(credentials_manager); + + // Create a test token + let token_info = TokenInfo { + access_token: Some("test_access_token".to_string()), + refresh_token: Some("test_refresh_token".to_string()), + expires_at: None, + id_token: None, + }; + + let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; + + // Store the token + storage.set(scopes, token_info.clone()).await.unwrap(); + + // Retrieve the token + let retrieved = storage.get(scopes).await.unwrap(); + + // Verify the token matches + assert_eq!(retrieved.access_token, token_info.access_token); + assert_eq!(retrieved.refresh_token, token_info.refresh_token); + } + + #[tokio::test] + async fn test_token_storage_scope_mismatch() { + // Create a temporary file for testing + let temp_file = NamedTempFile::new().unwrap(); + let credentials_manager = Arc::new(CredentialsManager::new( + temp_file.path().to_string_lossy().to_string(), + )); + + let storage = KeychainTokenStorage::new(credentials_manager); + + // Create a test token + let token_info = TokenInfo { + access_token: Some("test_access_token".to_string()), + refresh_token: Some("test_refresh_token".to_string()), + expires_at: None, + id_token: None, + }; + + let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; + let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; + + // Store the token with scopes1 + storage.set(scopes1, token_info).await.unwrap(); + + // Try to retrieve with different scopes + let result = storage.get(scopes2).await; + assert!(result.is_none()); + } +} + diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 1bd545ff2e25..f8ce75b6c3b9 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,10 +1,11 @@ +mod auth; mod token_storage; +use auth::{CredentialsManager, KeychainTokenStorage}; use indoc::indoc; use regex::Regex; use serde_json::{json, Value}; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; -use token_storage::{CredentialsManager, KeychainTokenStorage}; use mcp_core::content::Content; use mcp_core::{ @@ -134,22 +135,13 @@ impl GoogleDriveRouter { // Create a credentials manager for storing tokens securely let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); - // Read the application secret from the OAuth keyfile + // Create custom token storage using our credentials manager + let token_storage = KeychainTokenStorage::new(credentials_manager.clone()); + let secret = yup_oauth2::read_application_secret(keyfile_path) .await .expect("expected keyfile for google auth"); - // Create custom token storage using our credentials manager - let token_storage = KeychainTokenStorage::new( - secret - .project_id - .clone() - .unwrap_or("unknown-project-id".to_string()) - .to_string(), - credentials_manager.clone(), - ); - - // Create the authenticator with the installed flow let auth = InstalledFlowAuthenticator::builder( secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, From 3cdec02b3b623d07ead5648aeb2eb96e17440c98 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 14:03:38 -0700 Subject: [PATCH 06/18] feat(google_drive): validate matching oauth config, use project_id to key secrets --- crates/goose-mcp/src/google_drive/auth.rs | 43 ++++++++++++++--------- crates/goose-mcp/src/google_drive/mod.rs | 13 +++++-- 2 files changed, 36 insertions(+), 20 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/auth.rs b/crates/goose-mcp/src/google_drive/auth.rs index fa5065baa850..37e8ecdaae9f 100644 --- a/crates/goose-mcp/src/google_drive/auth.rs +++ b/crates/goose-mcp/src/google_drive/auth.rs @@ -95,13 +95,13 @@ impl CredentialsManager { } }; + // Fallback to writing on disk if we can't write to the keychain match entry.set_password(content) { Ok(_) => { debug!("Successfully wrote credentials to keychain"); Ok(()) } Err(e) => { - // Categorize errors - some might be critical and should not trigger fallback warn!( "Non-critical keychain error: {}, falling back to file system", e @@ -143,36 +143,46 @@ impl CredentialsManager { struct StorageEntry { token: TokenInfo, scopes: String, + project_id: String, } /// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 /// to enable secure storage of OAuth tokens in the system keychain. pub struct KeychainTokenStorage { + project_id: String, credentials_manager: Arc, } impl KeychainTokenStorage { /// Create a new KeychainTokenStorage with the given CredentialsManager - pub fn new(credentials_manager: Arc) -> Self { + pub fn new(project_id: String, credentials_manager: Arc) -> Self { Self { + project_id, credentials_manager, } } + + fn generate_scoped_key(&self, scopes: &[&str]) -> String { + // Create a key based on the scopes and project_id + let mut sorted_scopes = scopes.to_vec(); + sorted_scopes.sort(); + + sorted_scopes.join(" ") + } } #[async_trait::async_trait] impl TokenStorage for KeychainTokenStorage { /// Store a token in the keychain async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { - debug!("Storing OAuth token in keychain for scopes: {:?}", scopes); - - // Create a key based on the scopes - let scope_key = scopes.join(" "); + let key = self.generate_scoped_key(scopes); + debug!("Storing OAuth token in keychain for scopes: {:?}", key); // Create a storage entry that includes the scopes let storage_entry = StorageEntry { token: token_info, - scopes: scope_key.clone(), + scopes: key, + project_id: self.project_id.clone(), }; let json = serde_json::to_string(&storage_entry)?; @@ -186,11 +196,8 @@ impl TokenStorage for KeychainTokenStorage { /// Retrieve a token from the keychain async fn get(&self, scopes: &[&str]) -> Option { - let scope_key = scopes.join(" "); - debug!( - "Retrieving OAuth token from keychain for scopes: {:?}", - scopes - ); + let key = self.generate_scoped_key(scopes); + debug!("Retrieving OAuth token from keychain for key: {:?}", key); match self.credentials_manager.read_credentials() { Ok(json) => { @@ -198,13 +205,14 @@ impl TokenStorage for KeychainTokenStorage { match serde_json::from_str::(&json) { Ok(entry) => { // Check if the stored token has the requested scopes - if entry.scopes == scope_key { + debug!("{} == {}", entry.project_id, self.project_id); + if entry.project_id == self.project_id && entry.scopes == key { debug!("Successfully retrieved OAuth token from storage"); Some(entry.token) } else { debug!( "Found token but scopes don't match. Stored: {}, Requested: {}", - entry.scopes, scope_key + entry.scopes, key ); None } @@ -258,11 +266,12 @@ mod tests { async fn test_token_storage_set_get() { // Create a temporary file for testing let temp_file = NamedTempFile::new().unwrap(); + let project_id = "test_project".to_string(); let credentials_manager = Arc::new(CredentialsManager::new( temp_file.path().to_string_lossy().to_string(), )); - let storage = KeychainTokenStorage::new(credentials_manager); + let storage = KeychainTokenStorage::new(project_id, credentials_manager); // Create a test token let token_info = TokenInfo { @@ -289,11 +298,12 @@ mod tests { async fn test_token_storage_scope_mismatch() { // Create a temporary file for testing let temp_file = NamedTempFile::new().unwrap(); + let project_id = "test_project".to_string(); let credentials_manager = Arc::new(CredentialsManager::new( temp_file.path().to_string_lossy().to_string(), )); - let storage = KeychainTokenStorage::new(credentials_manager); + let storage = KeychainTokenStorage::new(project_id, credentials_manager); // Create a test token let token_info = TokenInfo { @@ -314,4 +324,3 @@ mod tests { assert!(result.is_none()); } } - diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index f8ce75b6c3b9..b619c10a4de9 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -135,13 +135,20 @@ impl GoogleDriveRouter { // Create a credentials manager for storing tokens securely let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); - // Create custom token storage using our credentials manager - let token_storage = KeychainTokenStorage::new(credentials_manager.clone()); - let secret = yup_oauth2::read_application_secret(keyfile_path) .await .expect("expected keyfile for google auth"); + let token_storage = KeychainTokenStorage::new( + secret + .project_id + .clone() + .unwrap_or("unknown-project-id".to_string()) + .to_string(), + credentials_manager.clone(), + ); + + // Create the authenticator with the installed flow let auth = InstalledFlowAuthenticator::builder( secret, yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, From 194bb23f2f5f04a759b9d8fc08cc566a610a305a Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 19:34:11 -0700 Subject: [PATCH 07/18] feat(google_drive): add pkce support using oauth2 crate feat: use exchange_refresh_token method instead of manual impl --- Cargo.lock | 22 ++ crates/goose-mcp/Cargo.toml | 1 + crates/goose-mcp/src/google_drive/mod.rs | 98 ++---- .../goose-mcp/src/google_drive/oauth_pkce.rs | 309 ++++++++++++++++++ 4 files changed, 360 insertions(+), 70 deletions(-) create mode 100644 crates/goose-mcp/src/google_drive/oauth_pkce.rs diff --git a/Cargo.lock b/Cargo.lock index fbb41a55d0d6..3b3eb10d6ed0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2420,6 +2420,7 @@ dependencies = [ "lopdf", "mcp-core", "mcp-server", + "oauth2", "once_cell", "regex", "reqwest 0.11.27", @@ -3920,6 +3921,26 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +[[package]] +name = "oauth2" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f" +dependencies = [ + "base64 0.13.1", + "chrono", + "getrandom 0.2.15", + "http 0.2.12", + "rand", + "reqwest 0.11.27", + "serde", + "serde_json", + "serde_path_to_error", + "sha2", + "thiserror 1.0.69", + "url", +] + [[package]] name = "objc" version = "0.2.7" @@ -6200,6 +6221,7 @@ dependencies = [ "form_urlencoded", "idna", "percent-encoding", + "serde", ] [[package]] diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index cd718134f9d9..5870c336ae76 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -44,6 +44,7 @@ docx-rs = "0.4.7" image = "0.24.9" umya-spreadsheet = "2.2.3" keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } +oauth2 = { version = "4.4.2", features = ["reqwest"] } [dev-dependencies] serial_test = "3.0.0" diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index b619c10a4de9..00efa864d96b 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,11 +1,12 @@ -mod auth; +mod oauth_pkce; mod token_storage; -use auth::{CredentialsManager, KeychainTokenStorage}; use indoc::indoc; +use oauth_pkce::PkceOAuth2Client; use regex::Regex; use serde_json::{json, Value}; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; +use token_storage::CredentialsManager; use mcp_core::content::Content; use mcp_core::{ @@ -23,48 +24,11 @@ use google_drive3::{ api::{File, Scope}, hyper_rustls::{self, HttpsConnector}, hyper_util::{self, client::legacy::connect::HttpConnector}, - yup_oauth2::{ - self, - authenticator_delegate::{DefaultInstalledFlowDelegate, InstalledFlowDelegate}, - InstalledFlowAuthenticator, - }, DriveHub, }; use google_sheets4::{self, Sheets}; use http_body_util::BodyExt; -/// async function to be pinned by the `present_user_url` method of the trait -/// we use the existing `DefaultInstalledFlowDelegate::present_user_url` method as a fallback for -/// when the browser did not open for example, the user still see's the URL. -async fn browser_user_url(url: &str, need_code: bool) -> Result { - tracing::info!(oauth_url = url, "Attempting OAuth login flow"); - if let Err(e) = webbrowser::open(url) { - tracing::debug!(oauth_url = url, error = ?e, "Failed to open OAuth flow"); - println!("Please open this URL in your browser:\n{}", url); - } - let def_delegate = DefaultInstalledFlowDelegate; - def_delegate.present_user_url(url, need_code).await -} - -/// our custom delegate struct we will implement a flow delegate trait for: -/// in this case we will implement the `InstalledFlowDelegated` trait -#[derive(Copy, Clone)] -struct LocalhostBrowserDelegate; - -/// here we implement only the present_user_url method with the added webbrowser opening -/// the other behaviour of the trait does not need to be changed. -impl InstalledFlowDelegate for LocalhostBrowserDelegate { - /// the actual presenting of URL and browser opening happens in the function defined above here - /// we only pin it - fn present_user_url<'a>( - &'a self, - url: &'a str, - need_code: bool, - ) -> Pin> + Send + 'a>> { - Box::pin(browser_user_url(url, need_code)) - } -} - pub struct GoogleDriveRouter { tools: Vec, instructions: String, @@ -135,33 +99,17 @@ impl GoogleDriveRouter { // Create a credentials manager for storing tokens securely let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); - let secret = yup_oauth2::read_application_secret(keyfile_path) - .await - .expect("expected keyfile for google auth"); - - let token_storage = KeychainTokenStorage::new( - secret - .project_id - .clone() - .unwrap_or("unknown-project-id".to_string()) - .to_string(), - credentials_manager.clone(), - ); - - // Create the authenticator with the installed flow - let auth = InstalledFlowAuthenticator::builder( - secret, - yup_oauth2::InstalledFlowReturnMethod::HTTPRedirect, - ) - .with_storage(Box::new(token_storage)) // Use our custom storage - .flow_delegate(Box::new(LocalhostBrowserDelegate)) - .build() - .await - .expect("expected successful authentication"); - - // Create the HTTP client - let client = - hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + // Read the OAuth credentials from the keyfile + match fs::read_to_string(keyfile_path) { + Ok(_) => { + // Create the PKCE OAuth2 client + let auth = PkceOAuth2Client::new(keyfile_path, credentials_manager.clone()) + .expect("Failed to create OAuth2 client"); + + // Create the HTTP client + let client = hyper_util::client::legacy::Client::builder( + hyper_util::rt::TokioExecutor::new(), + ) .build( hyper_rustls::HttpsConnectorBuilder::new() .with_native_roots() @@ -171,11 +119,21 @@ impl GoogleDriveRouter { .build(), ); - let drive_hub = DriveHub::new(client.clone(), auth.clone()); - let sheets_hub = Sheets::new(client, auth); + let drive_hub = DriveHub::new(client.clone(), auth.clone()); + let sheets_hub = Sheets::new(client, auth); - // Create and return the DriveHub - (drive_hub, sheets_hub, credentials_manager) + // Create and return the DriveHub + (drive_hub, sheets_hub, credentials_manager) + } + Err(e) => { + tracing::error!( + "Failed to read OAuth config from {}: {}", + keyfile_path.display(), + e + ); + panic!("Failed to read OAuth config: {}", e); + } + } } pub async fn new() -> Self { diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs new file mode 100644 index 000000000000..2af5ad996fdb --- /dev/null +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -0,0 +1,309 @@ +use std::error::Error; +use std::fs; +use std::future::Future; +use std::io::{BufRead, BufReader, Write}; +use std::net::TcpListener; +use std::path::Path; +use std::pin::Pin; +use std::sync::Arc; + +use google_drive3::common::GetToken; +use oauth2::basic::BasicClient; +use oauth2::{ + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, + RefreshToken, Scope, TokenResponse, TokenUrl, +}; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error, info}; +use url::Url; + +use crate::google_drive::token_storage::CredentialsManager; + +/// Structure representing the OAuth2 configuration file format +#[derive(Debug, Deserialize, Serialize)] +struct OAuth2Config { + installed: InstalledConfig, +} + +#[derive(Debug, Deserialize, Serialize)] +struct InstalledConfig { + client_id: String, + project_id: String, + auth_uri: String, + token_uri: String, + auth_provider_x509_cert_url: String, + client_secret: String, + redirect_uris: Vec, +} + +/// Structure for token storage +#[derive(Debug, Deserialize, Serialize)] +struct TokenData { + access_token: String, + refresh_token: String, + #[serde(skip_serializing_if = "Option::is_none")] + expires_at: Option, +} + +/// PkceOAuth2Client implements the GetToken trait required by DriveHub +/// It uses the oauth2 crate to implement a PKCE-enabled OAuth2 flow +#[derive(Clone)] +pub struct PkceOAuth2Client { + client: BasicClient, + credentials_manager: Arc, + refresh_token: Option, +} + +impl PkceOAuth2Client { + pub fn new( + config_path: impl AsRef, + credentials_manager: Arc, + ) -> Result> { + // Load and parse the config file + let config_content = fs::read_to_string(config_path)?; + let config: OAuth2Config = serde_json::from_str(&config_content)?; + + // Create OAuth URLs + let auth_url = + AuthUrl::new(config.installed.auth_uri).expect("Invalid authorization endpoint URL"); + let token_url = + TokenUrl::new(config.installed.token_uri).expect("Invalid token endpoint URL"); + + // Set up the OAuth2 client + let client = BasicClient::new( + ClientId::new(config.installed.client_id), + Some(ClientSecret::new(config.installed.client_secret)), + auth_url, + Some(token_url), + ) + .set_redirect_uri( + RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"), + ); + + // Try to load a refresh token from storage + let refresh_token = match credentials_manager.read_credentials() { + Ok(json) => match serde_json::from_str::(&json) { + Ok(token_data) => Some(token_data.refresh_token), + Err(e) => { + error!("Failed to parse stored credentials: {}", e); + None + } + }, + Err(e) => { + debug!("No stored credentials found or error reading them: {}", e); + None + } + }; + + Ok(Self { + client, + credentials_manager, + refresh_token, + }) + } + + async fn perform_oauth_flow( + &self, + scopes: &[&str], + ) -> Result> { + // Create a PKCE code verifier and challenge + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + // Generate the authorization URL + let (auth_url, csrf_token) = self + .client + .authorize_url(CsrfToken::new_random) + .add_scopes(scopes.iter().map(|&s| Scope::new(s.to_string()))) + .set_pkce_challenge(pkce_challenge) + .url(); + + info!("Opening browser for OAuth2 authentication"); + if let Err(e) = webbrowser::open(auth_url.as_str()) { + error!("Failed to open browser: {}", e); + println!("Please open this URL in your browser:\n{}\n", auth_url); + } else { + println!( + "A browser window should have opened. If not, please open this URL:\n{}\n", + auth_url + ); + } + + // Start a local server to receive the authorization code + // We'll spawn this in a separate thread since it's blocking + let (tx, rx) = tokio::sync::oneshot::channel(); + std::thread::spawn(move || match Self::start_redirect_server() { + Ok(result) => { + let _ = tx.send(Ok(result)); + } + Err(e) => { + let _ = tx.send(Err(e)); + } + }); + + // Wait for the code from the redirect server + let (code, received_state) = rx.await??; + + // Verify the CSRF state + if received_state.secret() != csrf_token.secret() { + return Err("CSRF token mismatch".into()); + } + + // Use the built-in exchange_code method with PKCE verifier + let token_result = self + .client + .exchange_code(code) + .set_pkce_verifier(pkce_verifier) + .request_async(oauth2::reqwest::async_http_client) + .await + .map_err(|e| Box::new(e) as Box)?; + + let access_token = token_result.access_token().secret().clone(); + + // Store the refresh token for future use if available + if let Some(refresh_token) = token_result.refresh_token() { + let token_data = TokenData { + access_token: access_token.clone(), + refresh_token: refresh_token.secret().clone(), + expires_at: token_result.expires_in().map(|d| d.as_secs()), + }; + + if let Err(e) = self + .credentials_manager + .write_credentials(&serde_json::to_string(&token_data)?) + { + error!("Failed to store refresh token: {}", e); + } else { + debug!("Successfully stored refresh token"); + } + } + + Ok(access_token) + } + + async fn refresh_token( + &self, + refresh_token: &str, + ) -> Result> { + debug!("Attempting to refresh access token"); + + // Create a RefreshToken from the string + let refresh_token = RefreshToken::new(refresh_token.to_string()); + + // Use the built-in exchange_refresh_token method + let token_result = self + .client + .exchange_refresh_token(&refresh_token) + .request_async(oauth2::reqwest::async_http_client) + .await + .map_err(|e| Box::new(e) as Box)?; + + let access_token = token_result.access_token().secret().clone(); + + // Update the stored refresh token if a new one was provided + if let Some(new_refresh_token) = token_result.refresh_token() { + let token_data = TokenData { + access_token: access_token.clone(), + refresh_token: new_refresh_token.secret().clone(), + expires_at: token_result.expires_in().map(|d| d.as_secs()), + }; + + if let Err(e) = self + .credentials_manager + .write_credentials(&serde_json::to_string(&token_data)?) + { + error!("Failed to update refresh token: {}", e); + } else { + debug!("Successfully updated refresh token"); + } + } + + Ok(access_token) + } + + fn start_redirect_server( + ) -> Result<(AuthorizationCode, CsrfToken), Box> { + let listener = TcpListener::bind("127.0.0.1:8080")?; + println!("Listening for the authorization code on http://localhost:8080"); + + for stream in listener.incoming() { + match stream { + Ok(mut stream) => { + let mut reader = BufReader::new(&stream); + let mut request_line = String::new(); + reader.read_line(&mut request_line)?; + + let redirect_url = request_line + .split_whitespace() + .nth(1) + .ok_or("Invalid request")?; + + let url = Url::parse(&format!("http://localhost{}", redirect_url))?; + + let code = url + .query_pairs() + .find(|(key, _)| key == "code") + .map(|(_, value)| AuthorizationCode::new(value.into_owned())) + .ok_or("No code found in the response")?; + + let state = url + .query_pairs() + .find(|(key, _)| key == "state") + .map(|(_, value)| CsrfToken::new(value.into_owned())) + .ok_or("No state found in the response")?; + + // Send a success response to the browser + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\ +

Authentication successful!

\ +

You can now close this window and return to the application.

"; + + stream.write_all(response.as_bytes())?; + stream.flush()?; + + return Ok((code, state)); + } + Err(e) => { + error!("Failed to accept connection: {}", e); + } + } + } + + Err("Failed to receive authorization code".into()) + } +} + +impl GetToken for PkceOAuth2Client { + fn get_token<'a>( + &'a self, + scopes: &'a [&str], + ) -> Pin< + Box, Box>> + Send + 'a>, + > { + Box::pin(async move { + // Try to use refresh token if available + if let Some(refresh_token) = &self.refresh_token { + match self.refresh_token(refresh_token).await { + Ok(access_token) => { + debug!("Successfully refreshed access token"); + return Ok(Some(access_token)); + } + Err(e) => { + error!("Failed to refresh token: {}", e); + // Fall through to interactive flow + } + } + } + + // If refresh failed or no refresh token, do interactive flow + match self.perform_oauth_flow(scopes).await { + Ok(token) => { + debug!("Successfully obtained new access token"); + Ok(Some(token)) + } + Err(e) => { + error!("OAuth flow failed: {}", e); + Err(e) + } + } + }) + } +} From 303c9b1c5ac29f03ab59818c49ccc68ad9802689 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Mon, 10 Mar 2025 21:48:31 -0700 Subject: [PATCH 08/18] fix(google_drive): fix OAuth token persistence use interior mutability using Arc> for refresh_token ensure credentials persist between requests and preventing repeated authentication prompts after successful authorization --- .../goose-mcp/src/google_drive/oauth_pkce.rs | 74 +++++++++++++++++-- 1 file changed, 69 insertions(+), 5 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs index 2af5ad996fdb..a949494bd725 100644 --- a/crates/goose-mcp/src/google_drive/oauth_pkce.rs +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -45,13 +45,15 @@ struct TokenData { expires_at: Option, } +use std::sync::Mutex; + /// PkceOAuth2Client implements the GetToken trait required by DriveHub /// It uses the oauth2 crate to implement a PKCE-enabled OAuth2 flow #[derive(Clone)] pub struct PkceOAuth2Client { client: BasicClient, credentials_manager: Arc, - refresh_token: Option, + refresh_token: Arc>>, } impl PkceOAuth2Client { @@ -98,7 +100,7 @@ impl PkceOAuth2Client { Ok(Self { client, credentials_manager, - refresh_token, + refresh_token: Arc::new(Mutex::new(refresh_token)), }) } @@ -167,6 +169,14 @@ impl PkceOAuth2Client { expires_at: token_result.expires_in().map(|d| d.as_secs()), }; + // Update the in-memory refresh token using the Mutex + if let Ok(mut token_guard) = self.refresh_token.lock() { + *token_guard = Some(refresh_token.secret().clone()); + debug!("Successfully updated in-memory refresh token"); + } else { + error!("Failed to acquire lock on refresh token"); + } + if let Err(e) = self .credentials_manager .write_credentials(&serde_json::to_string(&token_data)?) @@ -207,6 +217,14 @@ impl PkceOAuth2Client { expires_at: token_result.expires_in().map(|d| d.as_secs()), }; + // Update the in-memory refresh token using the Mutex + if let Ok(mut token_guard) = self.refresh_token.lock() { + *token_guard = Some(new_refresh_token.secret().clone()); + debug!("Successfully updated in-memory refresh token during refresh"); + } else { + error!("Failed to acquire lock on refresh token during refresh"); + } + if let Err(e) = self .credentials_manager .write_credentials(&serde_json::to_string(&token_data)?) @@ -279,17 +297,63 @@ impl GetToken for PkceOAuth2Client { Box, Box>> + Send + 'a>, > { Box::pin(async move { - // Try to use refresh token if available - if let Some(refresh_token) = &self.refresh_token { - match self.refresh_token(refresh_token).await { + // Try to use refresh token if available in memory + let refresh_token_option = if let Ok(token_guard) = self.refresh_token.lock() { + token_guard.clone() + } else { + error!("Failed to acquire lock on refresh token"); + None + }; + + if let Some(refresh_token) = refresh_token_option { + debug!("Found refresh token in memory, attempting to use it"); + match self.refresh_token(&refresh_token).await { Ok(access_token) => { debug!("Successfully refreshed access token"); return Ok(Some(access_token)); } Err(e) => { error!("Failed to refresh token: {}", e); + // Fall through to check storage + } + } + } else { + debug!("No refresh token available in memory, checking storage"); + } + + // Try to load from storage as a fallback if in-memory token failed or wasn't available + match self.credentials_manager.read_credentials() { + Ok(json) => match serde_json::from_str::(&json) { + Ok(token_data) => { + debug!("Found token in storage, attempting to use it"); + + // Update the in-memory refresh token + if let Ok(mut token_guard) = self.refresh_token.lock() { + *token_guard = Some(token_data.refresh_token.clone()); + debug!("Updated in-memory refresh token from storage"); + } else { + error!("Failed to acquire lock to update refresh token from storage"); + } + + match self.refresh_token(&token_data.refresh_token).await { + Ok(access_token) => { + debug!("Successfully refreshed access token from storage"); + return Ok(Some(access_token)); + } + Err(e) => { + error!("Failed to refresh token from storage: {}", e); + // Fall through to interactive flow + } + } + } + Err(e) => { + error!("Failed to parse stored credentials: {}", e); // Fall through to interactive flow } + }, + Err(e) => { + debug!("No stored credentials found or error reading them: {}", e); + // Fall through to interactive flow } } From 2af02fcbca46af3f33e4b24e4cf5fd93a64c1378 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 11 Mar 2025 10:08:14 -0700 Subject: [PATCH 09/18] feat(google_drive): update oauth2 crate to 5.0.0 style: refactor for readability and comments --- Cargo.lock | 10 +- crates/goose-mcp/Cargo.toml | 2 +- .../goose-mcp/src/google_drive/oauth_pkce.rs | 217 ++++++++---------- 3 files changed, 101 insertions(+), 128 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3b3eb10d6ed0..0f5194c654a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3923,16 +3923,16 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" [[package]] name = "oauth2" -version = "4.4.2" +version = "5.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c38841cdd844847e3e7c8d29cef9dcfed8877f8f56f9071f77843ecf3baf937f" +checksum = "51e219e79014df21a225b1860a479e2dcd7cbd9130f4defd4bd0e191ea31d67d" dependencies = [ - "base64 0.13.1", + "base64 0.22.1", "chrono", "getrandom 0.2.15", - "http 0.2.12", + "http 1.2.0", "rand", - "reqwest 0.11.27", + "reqwest 0.12.12", "serde", "serde_json", "serde_path_to_error", diff --git a/crates/goose-mcp/Cargo.toml b/crates/goose-mcp/Cargo.toml index 5870c336ae76..7e667578af8e 100644 --- a/crates/goose-mcp/Cargo.toml +++ b/crates/goose-mcp/Cargo.toml @@ -44,7 +44,7 @@ docx-rs = "0.4.7" image = "0.24.9" umya-spreadsheet = "2.2.3" keyring = { version = "3.6.1", features = ["apple-native", "windows-native", "sync-secret-service"] } -oauth2 = { version = "4.4.2", features = ["reqwest"] } +oauth2 = { version = "5.0.0", features = ["reqwest"] } [dev-dependencies] serial_test = "3.0.0" diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs index a949494bd725..ab8f361f66d6 100644 --- a/crates/goose-mcp/src/google_drive/oauth_pkce.rs +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -9,9 +9,10 @@ use std::sync::Arc; use google_drive3::common::GetToken; use oauth2::basic::BasicClient; +use oauth2::reqwest; use oauth2::{ - AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, - RefreshToken, Scope, TokenResponse, TokenUrl, + AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet, EndpointSet, + PkceCodeChallenge, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl, }; use serde::{Deserialize, Serialize}; use tracing::{debug, error, info}; @@ -51,9 +52,10 @@ use std::sync::Mutex; /// It uses the oauth2 crate to implement a PKCE-enabled OAuth2 flow #[derive(Clone)] pub struct PkceOAuth2Client { - client: BasicClient, + client: BasicClient, credentials_manager: Arc, refresh_token: Arc>>, + http_client: reqwest::Client, } impl PkceOAuth2Client { @@ -72,35 +74,38 @@ impl PkceOAuth2Client { TokenUrl::new(config.installed.token_uri).expect("Invalid token endpoint URL"); // Set up the OAuth2 client - let client = BasicClient::new( - ClientId::new(config.installed.client_id), - Some(ClientSecret::new(config.installed.client_secret)), - auth_url, - Some(token_url), - ) - .set_redirect_uri( - RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"), - ); + let client = BasicClient::new(ClientId::new(config.installed.client_id)) + .set_client_secret(ClientSecret::new(config.installed.client_secret)) + .set_auth_uri(auth_url) + .set_token_uri(token_url) + .set_redirect_uri( + RedirectUrl::new("http://localhost:8080".to_string()) + .expect("Invalid redirect URL"), + ); // Try to load a refresh token from storage - let refresh_token = match credentials_manager.read_credentials() { - Ok(json) => match serde_json::from_str::(&json) { - Ok(token_data) => Some(token_data.refresh_token), - Err(e) => { - error!("Failed to parse stored credentials: {}", e); - None - } - }, - Err(e) => { - debug!("No stored credentials found or error reading them: {}", e); - None - } - }; + let refresh_token = credentials_manager + .read_credentials() + .inspect_err(|e| debug!("No stored credentials found or error reading them: {}", e)) + .ok() + .and_then(|json| { + serde_json::from_str::(&json) + .inspect_err(|e| error!("Failed to parse stored credentials: {}", e)) + .ok() + .map(|token_data| token_data.refresh_token) + }); + + let http_client = reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Oauth2 HTTP Client should build"); Ok(Self { client, credentials_manager, refresh_token: Arc::new(Mutex::new(refresh_token)), + http_client, }) } @@ -123,11 +128,6 @@ impl PkceOAuth2Client { if let Err(e) = webbrowser::open(auth_url.as_str()) { error!("Failed to open browser: {}", e); println!("Please open this URL in your browser:\n{}\n", auth_url); - } else { - println!( - "A browser window should have opened. If not, please open this URL:\n{}\n", - auth_url - ); } // Start a local server to receive the authorization code @@ -155,13 +155,14 @@ impl PkceOAuth2Client { .client .exchange_code(code) .set_pkce_verifier(pkce_verifier) - .request_async(oauth2::reqwest::async_http_client) + .request_async(&self.http_client) .await .map_err(|e| Box::new(e) as Box)?; let access_token = token_result.access_token().secret().clone(); - // Store the refresh token for future use if available + // Update the stored refresh token if a new one was provided + // not all authorization servers return a new refresh token if let Some(refresh_token) = token_result.refresh_token() { let token_data = TokenData { access_token: access_token.clone(), @@ -169,22 +170,19 @@ impl PkceOAuth2Client { expires_at: token_result.expires_in().map(|d| d.as_secs()), }; - // Update the in-memory refresh token using the Mutex - if let Ok(mut token_guard) = self.refresh_token.lock() { - *token_guard = Some(refresh_token.secret().clone()); - debug!("Successfully updated in-memory refresh token"); - } else { - error!("Failed to acquire lock on refresh token"); - } - - if let Err(e) = self - .credentials_manager - .write_credentials(&serde_json::to_string(&token_data)?) - { - error!("Failed to store refresh token: {}", e); - } else { - debug!("Successfully stored refresh token"); - } + self.refresh_token + .lock() + .map(|mut token_guard| { + *token_guard = Some(refresh_token.secret().clone()); + debug!("Successfully updated in-memory refresh token"); + }) + .unwrap_or_else(|_| error!("Failed to acquire lock on refresh token")); + + serde_json::to_string(&token_data) + .map_err(Into::into) // Convert serde_json::Error to AuthError + .and_then(|data| self.credentials_manager.write_credentials(&data)) + .map(|_| debug!("Successfully stored refresh token")) + .unwrap_or_else(|e| error!("Failed to store refresh token: {}", e)); } Ok(access_token) @@ -203,36 +201,34 @@ impl PkceOAuth2Client { let token_result = self .client .exchange_refresh_token(&refresh_token) - .request_async(oauth2::reqwest::async_http_client) + .request_async(&self.http_client) .await .map_err(|e| Box::new(e) as Box)?; let access_token = token_result.access_token().secret().clone(); // Update the stored refresh token if a new one was provided - if let Some(new_refresh_token) = token_result.refresh_token() { + // not all authorization servers return a new refresh token + if let Some(refresh_token) = token_result.refresh_token() { let token_data = TokenData { access_token: access_token.clone(), - refresh_token: new_refresh_token.secret().clone(), + refresh_token: refresh_token.secret().clone(), expires_at: token_result.expires_in().map(|d| d.as_secs()), }; - // Update the in-memory refresh token using the Mutex - if let Ok(mut token_guard) = self.refresh_token.lock() { - *token_guard = Some(new_refresh_token.secret().clone()); - debug!("Successfully updated in-memory refresh token during refresh"); - } else { - error!("Failed to acquire lock on refresh token during refresh"); - } - - if let Err(e) = self - .credentials_manager - .write_credentials(&serde_json::to_string(&token_data)?) - { - error!("Failed to update refresh token: {}", e); - } else { - debug!("Successfully updated refresh token"); - } + self.refresh_token + .lock() + .map(|mut token_guard| { + *token_guard = Some(refresh_token.secret().clone()); + debug!("Successfully updated in-memory refresh token"); + }) + .unwrap_or_else(|_| error!("Failed to acquire lock on refresh token")); + + serde_json::to_string(&token_data) + .map_err(Into::into) // Convert serde_json::Error to AuthError + .and_then(|data| self.credentials_manager.write_credentials(&data)) + .map(|_| debug!("Successfully stored refresh token")) + .unwrap_or_else(|e| error!("Failed to store refresh token: {}", e)); } Ok(access_token) @@ -297,70 +293,47 @@ impl GetToken for PkceOAuth2Client { Box, Box>> + Send + 'a>, > { Box::pin(async move { - // Try to use refresh token if available in memory - let refresh_token_option = if let Ok(token_guard) = self.refresh_token.lock() { - token_guard.clone() - } else { - error!("Failed to acquire lock on refresh token"); - None - }; - - if let Some(refresh_token) = refresh_token_option { - debug!("Found refresh token in memory, attempting to use it"); - match self.refresh_token(&refresh_token).await { - Ok(access_token) => { - debug!("Successfully refreshed access token"); - return Ok(Some(access_token)); - } - Err(e) => { - error!("Failed to refresh token: {}", e); - // Fall through to check storage - } + // Attempt to get token from memory + let token_from_memory = self + .refresh_token + .lock() + .ok() + .and_then(|guard| guard.clone()); + + // In error cases we just fall through to checking storage + if let Some(ref token) = token_from_memory { + if let Ok(access_token) = self.refresh_token(token).await { + debug!("Successfully refreshed access token from memory"); + return Ok(Some(access_token)); } - } else { - debug!("No refresh token available in memory, checking storage"); } - // Try to load from storage as a fallback if in-memory token failed or wasn't available - match self.credentials_manager.read_credentials() { - Ok(json) => match serde_json::from_str::(&json) { - Ok(token_data) => { - debug!("Found token in storage, attempting to use it"); - - // Update the in-memory refresh token - if let Ok(mut token_guard) = self.refresh_token.lock() { - *token_guard = Some(token_data.refresh_token.clone()); - debug!("Updated in-memory refresh token from storage"); - } else { - error!("Failed to acquire lock to update refresh token from storage"); - } - - match self.refresh_token(&token_data.refresh_token).await { - Ok(access_token) => { - debug!("Successfully refreshed access token from storage"); - return Ok(Some(access_token)); - } - Err(e) => { - error!("Failed to refresh token from storage: {}", e); - // Fall through to interactive flow - } - } - } - Err(e) => { - error!("Failed to parse stored credentials: {}", e); - // Fall through to interactive flow + // Attempt to read token from storage and update in-memory cache + let token_from_storage = self + .credentials_manager + .read_credentials() + .ok() + .and_then(|json| serde_json::from_str::(&json).ok()) + .map(|token_data| { + if let Ok(mut token_guard) = self.refresh_token.lock() { + *token_guard = Some(token_data.refresh_token.clone()); + debug!("Updated in-memory refresh token from storage"); } - }, - Err(e) => { - debug!("No stored credentials found or error reading them: {}", e); - // Fall through to interactive flow + token_data.refresh_token + }); + + // If we fail to use the refresh token here, fall through to full OAuth flow + if let Some(ref token) = token_from_storage { + if let Ok(access_token) = self.refresh_token(token).await { + debug!("Successfully refreshed access token from storage"); + return Ok(Some(access_token)); } } - // If refresh failed or no refresh token, do interactive flow + // Fallback: perform interactive OAuth flow match self.perform_oauth_flow(scopes).await { Ok(token) => { - debug!("Successfully obtained new access token"); + debug!("Successfully obtained new access token through OAuth flow"); Ok(Some(token)) } Err(e) => { From aad940cb4236ffbefbac95656cc4a30ec29fb8d9 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 11 Mar 2025 11:28:30 -0700 Subject: [PATCH 10/18] feat: make token_storage generic --- .../goose-mcp/src/google_drive/oauth_pkce.rs | 23 +- .../src/google_drive/token_storage.rs | 281 +++++++----------- 2 files changed, 110 insertions(+), 194 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs index ab8f361f66d6..60ea1f1d5f2d 100644 --- a/crates/goose-mcp/src/google_drive/oauth_pkce.rs +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -85,15 +85,10 @@ impl PkceOAuth2Client { // Try to load a refresh token from storage let refresh_token = credentials_manager - .read_credentials() + .read_credentials::() .inspect_err(|e| debug!("No stored credentials found or error reading them: {}", e)) .ok() - .and_then(|json| { - serde_json::from_str::(&json) - .inspect_err(|e| error!("Failed to parse stored credentials: {}", e)) - .ok() - .map(|token_data| token_data.refresh_token) - }); + .map(|token_data| token_data.refresh_token); let http_client = reqwest::ClientBuilder::new() // Following redirects opens the client up to SSRF vulnerabilities. @@ -178,9 +173,8 @@ impl PkceOAuth2Client { }) .unwrap_or_else(|_| error!("Failed to acquire lock on refresh token")); - serde_json::to_string(&token_data) - .map_err(Into::into) // Convert serde_json::Error to AuthError - .and_then(|data| self.credentials_manager.write_credentials(&data)) + self.credentials_manager + .write_credentials(&token_data) .map(|_| debug!("Successfully stored refresh token")) .unwrap_or_else(|e| error!("Failed to store refresh token: {}", e)); } @@ -224,9 +218,8 @@ impl PkceOAuth2Client { }) .unwrap_or_else(|_| error!("Failed to acquire lock on refresh token")); - serde_json::to_string(&token_data) - .map_err(Into::into) // Convert serde_json::Error to AuthError - .and_then(|data| self.credentials_manager.write_credentials(&data)) + self.credentials_manager + .write_credentials(&token_data) .map(|_| debug!("Successfully stored refresh token")) .unwrap_or_else(|e| error!("Failed to store refresh token: {}", e)); } @@ -311,9 +304,8 @@ impl GetToken for PkceOAuth2Client { // Attempt to read token from storage and update in-memory cache let token_from_storage = self .credentials_manager - .read_credentials() + .read_credentials::() .ok() - .and_then(|json| serde_json::from_str::(&json).ok()) .map(|token_data| { if let Ok(mut token_guard) = self.refresh_token.lock() { *token_guard = Some(token_data.refresh_token.clone()); @@ -344,3 +336,4 @@ impl GetToken for PkceOAuth2Client { }) } } + diff --git a/crates/goose-mcp/src/google_drive/token_storage.rs b/crates/goose-mcp/src/google_drive/token_storage.rs index f40ab4baab76..e5a6408d7a6b 100644 --- a/crates/goose-mcp/src/google_drive/token_storage.rs +++ b/crates/goose-mcp/src/google_drive/token_storage.rs @@ -1,10 +1,9 @@ use anyhow::Result; -use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; use keyring::Entry; +use serde::{de::DeserializeOwned, Serialize}; use std::env; use std::fs; use std::path::Path; -use std::sync::Arc; use thiserror::Error; use tracing::{debug, error, warn}; @@ -14,7 +13,7 @@ const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; #[allow(dead_code)] #[derive(Error, Debug)] -pub enum AuthError { +pub enum StorageError { #[error("Failed to access keychain: {0}")] KeyringError(#[from] keyring::Error), #[error("Failed to access file system: {0}")] @@ -49,8 +48,44 @@ impl CredentialsManager { } } - pub fn read_credentials(&self) -> Result { - Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + /// Reads and deserializes credentials from secure storage. + /// + /// This method attempts to read credentials from the system keychain first. + /// If keychain access fails and fallback is enabled, it will try to read from the file system. + /// + /// # Type Parameters + /// + /// * `T` - The type to deserialize the credentials into. Must implement `serde::de::DeserializeOwned`. + /// + /// # Returns + /// + /// * `Ok(T)` - The deserialized credentials + /// * `Err(StorageError)` - If reading or deserialization fails + /// + /// # Examples + /// + /// ``` + /// # use goose_mcp::google_drive::token_storage::CredentialsManager; + /// use serde::{Serialize, Deserialize}; + /// + /// #[derive(Serialize, Deserialize)] + /// struct OAuthToken { + /// access_token: String, + /// refresh_token: String, + /// expiry: u64, + /// } + /// + /// let manager = CredentialsManager::new(String::from("/path/to/credentials.json")); + /// match manager.read_credentials::() { + /// Ok(token) => println!("Token expires at: {}", token.expiry), + /// Err(e) => eprintln!("Failed to read token: {}", e), + /// } + /// ``` + pub fn read_credentials(&self) -> Result + where + T: DeserializeOwned, + { + let json_str = Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) .and_then(|entry| entry.get_password()) .inspect(|_| { debug!("Successfully read credentials from keychain"); @@ -61,14 +96,16 @@ impl CredentialsManager { self.read_from_file() } else { match e { - keyring::Error::NoEntry => Err(AuthError::NotFound), - _ => Err(AuthError::KeyringError(e)), + keyring::Error::NoEntry => Err(StorageError::NotFound), + _ => Err(StorageError::KeyringError(e)), } } - }) + })?; + + serde_json::from_str(&json_str).map_err(StorageError::SerializationError) } - fn read_from_file(&self) -> Result { + fn read_from_file(&self) -> Result { let path = Path::new(&self.credentials_path); if path.exists() { match fs::read_to_string(path) { @@ -78,32 +115,79 @@ impl CredentialsManager { } Err(e) => { error!("Failed to read credentials file: {}", e); - Err(AuthError::FileSystemError(e)) + Err(StorageError::FileSystemError(e)) } } } else { debug!("No credentials found in file system"); - Err(AuthError::NotFound) + Err(StorageError::NotFound) } } - pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { + /// Serializes and writes credentials to secure storage. + /// + /// This method attempts to write credentials to the system keychain first. + /// If keychain access fails and fallback is enabled, it will try to write to the file system. + /// + /// # Type Parameters + /// + /// * `T` - The type to serialize. Must implement `serde::Serialize`. + /// + /// # Parameters + /// + /// * `content` - The data to serialize and store + /// + /// # Returns + /// + /// * `Ok(())` - If writing succeeds + /// * `Err(StorageError)` - If serialization or writing fails + /// + /// # Examples + /// + /// ``` + /// # use goose_mcp::google_drive::token_storage::CredentialsManager; + /// use serde::{Serialize, Deserialize}; + /// + /// #[derive(Serialize, Deserialize)] + /// struct OAuthToken { + /// access_token: String, + /// refresh_token: String, + /// expiry: u64, + /// } + /// + /// let token = OAuthToken { + /// access_token: String::from("access_token_value"), + /// refresh_token: String::from("refresh_token_value"), + /// expiry: 1672531200, // Unix timestamp + /// }; + /// + /// let manager = CredentialsManager::new(String::from("/path/to/credentials.json")); + /// if let Err(e) = manager.write_credentials(&token) { + /// eprintln!("Failed to write token: {}", e); + /// } + /// ``` + pub fn write_credentials(&self, content: &T) -> Result<(), StorageError> + where + T: Serialize, + { + let json_str = serde_json::to_string(content).map_err(StorageError::SerializationError)?; + Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) - .and_then(|entry| entry.set_password(content)) + .and_then(|entry| entry.set_password(&json_str)) .inspect(|_| { debug!("Successfully wrote credentials to keychain"); }) .or_else(|e| { if self.fallback_to_disk { warn!("Falling back to file system due to keyring error: {}", e); - self.write_to_file(content) + self.write_to_file(&json_str) } else { - Err(AuthError::KeyringError(e)) + Err(StorageError::KeyringError(e)) } }) } - fn write_to_file(&self, content: &str) -> Result<(), AuthError> { + fn write_to_file(&self, content: &str) -> Result<(), StorageError> { let path = Path::new(&self.credentials_path); if let Some(parent) = path.parent() { if !parent.exists() { @@ -111,7 +195,7 @@ impl CredentialsManager { Ok(_) => debug!("Created parent directories for credentials file"), Err(e) => { error!("Failed to create directories for credentials file: {}", e); - return Err(AuthError::FileSystemError(e)); + return Err(StorageError::FileSystemError(e)); } } } @@ -124,97 +208,7 @@ impl CredentialsManager { } Err(e) => { error!("Failed to write credentials to file system: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } -} - -/// Storage entry that includes the token, scopes and project it's valid for -#[derive(serde::Serialize, serde::Deserialize)] -struct StorageEntry { - token: TokenInfo, - scopes: String, - project_id: String, -} - -/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 -/// to enable secure storage of OAuth tokens in the system keychain. -pub struct KeychainTokenStorage { - project_id: String, - credentials_manager: Arc, -} - -impl KeychainTokenStorage { - /// Create a new KeychainTokenStorage with the given CredentialsManager - pub fn new(project_id: String, credentials_manager: Arc) -> Self { - Self { - project_id, - credentials_manager, - } - } - - fn generate_scoped_key(&self, scopes: &[&str]) -> String { - // Create a key based on the scopes and project_id - // Sort so we can be consistent using scopes as the key - let mut sorted_scopes = scopes.to_vec(); - sorted_scopes.sort(); - sorted_scopes.join(" ") - } -} - -#[async_trait::async_trait] -impl TokenStorage for KeychainTokenStorage { - /// Store a token in the keychain - async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { - let key = self.generate_scoped_key(scopes); - - // Create a storage entry that includes the scopes - let storage_entry = StorageEntry { - token: token_info, - scopes: key, - project_id: self.project_id.clone(), - }; - - let json = serde_json::to_string(&storage_entry)?; - self.credentials_manager - .write_credentials(&json) - .map_err(|e| { - error!("Failed to write token to keychain: {}", e); - anyhow::anyhow!("Failed to write token to keychain: {}", e) - }) - } - - /// Retrieve a token from the keychain - async fn get(&self, scopes: &[&str]) -> Option { - let key = self.generate_scoped_key(scopes); - - match self.credentials_manager.read_credentials() { - Ok(json) => { - debug!("Successfully read credentials from storage"); - match serde_json::from_str::(&json) { - Ok(entry) => { - // Check if token has the requested scopes and matches the project_id - if entry.project_id == self.project_id && entry.scopes == key { - debug!("Successfully retrieved OAuth token from storage"); - Some(entry.token) - } else { - None - } - } - Err(e) => { - warn!("Failed to deserialize token from storage: {}", e); - None - } - } - } - Err(AuthError::NotFound) => { - debug!("No OAuth token found in storage"); - None - } - Err(e) => { - warn!("Error reading OAuth token from storage: {}", e); - None + Err(StorageError::FileSystemError(e)) } } } @@ -228,74 +222,3 @@ impl Clone for CredentialsManager { } } } - -#[cfg(test)] -mod tests { - use super::*; - use serial_test::serial; - use tempfile::NamedTempFile; - - #[tokio::test] - #[serial] - async fn test_token_storage_set_get() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project_1".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; - - // Store the token - storage.set(scopes, token_info.clone()).await.unwrap(); - - // Retrieve the token - let retrieved = storage.get(scopes).await.unwrap(); - - // Verify the token matches - assert_eq!(retrieved.access_token, token_info.access_token); - assert_eq!(retrieved.refresh_token, token_info.refresh_token); - } - - #[tokio::test] - #[serial] - async fn test_token_storage_scope_mismatch() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project_2".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; - let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; - - // Store the token with scopes1 - storage.set(scopes1, token_info).await.unwrap(); - - // Try to retrieve with different scopes - let result = storage.get(scopes2).await; - assert!(result.is_none()); - } -} From 673f92b6ac9ad22d82acf96b31154248b6101a48 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 11 Mar 2025 11:53:11 -0700 Subject: [PATCH 11/18] feat(google_drive): make CredentialsManager generic over T: Serialize rename token_storage to storage since it is generic enough --- crates/goose-mcp/src/google_drive/mod.rs | 4 +- .../goose-mcp/src/google_drive/oauth_pkce.rs | 53 ++++++++++++++----- .../{token_storage.rs => storage.rs} | 0 3 files changed, 43 insertions(+), 14 deletions(-) rename crates/goose-mcp/src/google_drive/{token_storage.rs => storage.rs} (100%) diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 00efa864d96b..1af202f1452f 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,12 +1,12 @@ mod oauth_pkce; -mod token_storage; +mod storage; use indoc::indoc; use oauth_pkce::PkceOAuth2Client; use regex::Regex; use serde_json::{json, Value}; use std::{env, fs, future::Future, path::Path, pin::Pin, sync::Arc}; -use token_storage::CredentialsManager; +use storage::CredentialsManager; use mcp_core::content::Content; use mcp_core::{ diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs index 60ea1f1d5f2d..64e746bf27a2 100644 --- a/crates/goose-mcp/src/google_drive/oauth_pkce.rs +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -18,7 +18,7 @@ use serde::{Deserialize, Serialize}; use tracing::{debug, error, info}; use url::Url; -use crate::google_drive::token_storage::CredentialsManager; +use super::storage::CredentialsManager; /// Structure representing the OAuth2 configuration file format #[derive(Debug, Deserialize, Serialize)] @@ -44,6 +44,7 @@ struct TokenData { refresh_token: String, #[serde(skip_serializing_if = "Option::is_none")] expires_at: Option, + project_id: String, } use std::sync::Mutex; @@ -56,6 +57,7 @@ pub struct PkceOAuth2Client { credentials_manager: Arc, refresh_token: Arc>>, http_client: reqwest::Client, + project_id: String, } impl PkceOAuth2Client { @@ -67,6 +69,9 @@ impl PkceOAuth2Client { let config_content = fs::read_to_string(config_path)?; let config: OAuth2Config = serde_json::from_str(&config_content)?; + // Extract the project_id from the config + let project_id = config.installed.project_id.clone(); + // Create OAuth URLs let auth_url = AuthUrl::new(config.installed.auth_uri).expect("Invalid authorization endpoint URL"); @@ -84,11 +89,24 @@ impl PkceOAuth2Client { ); // Try to load a refresh token from storage - let refresh_token = credentials_manager - .read_credentials::() - .inspect_err(|e| debug!("No stored credentials found or error reading them: {}", e)) - .ok() - .map(|token_data| token_data.refresh_token); + let refresh_token = match credentials_manager.read_credentials::() { + Ok(token_data) => { + // Verify the project_id matches + if token_data.project_id != project_id { + debug!( + "Project ID mismatch: stored={}, current={}. Discarding stored credentials.", + token_data.project_id, project_id + ); + None // Don't use these credentials if project_id doesn't match + } else { + Some(token_data.refresh_token) + } + } + Err(e) => { + debug!("No stored credentials found or error reading them: {}", e); + None + } + }; let http_client = reqwest::ClientBuilder::new() // Following redirects opens the client up to SSRF vulnerabilities. @@ -101,6 +119,7 @@ impl PkceOAuth2Client { credentials_manager, refresh_token: Arc::new(Mutex::new(refresh_token)), http_client, + project_id, }) } @@ -163,6 +182,7 @@ impl PkceOAuth2Client { access_token: access_token.clone(), refresh_token: refresh_token.secret().clone(), expires_at: token_result.expires_in().map(|d| d.as_secs()), + project_id: self.project_id.clone(), }; self.refresh_token @@ -208,6 +228,7 @@ impl PkceOAuth2Client { access_token: access_token.clone(), refresh_token: refresh_token.secret().clone(), expires_at: token_result.expires_in().map(|d| d.as_secs()), + project_id: self.project_id.clone(), }; self.refresh_token @@ -306,12 +327,21 @@ impl GetToken for PkceOAuth2Client { .credentials_manager .read_credentials::() .ok() - .map(|token_data| { - if let Ok(mut token_guard) = self.refresh_token.lock() { - *token_guard = Some(token_data.refresh_token.clone()); - debug!("Updated in-memory refresh token from storage"); + .and_then(|token_data| { + // Verify the project_id matches + if token_data.project_id != self.project_id { + debug!( + "Project ID mismatch: stored={}, current={}. Discarding stored credentials.", + token_data.project_id, self.project_id + ); + None // Don't use these credentials if project_id doesn't match + } else { + if let Ok(mut token_guard) = self.refresh_token.lock() { + *token_guard = Some(token_data.refresh_token.clone()); + debug!("Updated in-memory refresh token from storage"); + } + Some(token_data.refresh_token) } - token_data.refresh_token }); // If we fail to use the refresh token here, fall through to full OAuth flow @@ -336,4 +366,3 @@ impl GetToken for PkceOAuth2Client { }) } } - diff --git a/crates/goose-mcp/src/google_drive/token_storage.rs b/crates/goose-mcp/src/google_drive/storage.rs similarity index 100% rename from crates/goose-mcp/src/google_drive/token_storage.rs rename to crates/goose-mcp/src/google_drive/storage.rs From 7308705a349a070e1bea1726c15e78f11b6985b5 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 11 Mar 2025 12:09:21 -0700 Subject: [PATCH 12/18] feat(google_drive): update TokenData to use epoch for expiration time --- .../goose-mcp/src/google_drive/oauth_pkce.rs | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs index 64e746bf27a2..9c42f4a7d515 100644 --- a/crates/goose-mcp/src/google_drive/oauth_pkce.rs +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -6,6 +6,7 @@ use std::net::TcpListener; use std::path::Path; use std::pin::Pin; use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; use google_drive3::common::GetToken; use oauth2::basic::BasicClient; @@ -178,10 +179,19 @@ impl PkceOAuth2Client { // Update the stored refresh token if a new one was provided // not all authorization servers return a new refresh token if let Some(refresh_token) = token_result.refresh_token() { + // Calculate expires_at as a Unix timestamp by adding expires_in to current time + let expires_at = token_result.expires_in().map(|duration| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + now + duration.as_secs() + }); + let token_data = TokenData { access_token: access_token.clone(), refresh_token: refresh_token.secret().clone(), - expires_at: token_result.expires_in().map(|d| d.as_secs()), + expires_at, project_id: self.project_id.clone(), }; @@ -224,10 +234,19 @@ impl PkceOAuth2Client { // Update the stored refresh token if a new one was provided // not all authorization servers return a new refresh token if let Some(refresh_token) = token_result.refresh_token() { + // Calculate expires_at as a Unix timestamp by adding expires_in to current time + let expires_at = token_result.expires_in().map(|duration| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + now + duration.as_secs() + }); + let token_data = TokenData { access_token: access_token.clone(), refresh_token: refresh_token.secret().clone(), - expires_at: token_result.expires_in().map(|d| d.as_secs()), + expires_at, project_id: self.project_id.clone(), }; From 927f94bc5e133ef495b582ce16663a07bfbfa81c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 11 Mar 2025 13:42:01 -0700 Subject: [PATCH 13/18] feat(google_drive): add token expiration check, remove Arc simplify logic and remove in memory refresh token, add expiration check to reuse any valid tokens --- .../goose-mcp/src/google_drive/oauth_pkce.rs | 200 +++++++----------- 1 file changed, 82 insertions(+), 118 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/oauth_pkce.rs b/crates/goose-mcp/src/google_drive/oauth_pkce.rs index 9c42f4a7d515..1da7380aa252 100644 --- a/crates/goose-mcp/src/google_drive/oauth_pkce.rs +++ b/crates/goose-mcp/src/google_drive/oauth_pkce.rs @@ -48,15 +48,12 @@ struct TokenData { project_id: String, } -use std::sync::Mutex; - /// PkceOAuth2Client implements the GetToken trait required by DriveHub /// It uses the oauth2 crate to implement a PKCE-enabled OAuth2 flow #[derive(Clone)] pub struct PkceOAuth2Client { client: BasicClient, credentials_manager: Arc, - refresh_token: Arc>>, http_client: reqwest::Client, project_id: String, } @@ -85,30 +82,10 @@ impl PkceOAuth2Client { .set_auth_uri(auth_url) .set_token_uri(token_url) .set_redirect_uri( - RedirectUrl::new("http://localhost:8080".to_string()) + RedirectUrl::new("http://localhost:18080".to_string()) .expect("Invalid redirect URL"), ); - // Try to load a refresh token from storage - let refresh_token = match credentials_manager.read_credentials::() { - Ok(token_data) => { - // Verify the project_id matches - if token_data.project_id != project_id { - debug!( - "Project ID mismatch: stored={}, current={}. Discarding stored credentials.", - token_data.project_id, project_id - ); - None // Don't use these credentials if project_id doesn't match - } else { - Some(token_data.refresh_token) - } - } - Err(e) => { - debug!("No stored credentials found or error reading them: {}", e); - None - } - }; - let http_client = reqwest::ClientBuilder::new() // Following redirects opens the client up to SSRF vulnerabilities. .redirect(reqwest::redirect::Policy::none()) @@ -118,12 +95,25 @@ impl PkceOAuth2Client { Ok(Self { client, credentials_manager, - refresh_token: Arc::new(Mutex::new(refresh_token)), http_client, project_id, }) } + /// Check if a token is expired or about to expire within the buffer period + fn is_token_expired(&self, expires_at: Option, buffer_seconds: u64) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + + // Consider the token expired if it's within buffer_seconds of expiring + // This gives us a safety margin to avoid using tokens right before expiration + expires_at + .map(|expiry_time| now + buffer_seconds >= expiry_time) + .unwrap_or(true) // If we don't know when it expires, assume it's expired to be safe + } + async fn perform_oauth_flow( &self, scopes: &[&str], @@ -176,37 +166,34 @@ impl PkceOAuth2Client { let access_token = token_result.access_token().secret().clone(); - // Update the stored refresh token if a new one was provided - // not all authorization servers return a new refresh token + // Calculate expires_at as a Unix timestamp by adding expires_in to current time + let expires_at = token_result.expires_in().map(|duration| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + now + duration.as_secs() + }); + + // Get the refresh token if provided if let Some(refresh_token) = token_result.refresh_token() { - // Calculate expires_at as a Unix timestamp by adding expires_in to current time - let expires_at = token_result.expires_in().map(|duration| { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs(); - now + duration.as_secs() - }); + let refresh_token_str = refresh_token.secret().clone(); + // Store token data let token_data = TokenData { access_token: access_token.clone(), - refresh_token: refresh_token.secret().clone(), + refresh_token: refresh_token_str.clone(), expires_at, project_id: self.project_id.clone(), }; - self.refresh_token - .lock() - .map(|mut token_guard| { - *token_guard = Some(refresh_token.secret().clone()); - debug!("Successfully updated in-memory refresh token"); - }) - .unwrap_or_else(|_| error!("Failed to acquire lock on refresh token")); - + // Store updated token data self.credentials_manager .write_credentials(&token_data) - .map(|_| debug!("Successfully stored refresh token")) - .unwrap_or_else(|e| error!("Failed to store refresh token: {}", e)); + .map(|_| debug!("Successfully stored token data")) + .unwrap_or_else(|e| error!("Failed to store token data: {}", e)); + } else { + debug!("No refresh token provided in OAuth flow response"); } Ok(access_token) @@ -231,46 +218,42 @@ impl PkceOAuth2Client { let access_token = token_result.access_token().secret().clone(); - // Update the stored refresh token if a new one was provided - // not all authorization servers return a new refresh token - if let Some(refresh_token) = token_result.refresh_token() { - // Calculate expires_at as a Unix timestamp by adding expires_in to current time - let expires_at = token_result.expires_in().map(|duration| { - let now = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs(); - now + duration.as_secs() - }); - - let token_data = TokenData { - access_token: access_token.clone(), - refresh_token: refresh_token.secret().clone(), - expires_at, - project_id: self.project_id.clone(), - }; + // Calculate expires_at as a Unix timestamp by adding expires_in to current time + let expires_at = token_result.expires_in().map(|duration| { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time went backwards") + .as_secs(); + now + duration.as_secs() + }); - self.refresh_token - .lock() - .map(|mut token_guard| { - *token_guard = Some(refresh_token.secret().clone()); - debug!("Successfully updated in-memory refresh token"); - }) - .unwrap_or_else(|_| error!("Failed to acquire lock on refresh token")); + // Get the refresh token - either the new one or reuse the existing one + let new_refresh_token = token_result + .refresh_token() + .map(|token| token.secret().clone()) + .unwrap_or_else(|| refresh_token.secret().to_string()); + + // Always update the token data with the new access token and expiration + let token_data = TokenData { + access_token: access_token.clone(), + refresh_token: new_refresh_token.clone(), + expires_at, + project_id: self.project_id.clone(), + }; - self.credentials_manager - .write_credentials(&token_data) - .map(|_| debug!("Successfully stored refresh token")) - .unwrap_or_else(|e| error!("Failed to store refresh token: {}", e)); - } + // Store updated token data + self.credentials_manager + .write_credentials(&token_data) + .map(|_| debug!("Successfully stored token data")) + .unwrap_or_else(|e| error!("Failed to store token data: {}", e)); Ok(access_token) } fn start_redirect_server( ) -> Result<(AuthorizationCode, CsrfToken), Box> { - let listener = TcpListener::bind("127.0.0.1:8080")?; - println!("Listening for the authorization code on http://localhost:8080"); + let listener = TcpListener::bind("127.0.0.1:18080")?; + println!("Listening for the authorization code on http://localhost:18080"); for stream in listener.incoming() { match stream { @@ -318,6 +301,8 @@ impl PkceOAuth2Client { } } +// impl GetToken for use with DriveHub directly +// see google_drive3::common::GetToken impl GetToken for PkceOAuth2Client { fn get_token<'a>( &'a self, @@ -326,51 +311,30 @@ impl GetToken for PkceOAuth2Client { Box, Box>> + Send + 'a>, > { Box::pin(async move { - // Attempt to get token from memory - let token_from_memory = self - .refresh_token - .lock() - .ok() - .and_then(|guard| guard.clone()); - - // In error cases we just fall through to checking storage - if let Some(ref token) = token_from_memory { - if let Ok(access_token) = self.refresh_token(token).await { - debug!("Successfully refreshed access token from memory"); - return Ok(Some(access_token)); - } - } - - // Attempt to read token from storage and update in-memory cache - let token_from_storage = self - .credentials_manager - .read_credentials::() - .ok() - .and_then(|token_data| { - // Verify the project_id matches - if token_data.project_id != self.project_id { - debug!( - "Project ID mismatch: stored={}, current={}. Discarding stored credentials.", - token_data.project_id, self.project_id - ); - None // Don't use these credentials if project_id doesn't match - } else { - if let Ok(mut token_guard) = self.refresh_token.lock() { - *token_guard = Some(token_data.refresh_token.clone()); - debug!("Updated in-memory refresh token from storage"); - } - Some(token_data.refresh_token) + // Try to read token data from storage to check if we have a valid token + if let Ok(token_data) = self.credentials_manager.read_credentials::() { + // Verify the project_id matches + if token_data.project_id == self.project_id { + // Check if the token is expired or expiring within a 5-min buffer + if !self.is_token_expired(token_data.expires_at, 300) { + return Ok(Some(token_data.access_token)); } - }); - // If we fail to use the refresh token here, fall through to full OAuth flow - if let Some(ref token) = token_from_storage { - if let Ok(access_token) = self.refresh_token(token).await { - debug!("Successfully refreshed access token from storage"); - return Ok(Some(access_token)); + // Token is expired or will expire soon, try to refresh it + debug!("Token is expired or will expire soon, refreshing..."); + + // Try to refresh the token + if let Ok(access_token) = self.refresh_token(&token_data.refresh_token).await { + debug!("Successfully refreshed access token"); + return Ok(Some(access_token)); + } } } + // If we get here, either: + // 1. The project ID didn't match + // 2. Token refresh failed + // 3. There are no valid tokens yet // Fallback: perform interactive OAuth flow match self.perform_oauth_flow(scopes).await { Ok(token) => { From 3739135a15d09ececb7e425d31f6ec5e342ef97c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 12 Mar 2025 09:00:57 -0700 Subject: [PATCH 14/18] fix: add priority to search response --- crates/goose-mcp/src/google_drive/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 1af202f1452f..f6d9b74229db 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -387,7 +387,7 @@ impl GoogleDriveRouter { .collect::>() .join("\n"); - Ok(vec![Content::text(content.to_string())]) + Ok(vec![Content::text(content.to_string()).with_priority(0.3)]) } } } From f0f24b16a87f88c8eeff92559d3b41b2357e11f5 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 12 Mar 2025 09:38:45 -0700 Subject: [PATCH 15/18] chore: remove files from rebase --- crates/goose-mcp/src/google_drive/auth.rs | 326 ------------------ .../src/google_drive/token_storage.rs | 301 ---------------- 2 files changed, 627 deletions(-) delete mode 100644 crates/goose-mcp/src/google_drive/auth.rs delete mode 100644 crates/goose-mcp/src/google_drive/token_storage.rs diff --git a/crates/goose-mcp/src/google_drive/auth.rs b/crates/goose-mcp/src/google_drive/auth.rs deleted file mode 100644 index 37e8ecdaae9f..000000000000 --- a/crates/goose-mcp/src/google_drive/auth.rs +++ /dev/null @@ -1,326 +0,0 @@ -use anyhow::Result; -use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; -use keyring::Entry; -use std::fs; -use std::path::Path; -use std::sync::Arc; -use thiserror::Error; -use tracing::{debug, error, warn}; - -const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; -const KEYCHAIN_USERNAME: &str = "oauth_credentials"; - -#[derive(Error, Debug)] -pub enum AuthError { - #[error("Failed to access keychain: {0}")] - KeyringError(#[from] keyring::Error), - #[error("Failed to access file system: {0}")] - FileSystemError(#[from] std::io::Error), - #[error("No credentials found")] - NotFound, - #[error("Critical error: {0}")] - Critical(String), - #[error("Failed to serialize/deserialize: {0}")] - SerializationError(#[from] serde_json::Error), -} - -/// CredentialsManager handles secure storage of OAuth credentials. -/// It attempts to store credentials in the system keychain first, -/// with fallback to file system storage if keychain access fails. -pub struct CredentialsManager { - credentials_path: String, -} - -impl CredentialsManager { - pub fn new(credentials_path: String) -> Self { - Self { credentials_path } - } - - pub fn read_credentials(&self) -> Result { - // First try to read from keychain - let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { - Ok(entry) => entry, - Err(e) => { - warn!("Failed to create keychain entry: {}", e); - return self.read_from_file(); - } - }; - - match entry.get_password() { - Ok(content) => { - debug!("Successfully read credentials from keychain"); - Ok(content) - } - Err(keyring::Error::NoEntry) => { - debug!("No credentials found in keychain, falling back to file system"); - self.read_from_file() - } - Err(e) => { - // Categorize errors - some might be critical and should not trigger fallback - warn!( - "Non-critical keychain error: {}, falling back to file system", - e - ); - self.read_from_file() - } - } - } - - fn read_from_file(&self) -> Result { - let path = Path::new(&self.credentials_path); - if path.exists() { - match fs::read_to_string(path) { - Ok(content) => { - debug!("Successfully read credentials from file system"); - Ok(content) - } - Err(e) => { - error!("Failed to read credentials file: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } else { - debug!("No credentials found in file system"); - Err(AuthError::NotFound) - } - } - - pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { - // Try to write to keychain first - let entry = match Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) { - Ok(entry) => entry, - Err(e) => { - warn!("Failed to create keychain entry: {}", e); - return self.write_to_file(content); - } - }; - - // Fallback to writing on disk if we can't write to the keychain - match entry.set_password(content) { - Ok(_) => { - debug!("Successfully wrote credentials to keychain"); - Ok(()) - } - Err(e) => { - warn!( - "Non-critical keychain error: {}, falling back to file system", - e - ); - self.write_to_file(content) - } - } - } - - fn write_to_file(&self, content: &str) -> Result<(), AuthError> { - let path = Path::new(&self.credentials_path); - if let Some(parent) = path.parent() { - if !parent.exists() { - match fs::create_dir_all(parent) { - Ok(_) => debug!("Created parent directories for credentials file"), - Err(e) => { - error!("Failed to create directories for credentials file: {}", e); - return Err(AuthError::FileSystemError(e)); - } - } - } - } - - match fs::write(path, content) { - Ok(_) => { - debug!("Successfully wrote credentials to file system"); - Ok(()) - } - Err(e) => { - error!("Failed to write credentials to file system: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } -} - -/// Storage entry that includes both the token and the scopes it's valid for -#[derive(serde::Serialize, serde::Deserialize)] -struct StorageEntry { - token: TokenInfo, - scopes: String, - project_id: String, -} - -/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 -/// to enable secure storage of OAuth tokens in the system keychain. -pub struct KeychainTokenStorage { - project_id: String, - credentials_manager: Arc, -} - -impl KeychainTokenStorage { - /// Create a new KeychainTokenStorage with the given CredentialsManager - pub fn new(project_id: String, credentials_manager: Arc) -> Self { - Self { - project_id, - credentials_manager, - } - } - - fn generate_scoped_key(&self, scopes: &[&str]) -> String { - // Create a key based on the scopes and project_id - let mut sorted_scopes = scopes.to_vec(); - sorted_scopes.sort(); - - sorted_scopes.join(" ") - } -} - -#[async_trait::async_trait] -impl TokenStorage for KeychainTokenStorage { - /// Store a token in the keychain - async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { - let key = self.generate_scoped_key(scopes); - debug!("Storing OAuth token in keychain for scopes: {:?}", key); - - // Create a storage entry that includes the scopes - let storage_entry = StorageEntry { - token: token_info, - scopes: key, - project_id: self.project_id.clone(), - }; - - let json = serde_json::to_string(&storage_entry)?; - self.credentials_manager - .write_credentials(&json) - .map_err(|e| { - error!("Failed to write token to keychain: {}", e); - anyhow::anyhow!("Failed to write token to keychain: {}", e) - }) - } - - /// Retrieve a token from the keychain - async fn get(&self, scopes: &[&str]) -> Option { - let key = self.generate_scoped_key(scopes); - debug!("Retrieving OAuth token from keychain for key: {:?}", key); - - match self.credentials_manager.read_credentials() { - Ok(json) => { - debug!("Successfully read credentials from storage"); - match serde_json::from_str::(&json) { - Ok(entry) => { - // Check if the stored token has the requested scopes - debug!("{} == {}", entry.project_id, self.project_id); - if entry.project_id == self.project_id && entry.scopes == key { - debug!("Successfully retrieved OAuth token from storage"); - Some(entry.token) - } else { - debug!( - "Found token but scopes don't match. Stored: {}, Requested: {}", - entry.scopes, key - ); - None - } - } - Err(e) => { - warn!("Failed to deserialize token from storage: {}", e); - None - } - } - } - Err(AuthError::NotFound) => { - debug!("No OAuth token found in storage"); - None - } - Err(e) => { - warn!("Error reading OAuth token from storage: {}", e); - None - } - } - } -} - -impl Clone for CredentialsManager { - fn clone(&self) -> Self { - Self { - credentials_path: self.credentials_path.clone(), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use tempfile::NamedTempFile; - - #[test] - fn test_write_read_credentials() { - let temp_file = NamedTempFile::new().unwrap(); - let manager = CredentialsManager::new(temp_file.path().to_string_lossy().to_string()); - - // Write test credentials - let test_content = r#"{"access_token":"test_token","token_type":"Bearer"}"#; - manager.write_credentials(test_content).unwrap(); - - // Read back and verify - let read_content = manager.read_credentials().unwrap(); - assert_eq!(read_content, test_content); - } - - #[tokio::test] - async fn test_token_storage_set_get() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; - - // Store the token - storage.set(scopes, token_info.clone()).await.unwrap(); - - // Retrieve the token - let retrieved = storage.get(scopes).await.unwrap(); - - // Verify the token matches - assert_eq!(retrieved.access_token, token_info.access_token); - assert_eq!(retrieved.refresh_token, token_info.refresh_token); - } - - #[tokio::test] - async fn test_token_storage_scope_mismatch() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; - let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; - - // Store the token with scopes1 - storage.set(scopes1, token_info).await.unwrap(); - - // Try to retrieve with different scopes - let result = storage.get(scopes2).await; - assert!(result.is_none()); - } -} diff --git a/crates/goose-mcp/src/google_drive/token_storage.rs b/crates/goose-mcp/src/google_drive/token_storage.rs deleted file mode 100644 index f40ab4baab76..000000000000 --- a/crates/goose-mcp/src/google_drive/token_storage.rs +++ /dev/null @@ -1,301 +0,0 @@ -use anyhow::Result; -use google_drive3::yup_oauth2::storage::{TokenInfo, TokenStorage}; -use keyring::Entry; -use std::env; -use std::fs; -use std::path::Path; -use std::sync::Arc; -use thiserror::Error; -use tracing::{debug, error, warn}; - -const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; -const KEYCHAIN_USERNAME: &str = "oauth_credentials"; -const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; - -#[allow(dead_code)] -#[derive(Error, Debug)] -pub enum AuthError { - #[error("Failed to access keychain: {0}")] - KeyringError(#[from] keyring::Error), - #[error("Failed to access file system: {0}")] - FileSystemError(#[from] std::io::Error), - #[error("No credentials found")] - NotFound, - #[error("Critical error: {0}")] - Critical(String), - #[error("Failed to serialize/deserialize: {0}")] - SerializationError(#[from] serde_json::Error), -} - -/// CredentialsManager handles secure storage of OAuth credentials. -/// It attempts to store credentials in the system keychain first, -/// with fallback to file system storage if keychain access fails and fallback is enabled. -pub struct CredentialsManager { - credentials_path: String, - fallback_to_disk: bool, -} - -impl CredentialsManager { - pub fn new(credentials_path: String) -> Self { - // Check if we should fall back to disk, must be explicitly enabled - let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { - Ok(value) => value.to_lowercase() == "true", - Err(_) => false, - }; - - Self { - credentials_path, - fallback_to_disk, - } - } - - pub fn read_credentials(&self) -> Result { - Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) - .and_then(|entry| entry.get_password()) - .inspect(|_| { - debug!("Successfully read credentials from keychain"); - }) - .or_else(|e| { - if self.fallback_to_disk { - debug!("Falling back to file system due to keyring error: {}", e); - self.read_from_file() - } else { - match e { - keyring::Error::NoEntry => Err(AuthError::NotFound), - _ => Err(AuthError::KeyringError(e)), - } - } - }) - } - - fn read_from_file(&self) -> Result { - let path = Path::new(&self.credentials_path); - if path.exists() { - match fs::read_to_string(path) { - Ok(content) => { - debug!("Successfully read credentials from file system"); - Ok(content) - } - Err(e) => { - error!("Failed to read credentials file: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } else { - debug!("No credentials found in file system"); - Err(AuthError::NotFound) - } - } - - pub fn write_credentials(&self, content: &str) -> Result<(), AuthError> { - Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) - .and_then(|entry| entry.set_password(content)) - .inspect(|_| { - debug!("Successfully wrote credentials to keychain"); - }) - .or_else(|e| { - if self.fallback_to_disk { - warn!("Falling back to file system due to keyring error: {}", e); - self.write_to_file(content) - } else { - Err(AuthError::KeyringError(e)) - } - }) - } - - fn write_to_file(&self, content: &str) -> Result<(), AuthError> { - let path = Path::new(&self.credentials_path); - if let Some(parent) = path.parent() { - if !parent.exists() { - match fs::create_dir_all(parent) { - Ok(_) => debug!("Created parent directories for credentials file"), - Err(e) => { - error!("Failed to create directories for credentials file: {}", e); - return Err(AuthError::FileSystemError(e)); - } - } - } - } - - match fs::write(path, content) { - Ok(_) => { - debug!("Successfully wrote credentials to file system"); - Ok(()) - } - Err(e) => { - error!("Failed to write credentials to file system: {}", e); - Err(AuthError::FileSystemError(e)) - } - } - } -} - -/// Storage entry that includes the token, scopes and project it's valid for -#[derive(serde::Serialize, serde::Deserialize)] -struct StorageEntry { - token: TokenInfo, - scopes: String, - project_id: String, -} - -/// KeychainTokenStorage implements the TokenStorage trait from yup_oauth2 -/// to enable secure storage of OAuth tokens in the system keychain. -pub struct KeychainTokenStorage { - project_id: String, - credentials_manager: Arc, -} - -impl KeychainTokenStorage { - /// Create a new KeychainTokenStorage with the given CredentialsManager - pub fn new(project_id: String, credentials_manager: Arc) -> Self { - Self { - project_id, - credentials_manager, - } - } - - fn generate_scoped_key(&self, scopes: &[&str]) -> String { - // Create a key based on the scopes and project_id - // Sort so we can be consistent using scopes as the key - let mut sorted_scopes = scopes.to_vec(); - sorted_scopes.sort(); - sorted_scopes.join(" ") - } -} - -#[async_trait::async_trait] -impl TokenStorage for KeychainTokenStorage { - /// Store a token in the keychain - async fn set(&self, scopes: &[&str], token_info: TokenInfo) -> Result<()> { - let key = self.generate_scoped_key(scopes); - - // Create a storage entry that includes the scopes - let storage_entry = StorageEntry { - token: token_info, - scopes: key, - project_id: self.project_id.clone(), - }; - - let json = serde_json::to_string(&storage_entry)?; - self.credentials_manager - .write_credentials(&json) - .map_err(|e| { - error!("Failed to write token to keychain: {}", e); - anyhow::anyhow!("Failed to write token to keychain: {}", e) - }) - } - - /// Retrieve a token from the keychain - async fn get(&self, scopes: &[&str]) -> Option { - let key = self.generate_scoped_key(scopes); - - match self.credentials_manager.read_credentials() { - Ok(json) => { - debug!("Successfully read credentials from storage"); - match serde_json::from_str::(&json) { - Ok(entry) => { - // Check if token has the requested scopes and matches the project_id - if entry.project_id == self.project_id && entry.scopes == key { - debug!("Successfully retrieved OAuth token from storage"); - Some(entry.token) - } else { - None - } - } - Err(e) => { - warn!("Failed to deserialize token from storage: {}", e); - None - } - } - } - Err(AuthError::NotFound) => { - debug!("No OAuth token found in storage"); - None - } - Err(e) => { - warn!("Error reading OAuth token from storage: {}", e); - None - } - } - } -} - -impl Clone for CredentialsManager { - fn clone(&self) -> Self { - Self { - credentials_path: self.credentials_path.clone(), - fallback_to_disk: self.fallback_to_disk, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use serial_test::serial; - use tempfile::NamedTempFile; - - #[tokio::test] - #[serial] - async fn test_token_storage_set_get() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project_1".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes = &["https://www.googleapis.com/auth/drive.readonly"]; - - // Store the token - storage.set(scopes, token_info.clone()).await.unwrap(); - - // Retrieve the token - let retrieved = storage.get(scopes).await.unwrap(); - - // Verify the token matches - assert_eq!(retrieved.access_token, token_info.access_token); - assert_eq!(retrieved.refresh_token, token_info.refresh_token); - } - - #[tokio::test] - #[serial] - async fn test_token_storage_scope_mismatch() { - // Create a temporary file for testing - let temp_file = NamedTempFile::new().unwrap(); - let project_id = "test_project_2".to_string(); - let credentials_manager = Arc::new(CredentialsManager::new( - temp_file.path().to_string_lossy().to_string(), - )); - - let storage = KeychainTokenStorage::new(project_id, credentials_manager); - - // Create a test token - let token_info = TokenInfo { - access_token: Some("test_access_token".to_string()), - refresh_token: Some("test_refresh_token".to_string()), - expires_at: None, - id_token: None, - }; - - let scopes1 = &["https://www.googleapis.com/auth/drive.readonly"]; - let scopes2 = &["https://www.googleapis.com/auth/drive.file"]; - - // Store the token with scopes1 - storage.set(scopes1, token_info).await.unwrap(); - - // Try to retrieve with different scopes - let result = storage.get(scopes2).await; - assert!(result.is_none()); - } -} From 30ab2a5fc26a3df837935d8b6b0fa9accd0f9d35 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 12 Mar 2025 09:58:48 -0700 Subject: [PATCH 16/18] feat(google_drive): make storage more generic, move google_drive vars to mod.rs this allows easy re-use of storage.rs, and keeps all google_drive specific variables in the mod.rs file test: add basic storage tests --- crates/goose-mcp/src/google_drive/mod.rs | 24 ++- crates/goose-mcp/src/google_drive/storage.rs | 160 ++++++++++++++++--- crates/goose-mcp/src/lib.rs | 2 +- 3 files changed, 162 insertions(+), 24 deletions(-) diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index ab36ad74e43e..82cfe9775bd4 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -1,5 +1,5 @@ mod oauth_pkce; -mod storage; +pub mod storage; use indoc::indoc; use oauth_pkce::PkceOAuth2Client; @@ -29,6 +29,11 @@ use google_drive3::{ use google_sheets4::{self, Sheets}; use http_body_util::BodyExt; +// Constants for credential storage +pub const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; +pub const KEYCHAIN_USERNAME: &str = "oauth_credentials"; +pub const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; + pub struct GoogleDriveRouter { tools: Vec, instructions: String, @@ -95,8 +100,20 @@ impl GoogleDriveRouter { } } } + + // Check if we should fall back to disk, must be explicitly enabled + let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { + Ok(value) => value.to_lowercase() == "true", + Err(_) => false, + }; + // Create a credentials manager for storing tokens securely - let credentials_manager = Arc::new(CredentialsManager::new(credentials_path.clone())); + let credentials_manager = Arc::new(CredentialsManager::new( + credentials_path.clone(), + fallback_to_disk, + KEYCHAIN_SERVICE.to_string(), + KEYCHAIN_USERNAME.to_string(), + )); // Read the OAuth credentials from the keyfile match fs::read_to_string(keyfile_path) { @@ -104,6 +121,7 @@ impl GoogleDriveRouter { // Create the PKCE OAuth2 client let auth = PkceOAuth2Client::new(keyfile_path, credentials_manager.clone()) .expect("Failed to create OAuth2 client"); + // Create the HTTP client let client = hyper_util::client::legacy::Client::builder( hyper_util::rt::TokioExecutor::new(), @@ -120,7 +138,7 @@ impl GoogleDriveRouter { let drive_hub = DriveHub::new(client.clone(), auth.clone()); let sheets_hub = Sheets::new(client, auth); - // Create and return the DriveHub + // Create and return the DriveHub, Sheets and our PKCE OAuth2 client (drive_hub, sheets_hub, credentials_manager) } Err(e) => { diff --git a/crates/goose-mcp/src/google_drive/storage.rs b/crates/goose-mcp/src/google_drive/storage.rs index e5a6408d7a6b..8e8f3c08dec3 100644 --- a/crates/goose-mcp/src/google_drive/storage.rs +++ b/crates/goose-mcp/src/google_drive/storage.rs @@ -1,16 +1,11 @@ use anyhow::Result; use keyring::Entry; use serde::{de::DeserializeOwned, Serialize}; -use std::env; use std::fs; use std::path::Path; use thiserror::Error; use tracing::{debug, error, warn}; -const KEYCHAIN_SERVICE: &str = "mcp_google_drive"; -const KEYCHAIN_USERNAME: &str = "oauth_credentials"; -const KEYCHAIN_DISK_FALLBACK_ENV: &str = "GOOGLE_DRIVE_DISK_FALLBACK"; - #[allow(dead_code)] #[derive(Error, Debug)] pub enum StorageError { @@ -32,19 +27,22 @@ pub enum StorageError { pub struct CredentialsManager { credentials_path: String, fallback_to_disk: bool, + keychain_service: String, + keychain_username: String, } impl CredentialsManager { - pub fn new(credentials_path: String) -> Self { - // Check if we should fall back to disk, must be explicitly enabled - let fallback_to_disk = match env::var(KEYCHAIN_DISK_FALLBACK_ENV) { - Ok(value) => value.to_lowercase() == "true", - Err(_) => false, - }; - + pub fn new( + credentials_path: String, + fallback_to_disk: bool, + keychain_service: String, + keychain_username: String, + ) -> Self { Self { credentials_path, fallback_to_disk, + keychain_service, + keychain_username, } } @@ -64,8 +62,8 @@ impl CredentialsManager { /// /// # Examples /// - /// ``` - /// # use goose_mcp::google_drive::token_storage::CredentialsManager; + /// ```no_run + /// # use goose_mcp::google_drive::storage::CredentialsManager; /// use serde::{Serialize, Deserialize}; /// /// #[derive(Serialize, Deserialize)] @@ -75,7 +73,12 @@ impl CredentialsManager { /// expiry: u64, /// } /// - /// let manager = CredentialsManager::new(String::from("/path/to/credentials.json")); + /// let manager = CredentialsManager::new( + /// String::from("/path/to/credentials.json"), + /// true, // fallback to disk if keychain fails + /// String::from("test_service"), + /// String::from("test_user") + /// ); /// match manager.read_credentials::() { /// Ok(token) => println!("Token expires at: {}", token.expiry), /// Err(e) => eprintln!("Failed to read token: {}", e), @@ -85,7 +88,7 @@ impl CredentialsManager { where T: DeserializeOwned, { - let json_str = Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + let json_str = Entry::new(&self.keychain_service, &self.keychain_username) .and_then(|entry| entry.get_password()) .inspect(|_| { debug!("Successfully read credentials from keychain"); @@ -144,8 +147,8 @@ impl CredentialsManager { /// /// # Examples /// - /// ``` - /// # use goose_mcp::google_drive::token_storage::CredentialsManager; + /// ```no_run + /// # use goose_mcp::google_drive::storage::CredentialsManager; /// use serde::{Serialize, Deserialize}; /// /// #[derive(Serialize, Deserialize)] @@ -161,7 +164,12 @@ impl CredentialsManager { /// expiry: 1672531200, // Unix timestamp /// }; /// - /// let manager = CredentialsManager::new(String::from("/path/to/credentials.json")); + /// let manager = CredentialsManager::new( + /// String::from("/path/to/credentials.json"), + /// true, // fallback to disk if keychain fails + /// String::from("test_service"), + /// String::from("test_user") + /// ); /// if let Err(e) = manager.write_credentials(&token) { /// eprintln!("Failed to write token: {}", e); /// } @@ -172,7 +180,7 @@ impl CredentialsManager { { let json_str = serde_json::to_string(content).map_err(StorageError::SerializationError)?; - Entry::new(KEYCHAIN_SERVICE, KEYCHAIN_USERNAME) + Entry::new(&self.keychain_service, &self.keychain_username) .and_then(|entry| entry.set_password(&json_str)) .inspect(|_| { debug!("Successfully wrote credentials to keychain"); @@ -219,6 +227,118 @@ impl Clone for CredentialsManager { Self { credentials_path: self.credentials_path.clone(), fallback_to_disk: self.fallback_to_disk, + keychain_service: self.keychain_service.clone(), + keychain_username: self.keychain_username.clone(), } } } + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + use tempfile::tempdir; + + #[derive(Debug, Serialize, Deserialize, PartialEq)] + struct TestCredentials { + access_token: String, + refresh_token: String, + expiry: u64, + } + + impl TestCredentials { + fn new() -> Self { + Self { + access_token: "test_access_token".to_string(), + refresh_token: "test_refresh_token".to_string(), + expiry: 1672531200, + } + } + } + + #[test] + fn test_read_write_from_keychain() { + // Create a temporary directory for test files + let temp_dir = tempdir().expect("Failed to create temp dir"); + let cred_path = temp_dir.path().join("test_credentials.json"); + let cred_path_str = cred_path.to_str().unwrap().to_string(); + + // Create a credentials manager with fallback enabled + // Using a unique service name to ensure keychain operation fails + let manager = CredentialsManager::new( + cred_path_str, + true, // fallback to disk + "test_service".to_string(), + "test_user".to_string(), + ); + + // Test credentials to store + let creds = TestCredentials::new(); + + // Write should write to keychain + let write_result = manager.write_credentials(&creds); + assert!(write_result.is_ok(), "Write should succeed with fallback"); + + // Read should read from keychain + let read_result = manager.read_credentials::(); + assert!(read_result.is_ok(), "Read should succeed with fallback"); + + // Verify the read credentials match what we wrote + assert_eq!( + read_result.unwrap(), + creds, + "Read credentials should match written credentials" + ); + } + + #[test] + fn test_no_fallback_not_found() { + // Create a temporary directory for test files + let temp_dir = tempdir().expect("Failed to create temp dir"); + let cred_path = temp_dir.path().join("nonexistent_credentials.json"); + let cred_path_str = cred_path.to_str().unwrap().to_string(); + + // Create a credentials manager with fallback disabled + let manager = CredentialsManager::new( + cred_path_str, + false, // no fallback to disk + "test_service_that_should_not_exist".to_string(), + "test_user_no_fallback".to_string(), + ); + + // Read should fail with NotFound or KeyringError depending on the system + let read_result = manager.read_credentials::(); + println!("{:?}", read_result); + assert!( + read_result.is_err(), + "Read should fail when credentials don't exist" + ); + } + + #[test] + fn test_serialization_error() { + // This test verifies that serialization errors are properly handled + let error = serde_json::from_str::("invalid json").unwrap_err(); + let storage_error = StorageError::SerializationError(error); + assert!(matches!(storage_error, StorageError::SerializationError(_))); + } + + #[test] + fn test_file_system_error_handling() { + // Test handling of file system errors by using an invalid path + let invalid_path = String::from("/nonexistent_directory/credentials.json"); + let manager = CredentialsManager::new( + invalid_path, + true, + "test_service".to_string(), + "test_user".to_string(), + ); + + // Create test credentials + let creds = TestCredentials::new(); + + // Attempt to write to an invalid path should result in FileSystemError + let result = manager.write_to_file(&serde_json::to_string(&creds).unwrap()); + assert!(matches!(result, Err(StorageError::FileSystemError(_)))); + } +} diff --git a/crates/goose-mcp/src/lib.rs b/crates/goose-mcp/src/lib.rs index 1345d2dd0951..472349f571fd 100644 --- a/crates/goose-mcp/src/lib.rs +++ b/crates/goose-mcp/src/lib.rs @@ -9,7 +9,7 @@ pub static APP_STRATEGY: Lazy = Lazy::new(|| AppStrategyArgs { pub mod computercontroller; mod developer; -mod google_drive; +pub mod google_drive; mod jetbrains; mod memory; mod tutorial; From 1c2d6f7b5e9347f89a998a89801f8117302cf545 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 14 Mar 2025 17:13:18 -0700 Subject: [PATCH 17/18] ci: try cleanup disk space before build too --- .github/workflows/ci.yml | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1e183a46a9b4..a0e7f73e0e58 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -70,6 +70,24 @@ jobs: restore-keys: | ${{ runner.os }}-cargo-build- + # Add disk space cleanup before linting + - name: Check disk space before build + run: df -h + + - name: Clean up disk space + run: | + echo "Cleaning up disk space..." + # Clean npm cache if it exists + npm cache clean --force || true + # Clean apt cache + sudo apt-get clean + # Remove unnecessary large directories + rm -rf ~/.cargo/registry/index || true + # Remove docker images if any + docker system prune -af || true + # Remove unused packages + sudo apt-get autoremove -y || true + - name: Build and Test run: | gnome-keyring-daemon --components=secrets --daemonize --unlock <<< 'foobar' @@ -129,4 +147,4 @@ jobs: uses: ./.github/workflows/bundle-desktop.yml if: github.event_name == 'pull_request' with: - signing: false \ No newline at end of file + signing: false From 3a188c791254d664323084b219960d164a987018 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 14 Mar 2025 17:23:30 -0700 Subject: [PATCH 18/18] ci: cleanup from gh issue --- .github/workflows/ci.yml | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a0e7f73e0e58..eac44cb86386 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -74,19 +74,23 @@ jobs: - name: Check disk space before build run: df -h + #https://github.com/actions/runner-images/issues/2840 - name: Clean up disk space run: | echo "Cleaning up disk space..." - # Clean npm cache if it exists - npm cache clean --force || true - # Clean apt cache - sudo apt-get clean - # Remove unnecessary large directories - rm -rf ~/.cargo/registry/index || true - # Remove docker images if any - docker system prune -af || true - # Remove unused packages - sudo apt-get autoremove -y || true + sudo rm -rf \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /usr/lib/mono \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/dotnet \ + /usr/share/swift + + df -h - name: Build and Test run: |