-
Notifications
You must be signed in to change notification settings - Fork 5.4k
feat: Prompt injection detection #4021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
dorien-koelemeijer
wants to merge
15
commits into
aaif-goose:main
from
dorien-koelemeijer:feat/prompt-injection
Closed
Changes from 6 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
cc30637
initial version - integrating security scanning into check_tool_permi…
dorien-koelemeijer fe0c392
Fix some issues with how prompt injection detection on agent tool cal…
dorien-koelemeijer d641ea7
remove pattern based scanning - we should rely on BERT models, don't …
dorien-koelemeijer d6b1793
Model management updates - don't convert models to onnx if not necessary
dorien-koelemeijer 9177ea8
update model downloader - download onnx straight away from huggingfac…
dorien-koelemeijer 4ae158f
Re-use ToolCall for user input if prompt injection detected
dorien-koelemeijer e6aa6db
If ToolCall is verified for security finding, don't use 'Always Allow…
dorien-koelemeijer 15bc55f
Make user aware of model downloading for time being until we figure o…
dorien-koelemeijer af39972
Merge branch 'main' into feat/prompt-injection
dorien-koelemeijer b09d04e
Perform security scanning before 'check_tool_permissions' to make sur…
dorien-koelemeijer d7dfee9
fix: correctly fetch threshold values from goose config
dorien-koelemeijer 9b28c60
fix: update prompt injeciton logging to make analysis easier
dorien-koelemeijer 2847b80
fix
dorien-koelemeijer 8d9c190
temp fix
dorien-koelemeijer aeee29d
remove unused code
dorien-koelemeijer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -57,6 +57,7 @@ use super::platform_tools; | |
| use super::tool_execution::{ToolCallResult, CHAT_MODE_TOOL_SKIPPED_RESPONSE, DECLINED_RESPONSE}; | ||
| use crate::agents::subagent_task_config::TaskConfig; | ||
| use crate::conversation::message::{Message, ToolRequest}; | ||
| use crate::security::SecurityManager; | ||
|
|
||
| const DEFAULT_MAX_TURNS: u32 = 1000; | ||
|
|
||
|
|
@@ -97,6 +98,7 @@ pub struct Agent { | |
| pub(super) tool_route_manager: ToolRouteManager, | ||
| pub(super) scheduler_service: Mutex<Option<Arc<dyn SchedulerTrait>>>, | ||
| pub(super) retry_manager: RetryManager, | ||
| pub(super) security_manager: SecurityManager, | ||
| } | ||
|
|
||
| #[derive(Clone, Debug)] | ||
|
|
@@ -173,6 +175,7 @@ impl Agent { | |
| tool_route_manager: ToolRouteManager::new(), | ||
| scheduler_service: Mutex::new(None), | ||
| retry_manager, | ||
| security_manager: SecurityManager::new(), | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -1011,21 +1014,76 @@ impl Agent { | |
| self.provider().await?, | ||
| ).await; | ||
|
|
||
| // DEBUG: Log tool categorization | ||
| println!("🔍 DEBUG: Tool categorization results:"); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, let's |
||
| println!(" - {} tools approved (pre-approved)", permission_check_result.approved.len()); | ||
| println!(" - {} tools need approval", permission_check_result.needs_approval.len()); | ||
| println!(" - {} tools denied", permission_check_result.denied.len()); | ||
| println!(" - {} readonly tools", readonly_tools.len()); | ||
| println!(" - {} regular tools", regular_tools.len()); | ||
|
|
||
| for (i, tool_req) in remaining_requests.iter().enumerate() { | ||
| if let Ok(tool_call) = &tool_req.tool_call { | ||
| println!(" - Tool {}: '{}' -> {}", i, tool_call.name, | ||
| if permission_check_result.approved.iter().any(|r| r.id == tool_req.id) { "APPROVED" } | ||
| else if permission_check_result.needs_approval.iter().any(|r| r.id == tool_req.id) { "NEEDS_APPROVAL" } | ||
| else if permission_check_result.denied.iter().any(|r| r.id == tool_req.id) { "DENIED" } | ||
| else { "UNKNOWN" } | ||
| ); | ||
| } | ||
| } | ||
|
|
||
| // Scan tools for prompt injection | ||
| let total_tools = permission_check_result.approved.len() + permission_check_result.needs_approval.len(); | ||
| println!("🔍 DEBUG: About to call security manager with {} total tools ({} approved + {} need approval)", | ||
| total_tools, permission_check_result.approved.len(), permission_check_result.needs_approval.len()); | ||
| let security_results = self.security_manager | ||
| .filter_malicious_tool_calls(messages.messages(), &permission_check_result) | ||
| .await | ||
| .unwrap_or_else(|e| { | ||
| tracing::warn!("Security scanning failed: {}", e); | ||
| vec![] | ||
| }); | ||
|
|
||
| // Apply security results to permission check result | ||
| let final_permission_result = self.apply_security_results_to_permissions( | ||
| permission_check_result, | ||
| &security_results | ||
| ).await; | ||
|
|
||
| println!("🔍 DEBUG: After security integration - {} approved, {} need approval, {} denied", | ||
| final_permission_result.approved.len(), | ||
| final_permission_result.needs_approval.len(), | ||
| final_permission_result.denied.len()); | ||
|
|
||
| let mut tool_futures = self.handle_approved_and_denied_tools( | ||
| &permission_check_result, | ||
| &final_permission_result, | ||
| message_tool_response.clone(), | ||
| cancel_token.clone() | ||
| ).await?; | ||
|
|
||
| let tool_futures_arc = Arc::new(Mutex::new(tool_futures)); | ||
|
|
||
| // Process tools requiring approval | ||
| let mut tool_approval_stream = self.handle_approval_tool_requests( | ||
| &permission_check_result.needs_approval, | ||
| // Process tools requiring approval (including security-flagged tools) | ||
| // Create a mapping of security results for tools that need approval | ||
| let mut security_results_for_approval: Vec<Option<&crate::security::SecurityResult>> = Vec::new(); | ||
| for approval_request in &final_permission_result.needs_approval { | ||
| // Find the corresponding security result for this tool request | ||
| let security_result = security_results.iter().find(|result| { | ||
| // Match by checking if this tool was flagged as malicious | ||
| // This is a simplified matching - ideally we'd have better tool request tracking | ||
| result.is_malicious | ||
| }); | ||
| security_results_for_approval.push(security_result); | ||
| } | ||
|
|
||
| let mut tool_approval_stream = self.handle_approval_tool_requests_with_security( | ||
| &final_permission_result.needs_approval, | ||
| tool_futures_arc.clone(), | ||
| &mut permission_manager, | ||
| message_tool_response.clone(), | ||
| cancel_token.clone(), | ||
| Some(&security_results_for_approval), | ||
| ); | ||
|
|
||
| while let Some(msg) = tool_approval_stream.try_next().await? { | ||
|
|
@@ -1152,6 +1210,82 @@ impl Agent { | |
| } | ||
| } | ||
|
|
||
| /// Apply security scan results to permission check results | ||
| /// This integrates security scanning with the existing tool approval system | ||
| async fn apply_security_results_to_permissions( | ||
| &self, | ||
| mut permission_result: PermissionCheckResult, | ||
| security_results: &[crate::security::SecurityResult], | ||
| ) -> PermissionCheckResult { | ||
| if security_results.is_empty() { | ||
| return permission_result; | ||
| } | ||
|
|
||
| // Create a map of tool requests by ID for easy lookup | ||
| let mut all_requests: std::collections::HashMap<String, ToolRequest> = std::collections::HashMap::new(); | ||
|
|
||
| // Collect all tool requests | ||
| for req in &permission_result.approved { | ||
| all_requests.insert(req.id.clone(), req.clone()); | ||
| } | ||
| for req in &permission_result.needs_approval { | ||
| all_requests.insert(req.id.clone(), req.clone()); | ||
| } | ||
| for req in &permission_result.denied { | ||
| all_requests.insert(req.id.clone(), req.clone()); | ||
| } | ||
|
|
||
| // Collect the combined requests first to avoid borrowing issues | ||
| let combined_requests: Vec<ToolRequest> = permission_result.approved.iter() | ||
| .chain(permission_result.needs_approval.iter()) | ||
| .cloned() | ||
| .collect(); | ||
|
|
||
| // Process security results | ||
| for (i, security_result) in security_results.iter().enumerate() { | ||
| if !security_result.is_malicious { | ||
| continue; | ||
| } | ||
|
|
||
| // Find the corresponding tool request by index | ||
| if let Some(tool_request) = combined_requests.get(i) { | ||
| let request_id = &tool_request.id; | ||
|
|
||
| tracing::warn!( | ||
| tool_request_id = %request_id, | ||
| confidence = security_result.confidence, | ||
| explanation = %security_result.explanation, | ||
| "Security threat detected - modifying tool approval status" | ||
| ); | ||
|
|
||
| // Remove from approved if present | ||
| permission_result.approved.retain(|req| req.id != *request_id); | ||
|
|
||
| if security_result.should_ask_user { | ||
| // Move to needs_approval with security context | ||
| if let Some(request) = all_requests.get(request_id) { | ||
| // Only add if not already in needs_approval | ||
| if !permission_result.needs_approval.iter().any(|req| req.id == *request_id) { | ||
| permission_result.needs_approval.push(request.clone()); | ||
| } | ||
| } | ||
| } else { | ||
| // High confidence threat - move to denied | ||
| permission_result.needs_approval.retain(|req| req.id != *request_id); | ||
|
|
||
| if let Some(request) = all_requests.get(request_id) { | ||
| // Only add if not already in denied | ||
| if !permission_result.denied.iter().any(|req| req.id == *request_id) { | ||
| permission_result.denied.push(request.clone()); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| permission_result | ||
| } | ||
|
|
||
| /// Extend the system prompt with one line of additional instruction | ||
| pub async fn extend_system_prompt(&self, instruction: String) { | ||
| let mut prompt_manager = self.prompt_manager.lock().await; | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| pub mod scanner; | ||
| pub mod model_downloader; | ||
|
|
||
| use anyhow::Result; | ||
| use crate::conversation::message::Message; | ||
| use crate::permission::permission_judge::PermissionCheckResult; | ||
| use scanner::PromptInjectionScanner; | ||
|
|
||
| /// Simple security manager for the POC | ||
| /// Focuses on tool call analysis with conversation context | ||
| pub struct SecurityManager { | ||
| scanner: Option<PromptInjectionScanner>, | ||
| } | ||
|
|
||
| #[derive(Debug, Clone)] | ||
| pub struct SecurityResult { | ||
| pub is_malicious: bool, | ||
| pub confidence: f32, | ||
| pub explanation: String, | ||
| pub should_ask_user: bool, | ||
| } | ||
|
|
||
| impl SecurityManager { | ||
| pub fn new() -> Self { | ||
| println!("🔒 SecurityManager::new() called - checking if security should be enabled"); | ||
|
|
||
| // Initialize scanner based on config | ||
| let should_enable = Self::should_enable_security(); | ||
| println!("🔒 Security enabled check result: {}", should_enable); | ||
|
|
||
| let scanner = match should_enable { | ||
| true => { | ||
| println!("🔒 Initializing security scanner"); | ||
| tracing::info!("🔒 Initializing security scanner"); | ||
| Some(PromptInjectionScanner::new()) | ||
| } | ||
| false => { | ||
| println!("🔓 Security scanning disabled"); | ||
| tracing::info!("🔓 Security scanning disabled"); | ||
| None | ||
| } | ||
| }; | ||
|
|
||
| Self { scanner } | ||
| } | ||
|
|
||
| /// Check if security should be enabled based on config | ||
| fn should_enable_security() -> bool { | ||
| // Check config file for security settings | ||
| use crate::config::Config; | ||
| let config = Config::global(); | ||
|
|
||
| // Try to get security.enabled from config | ||
| let result = config.get_param::<serde_json::Value>("security") | ||
| .ok() | ||
| .and_then(|security_config| security_config.get("enabled")?.as_bool()) | ||
| .unwrap_or(false); | ||
|
|
||
| println!("🔒 Config check - security config result: {:?}", | ||
| config.get_param::<serde_json::Value>("security")); | ||
| println!("🔒 Final security enabled result: {}", result); | ||
|
|
||
| result | ||
| } | ||
|
|
||
| /// Main security check function - called from reply_internal | ||
| /// Uses the proper two-step security analysis process | ||
| /// Scans ALL tools (approved + needs_approval) for security threats | ||
| pub async fn filter_malicious_tool_calls( | ||
| &self, | ||
| messages: &[Message], | ||
| permission_check_result: &PermissionCheckResult, | ||
| ) -> Result<Vec<SecurityResult>> { | ||
| let Some(scanner) = &self.scanner else { | ||
| // Security disabled, return empty results | ||
| return Ok(vec![]); | ||
| }; | ||
|
|
||
| let mut results = Vec::new(); | ||
|
|
||
| // Collect ALL tool requests (approved + needs_approval) for security scanning | ||
| let mut all_tool_requests = Vec::new(); | ||
| all_tool_requests.extend(&permission_check_result.approved); | ||
| all_tool_requests.extend(&permission_check_result.needs_approval); | ||
|
|
||
| // Check ALL tools for potential security issues | ||
| for tool_request in &all_tool_requests { | ||
| if let Ok(tool_call) = &tool_request.tool_call { | ||
| tracing::info!( | ||
| tool_name = %tool_call.name, | ||
| "🔍 Starting two-step security analysis for tool call" | ||
| ); | ||
|
|
||
| // Use the new two-step analysis method | ||
| let analysis_result = scanner.analyze_tool_call_with_context( | ||
| tool_call, | ||
| messages, | ||
| ).await?; | ||
|
|
||
| if analysis_result.is_malicious { | ||
| tracing::warn!( | ||
| tool_name = %tool_call.name, | ||
| confidence = analysis_result.confidence, | ||
| explanation = %analysis_result.explanation, | ||
| "🚨 Tool call flagged as malicious after two-step analysis" | ||
| ); | ||
|
|
||
| results.push(SecurityResult { | ||
| is_malicious: analysis_result.is_malicious, | ||
| confidence: analysis_result.confidence, | ||
| explanation: analysis_result.explanation, | ||
| should_ask_user: analysis_result.confidence > 0.7, | ||
| }); | ||
| } else { | ||
| tracing::debug!( | ||
| tool_name = %tool_call.name, | ||
| confidence = analysis_result.confidence, | ||
| explanation = %analysis_result.explanation, | ||
| "✅ Tool call passed two-step security analysis" | ||
| ); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Ok(results) | ||
| } | ||
| } | ||
|
|
||
| impl Default for SecurityManager { | ||
| fn default() -> Self { | ||
| Self::new() | ||
| } | ||
| } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how much does this add to the executable? do we need the huggingface tokenizer? can we not get this done using tiktoken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, didn't realise you'd left feedback - having a look now. Thanks for the feedback 🙏