Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
349 changes: 338 additions & 11 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ tokio-util = "0.7.15"
lancedb = "0.13"
arrow = "52.2"

# ML inference backends for security scanning
ort = "2.0.0-rc.10" # ONNX Runtime - use latest RC
tokenizers = { version = "0.20.4", default-features = false, features = ["onig"] } # HuggingFace tokenizers

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how much does this add to the executable? do we need the huggingface tokenizer? can we not get this done using tiktoken?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, didn't realise you'd left feedback - having a look now. Thanks for the feedback 🙏

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }

Expand Down
142 changes: 138 additions & 4 deletions crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ use super::platform_tools;
use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE};
use crate::agents::subagent_task_config::TaskConfig;
use crate::conversation::message::{Message, ToolRequest};
use crate::security::SecurityManager;

const DEFAULT_MAX_TURNS: u32 = 1000;

Expand Down Expand Up @@ -97,6 +98,7 @@ pub struct Agent {
pub(super) tool_route_manager: ToolRouteManager,
pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>,
pub(super) retry_manager: RetryManager,
pub(super) security_manager: SecurityManager,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -173,6 +175,7 @@ impl Agent {
tool_route_manager: ToolRouteManager::new(),
scheduler_service: Mutex::new(None),
retry_manager,
security_manager: SecurityManager::new(),
}
}

Expand Down Expand Up @@ -1011,21 +1014,76 @@ impl Agent {
self.provider().await?,
).await;

// DEBUG: Log tool categorization
println!("🔍 DEBUG: Tool categorization results:");

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, let's

println!(" - {} tools approved (pre-approved)", permission_check_result.approved.len());
println!(" - {} tools need approval", permission_check_result.needs_approval.len());
println!(" - {} tools denied", permission_check_result.denied.len());
println!(" - {} readonly tools", readonly_tools.len());
println!(" - {} regular tools", regular_tools.len());

for (i, tool_req) in remaining_requests.iter().enumerate() {
if let Ok(tool_call) = &tool_req.tool_call {
println!(" - Tool {}: '{}' -> {}", i, tool_call.name,
if permission_check_result.approved.iter().any(|r| r.id == tool_req.id) { "APPROVED" }
else if permission_check_result.needs_approval.iter().any(|r| r.id == tool_req.id) { "NEEDS_APPROVAL" }
else if permission_check_result.denied.iter().any(|r| r.id == tool_req.id) { "DENIED" }
else { "UNKNOWN" }
);
}
}

// Scan tools for prompt injection
let total_tools = permission_check_result.approved.len() + permission_check_result.needs_approval.len();
println!("🔍 DEBUG: About to call security manager with {} total tools ({} approved + {} need approval)",
total_tools, permission_check_result.approved.len(), permission_check_result.needs_approval.len());
let security_results = self.security_manager
.filter_malicious_tool_calls(messages.messages(), &permission_check_result)
.await
.unwrap_or_else(|e| {
tracing::warn!("Security scanning failed: {}", e);
vec![]
});

// Apply security results to permission check result
let final_permission_result = self.apply_security_results_to_permissions(
permission_check_result,
&security_results
).await;

println!("🔍 DEBUG: After security integration - {} approved, {} need approval, {} denied",
final_permission_result.approved.len(),
final_permission_result.needs_approval.len(),
final_permission_result.denied.len());

let mut tool_futures = self.handle_approved_and_denied_tools(
&permission_check_result,
&final_permission_result,
message_tool_response.clone(),
cancel_token.clone()
).await?;

let tool_futures_arc = Arc::new(Mutex::new(tool_futures));

// Process tools requiring approval
let mut tool_approval_stream = self.handle_approval_tool_requests(
&permission_check_result.needs_approval,
// Process tools requiring approval (including security-flagged tools)
// Create a mapping of security results for tools that need approval
let mut security_results_for_approval: Vec<Option<&crate::security::SecurityResult>> = Vec::new();
for approval_request in &final_permission_result.needs_approval {
// Find the corresponding security result for this tool request
let security_result = security_results.iter().find(|result| {
// Match by checking if this tool was flagged as malicious
// This is a simplified matching - ideally we'd have better tool request tracking
result.is_malicious
});
security_results_for_approval.push(security_result);
}

let mut tool_approval_stream = self.handle_approval_tool_requests_with_security(
&final_permission_result.needs_approval,
tool_futures_arc.clone(),
&mut permission_manager,
message_tool_response.clone(),
cancel_token.clone(),
Some(&security_results_for_approval),
);

while let Some(msg) = tool_approval_stream.try_next().await? {
Expand Down Expand Up @@ -1152,6 +1210,82 @@ impl Agent {
}
}

/// Apply security scan results to permission check results
/// This integrates security scanning with the existing tool approval system
async fn apply_security_results_to_permissions(
&self,
mut permission_result: PermissionCheckResult,
security_results: &[crate::security::SecurityResult],
) -> PermissionCheckResult {
if security_results.is_empty() {
return permission_result;
}

// Create a map of tool requests by ID for easy lookup
let mut all_requests: std::collections::HashMap<String, ToolRequest> = std::collections::HashMap::new();

// Collect all tool requests
for req in &permission_result.approved {
all_requests.insert(req.id.clone(), req.clone());
}
for req in &permission_result.needs_approval {
all_requests.insert(req.id.clone(), req.clone());
}
for req in &permission_result.denied {
all_requests.insert(req.id.clone(), req.clone());
}

// Collect the combined requests first to avoid borrowing issues
let combined_requests: Vec<ToolRequest> = permission_result.approved.iter()
.chain(permission_result.needs_approval.iter())
.cloned()
.collect();

// Process security results
for (i, security_result) in security_results.iter().enumerate() {
if !security_result.is_malicious {
continue;
}

// Find the corresponding tool request by index
if let Some(tool_request) = combined_requests.get(i) {
let request_id = &tool_request.id;

tracing::warn!(
tool_request_id = %request_id,
confidence = security_result.confidence,
explanation = %security_result.explanation,
"Security threat detected - modifying tool approval status"
);

// Remove from approved if present
permission_result.approved.retain(|req| req.id != *request_id);

if security_result.should_ask_user {
// Move to needs_approval with security context
if let Some(request) = all_requests.get(request_id) {
// Only add if not already in needs_approval
if !permission_result.needs_approval.iter().any(|req| req.id == *request_id) {
permission_result.needs_approval.push(request.clone());
}
}
} else {
// High confidence threat - move to denied
permission_result.needs_approval.retain(|req| req.id != *request_id);

if let Some(request) = all_requests.get(request_id) {
// Only add if not already in denied
if !permission_result.denied.iter().any(|req| req.id == *request_id) {
permission_result.denied.push(request.clone());
}
}
}
}
}

permission_result
}

/// Extend the system prompt with one line of additional instruction
pub async fn extend_system_prompt(&self, instruction: String) {
let mut prompt_manager = self.prompt_manager.lock().await;
Expand Down
44 changes: 42 additions & 2 deletions crates/goose/src/agents/tool_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,55 @@ impl Agent {
permission_manager: &'a mut PermissionManager,
message_tool_response: Arc<Mutex<Message>>,
cancellation_token: Option<CancellationToken>,
) -> BoxStream<'a, anyhow::Result<Message>> {
self.handle_approval_tool_requests_with_security(
tool_requests,
tool_futures,
permission_manager,
message_tool_response,
cancellation_token,
None, // No security context by default
)
}

pub(crate) fn handle_approval_tool_requests_with_security<'a>(
&'a self,
tool_requests: &'a [ToolRequest],
tool_futures: Arc<Mutex<Vec<(String, ToolStream)>>>,
permission_manager: &'a mut PermissionManager,
message_tool_response: Arc<Mutex<Message>>,
cancellation_token: Option<CancellationToken>,
security_results: Option<&'a [Option<&'a crate::security::SecurityResult>]>,
) -> BoxStream<'a, anyhow::Result<Message>> {
try_stream! {
for request in tool_requests {
for (i, request) in tool_requests.iter().enumerate() {
if let Ok(tool_call) = request.tool_call.clone() {
// Check if this tool has security concerns
// Match by index since security results are provided in the same order as tool requests
let security_context = security_results
.and_then(|results| results.get(i))
.and_then(|result| *result)
.filter(|result| result.is_malicious);

let confirmation_prompt = if let Some(security_result) = security_context {
format!(
"🚨 SECURITY WARNING: This tool call has been flagged as potentially malicious.\n\
Confidence: {:.1}%\n\
Reason: {}\n\n\
Goose would still like to call the above tool. \n\
Please review carefully. Allow? (y/n):",
security_result.confidence * 100.0,
security_result.explanation
)
} else {
"Goose would like to call the above tool. Allow? (y/n):".to_string()
};

let confirmation = Message::user().with_tool_confirmation_request(
request.id.clone(),
tool_call.name.clone(),
tool_call.arguments.clone(),
Some("Goose would like to call the above tool. Allow? (y/n):".to_string()),
Some(confirmation_prompt),
);
yield confirmation;

Expand Down
1 change: 1 addition & 0 deletions crates/goose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub mod recipe_deeplink;
pub mod scheduler;
pub mod scheduler_factory;
pub mod scheduler_trait;
pub mod security;
pub mod session;
pub mod temporal_scheduler;
pub mod token_counter;
Expand Down
133 changes: 133 additions & 0 deletions crates/goose/src/security/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
pub mod scanner;
pub mod model_downloader;

use anyhow::Result;
use crate::conversation::message::Message;
use crate::permission::permission_judge::PermissionCheckResult;
use scanner::PromptInjectionScanner;

/// Simple security manager for the POC
/// Focuses on tool call analysis with conversation context
pub struct SecurityManager {
scanner: Option<PromptInjectionScanner>,
}

#[derive(Debug, Clone)]
pub struct SecurityResult {
pub is_malicious: bool,
pub confidence: f32,
pub explanation: String,
pub should_ask_user: bool,
}

impl SecurityManager {
pub fn new() -> Self {
println!("🔒 SecurityManager::new() called - checking if security should be enabled");

// Initialize scanner based on config
let should_enable = Self::should_enable_security();
println!("🔒 Security enabled check result: {}", should_enable);

let scanner = match should_enable {
true => {
println!("🔒 Initializing security scanner");
tracing::info!("🔒 Initializing security scanner");
Some(PromptInjectionScanner::new())
}
false => {
println!("🔓 Security scanning disabled");
tracing::info!("🔓 Security scanning disabled");
None
}
};

Self { scanner }
}

/// Check if security should be enabled based on config
fn should_enable_security() -> bool {
// Check config file for security settings
use crate::config::Config;
let config = Config::global();

// Try to get security.enabled from config
let result = config.get_param::<serde_json::Value>("security")
.ok()
.and_then(|security_config| security_config.get("enabled")?.as_bool())
.unwrap_or(false);

println!("🔒 Config check - security config result: {:?}",
config.get_param::<serde_json::Value>("security"));
println!("🔒 Final security enabled result: {}", result);

result
}

/// Main security check function - called from reply_internal
/// Uses the proper two-step security analysis process
/// Scans ALL tools (approved + needs_approval) for security threats
pub async fn filter_malicious_tool_calls(
&self,
messages: &[Message],
permission_check_result: &PermissionCheckResult,
) -> Result<Vec<SecurityResult>> {
let Some(scanner) = &self.scanner else {
// Security disabled, return empty results
return Ok(vec![]);
};

let mut results = Vec::new();

// Collect ALL tool requests (approved + needs_approval) for security scanning
let mut all_tool_requests = Vec::new();
all_tool_requests.extend(&permission_check_result.approved);
all_tool_requests.extend(&permission_check_result.needs_approval);

// Check ALL tools for potential security issues
for tool_request in &all_tool_requests {
if let Ok(tool_call) = &tool_request.tool_call {
tracing::info!(
tool_name = %tool_call.name,
"🔍 Starting two-step security analysis for tool call"
);

// Use the new two-step analysis method
let analysis_result = scanner.analyze_tool_call_with_context(
tool_call,
messages,
).await?;

if analysis_result.is_malicious {
tracing::warn!(
tool_name = %tool_call.name,
confidence = analysis_result.confidence,
explanation = %analysis_result.explanation,
"🚨 Tool call flagged as malicious after two-step analysis"
);

results.push(SecurityResult {
is_malicious: analysis_result.is_malicious,
confidence: analysis_result.confidence,
explanation: analysis_result.explanation,
should_ask_user: analysis_result.confidence > 0.7,
});
} else {
tracing::debug!(
tool_name = %tool_call.name,
confidence = analysis_result.confidence,
explanation = %analysis_result.explanation,
"✅ Tool call passed two-step security analysis"
);
}
}
}

Ok(results)
}
}

impl Default for SecurityManager {
fn default() -> Self {
Self::new()
}
}
Loading