diff --git a/Cargo.lock b/Cargo.lock index 2f3c59dde0c1..d65e5b53e631 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -174,6 +174,12 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" +[[package]] +name = "arrayref" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76a2e8124351fda1ef8aaaa3bbd7ebbcb486bbcd4225aca0aa0d84bb2db8fecb" + [[package]] name = "arrayvec" version = "0.7.6" @@ -887,6 +893,19 @@ version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6099cdc01846bc367c4e7dd630dc5966dccf36b652fae7a74e17b640411a91b2" +[[package]] +name = "blake3" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675f87afced0413c9bb02843499dbbd3882a237645883f71a2b59644a6d2f753" +dependencies = [ + "arrayref", + "arrayvec", + "cc", + "cfg-if", + "constant_time_eq", +] + [[package]] name = "block-buffer" version = "0.10.4" @@ -1248,6 +1267,12 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "content_inspector" version = "0.2.4" @@ -2114,6 +2139,7 @@ dependencies = [ "aws-smithy-types", "axum 0.7.9", "base64 0.21.7", + "blake3", "chrono", "criterion", "ctor", diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 0e02205994cd..8702ad8ef0a7 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -69,6 +69,9 @@ aws-sdk-bedrockruntime = "1.72.0" # For GCP Vertex AI provider auth jsonwebtoken = "9.3.1" +# Added blake3 hashing library as a dependency +blake3 = "1.5" + [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index 4aa0659c976b..e970ca620cdc 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -3,6 +3,7 @@ mod capabilities; pub mod extension; mod factory; mod permission_judge; +mod permission_store; mod reference; mod truncate; @@ -11,3 +12,4 @@ pub use capabilities::Capabilities; pub use extension::ExtensionConfig; pub use factory::{register_agent, AgentFactory}; pub use permission_judge::detect_read_only_tools; +pub use permission_store::ToolPermissionStore; diff --git a/crates/goose/src/agents/permission_store.rs b/crates/goose/src/agents/permission_store.rs new file mode 100644 index 000000000000..d9b1d363068c --- /dev/null +++ b/crates/goose/src/agents/permission_store.rs @@ -0,0 +1,149 @@ +use crate::message::ToolRequest; +use anyhow::Result; +use blake3::Hasher; +use chrono::Utc; +use etcetera::{choose_app_strategy, AppStrategy}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::time::Duration; +use std::{fs::File, path::PathBuf}; + +#[derive(Debug, Serialize, Deserialize, Clone)] +pub struct ToolPermissionRecord { + tool_name: String, + allowed: bool, + context_hash: String, // Hash of the tool's arguments/context to differentiate similar calls + #[serde(skip_serializing_if = "Option::is_none")] // Don't serialize if None + readable_context: Option, // Add this field + timestamp: i64, + expiry: Option, // Optional expiry timestamp +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ToolPermissionStore { + permissions: HashMap>, + version: u32, // For future schema migrations + #[serde(skip)] // Don't serialize this field + permissions_dir: PathBuf, +} + +impl Default for ToolPermissionStore { + fn default() -> Self { + Self::new() + } +} + +impl ToolPermissionStore { + pub fn new() -> Self { + let permissions_dir = choose_app_strategy(crate::config::APP_STRATEGY.clone()) + .map(|strategy| strategy.config_dir()) + .unwrap_or_else(|_| PathBuf::from(".config/goose")); + + Self { + permissions: HashMap::new(), + version: 1, + permissions_dir, + } + } + + pub fn load() -> Result { + let store = Self::new(); + let file_path = store.permissions_dir.join("tool_permissions.json"); + + if !file_path.exists() { + return Ok(store); + } + + let file = File::open(file_path)?; + let mut permissions: ToolPermissionStore = serde_json::from_reader(file)?; + permissions.permissions_dir = store.permissions_dir; + + // Clean up expired entries on load + permissions.cleanup_expired()?; + + Ok(permissions) + } + + pub fn save(&self) -> anyhow::Result<()> { + std::fs::create_dir_all(&self.permissions_dir)?; + + let path = self.permissions_dir.join("tool_permissions.json"); + let temp_path = path.with_extension("tmp"); + + // Write complete content to temporary file + let content = serde_json::to_string_pretty(self)?; + std::fs::write(&temp_path, &content)?; + + // Atomically rename temp file to target file + std::fs::rename(temp_path, path)?; + + Ok(()) + } + + pub fn check_permission(&self, tool_request: &ToolRequest) -> Option { + let context_hash = self.hash_tool_context(tool_request); + let tool_call = tool_request.tool_call.as_ref().unwrap(); + let key = format!("{}:{}", tool_call.name, context_hash); + + self.permissions.get(&key).and_then(|records| { + records + .iter() + .filter(|record| record.expiry.is_none_or(|exp| exp > Utc::now().timestamp())) + .last() + .map(|record| record.allowed) + }) + } + + pub fn record_permission( + &mut self, + tool_request: &ToolRequest, + allowed: bool, + expiry_duration: Option, + ) -> anyhow::Result<()> { + let context_hash = self.hash_tool_context(tool_request); + let tool_call = tool_request.tool_call.as_ref().unwrap(); + let key = format!("{}:{}", tool_call.name, context_hash); + + let record = ToolPermissionRecord { + tool_name: tool_call.name.clone(), + allowed, + context_hash, + readable_context: Some(tool_request.to_readable_string()), + timestamp: Utc::now().timestamp(), + expiry: expiry_duration.map(|d| Utc::now().timestamp() + d.as_secs() as i64), + }; + + self.permissions.entry(key).or_default().push(record); + + self.save()?; + Ok(()) + } + + fn hash_tool_context(&self, tool_request: &ToolRequest) -> String { + // Create a hash of the tool's arguments to differentiate similar calls + // This helps identify when the same tool is being used in a different context + let mut hasher = Hasher::new(); + hasher.update( + serde_json::to_string(&tool_request.tool_call.as_ref().unwrap().arguments) + .unwrap_or_default() + .as_bytes(), + ); + hasher.finalize().to_hex().to_string() + } + + pub fn cleanup_expired(&mut self) -> anyhow::Result<()> { + let now = Utc::now().timestamp(); + let mut changed = false; + + self.permissions.retain(|_, records| { + records.retain(|record| record.expiry.is_none_or(|exp| exp > now)); + changed = changed || records.is_empty(); + !records.is_empty() + }); + + if changed { + self.save()?; + } + Ok(()) + } +} diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index fc3414f84c71..e4a1f562779e 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -12,6 +12,7 @@ use super::detect_read_only_tools; use super::Agent; use crate::agents::capabilities::Capabilities; use crate::agents::extension::{ExtensionConfig, ExtensionResult}; +use crate::agents::ToolPermissionStore; use crate::config::Config; use crate::config::ExperimentManager; use crate::message::{Message, ToolRequest}; @@ -28,6 +29,7 @@ use mcp_core::prompt::Prompt; use mcp_core::protocol::GetPromptResult; use mcp_core::{tool::Tool, Content}; use serde_json::{json, Value}; +use std::time::Duration; const MAX_TRUNCATION_ATTEMPTS: usize = 3; const ESTIMATE_FACTOR_DECAY: f32 = 0.9; @@ -265,19 +267,43 @@ impl Agent for TruncateAgent { match mode.as_str() { "approve" => { let mut read_only_tools = Vec::new(); - // Process each tool request sequentially with confirmation - if ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? { - read_only_tools = detect_read_only_tools(&capabilities, tool_requests.clone()).await; + let mut needs_confirmation = Vec::<&ToolRequest>::new(); + + // First check permissions for all tools + let store = ToolPermissionStore::load()?; + for request in tool_requests.iter() { + if let Ok(tool_call) = request.tool_call.clone() { + if let Some(allowed) = store.check_permission(request) { + if allowed { + let output = capabilities.dispatch_tool_call(tool_call).await; + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + output, + ); + } else { + needs_confirmation.push(request); + } + } else { + needs_confirmation.push(request); + } + } } - for request in &tool_requests { + + // Only check read-only status for tools needing confirmation + if !needs_confirmation.is_empty() && ExperimentManager::is_enabled("GOOSE_SMART_APPROVE")? { + read_only_tools = detect_read_only_tools(&capabilities, needs_confirmation.clone()).await; + } + + // Process remaining tools that need confirmation + for request in &needs_confirmation { if let Ok(tool_call) = request.tool_call.clone() { // Skip confirmation if the tool_call.name is in the read_only_tools list if read_only_tools.contains(&tool_call.name) { let output = capabilities.dispatch_tool_call(tool_call).await; - message_tool_response = message_tool_response.with_tool_response( - request.id.clone(), - output, - ); + message_tool_response = message_tool_response.with_tool_response( + request.id.clone(), + output, + ); } else { let confirmation = Message::user().with_tool_confirmation_request( request.id.clone(), @@ -289,9 +315,12 @@ impl Agent for TruncateAgent { // Wait for confirmation response through the channel let mut rx = self.confirmation_rx.lock().await; - // Loop the recv until we have a matched req_id due to potential duplicate messages. while let Some((req_id, confirmed)) = rx.recv().await { if req_id == request.id { + // Store the user's response with 30-day expiration + let mut store = ToolPermissionStore::load()?; + store.record_permission(request, confirmed, Some(Duration::from_secs(30 * 24 * 60 * 60)))?; + if confirmed { // User approved - dispatch the tool call let output = capabilities.dispatch_tool_call(tool_call).await; diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 0e5781e7cca6..c370c174aa8a 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -26,6 +26,22 @@ pub struct ToolRequest { pub tool_call: ToolResult, } +impl ToolRequest { + pub fn to_readable_string(&self) -> String { + match &self.tool_call { + Ok(tool_call) => { + format!( + "Tool: {}, Args: {}", + tool_call.name, + serde_json::to_string_pretty(&tool_call.arguments) + .unwrap_or_else(|_| "<>".to_string()) + ) + } + Err(e) => format!("Invalid tool call: {}", e), + } + } +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] #[serde(rename_all = "camelCase")] pub struct ToolResponse {