diff --git a/crates/goose/src/security/scanner.rs b/crates/goose/src/security/scanner.rs index 2e658b399909..3411b89a6a5a 100644 --- a/crates/goose/src/security/scanner.rs +++ b/crates/goose/src/security/scanner.rs @@ -46,7 +46,7 @@ impl PromptInjectionScanner { fn create_classifier_from_config() -> Result { let config = Config::global(); - let model_name = config + let mut model_name = config .get_param::("SECURITY_PROMPT_CLASSIFIER_MODEL") .ok() .filter(|s| !s.trim().is_empty()); @@ -59,6 +59,23 @@ impl PromptInjectionScanner { .ok() .filter(|s| !s.trim().is_empty()); + if model_name.is_none() { + if let Ok(mapping_json) = std::env::var("SECURITY_ML_MODEL_MAPPING") { + if let Ok(mapping) = serde_json::from_str::< + crate::security::classification_client::ModelMappingConfig, + >(&mapping_json) + { + if let Some(first_model) = mapping.models.keys().next() { + tracing::info!( + default_model = %first_model, + "SECURITY_ML_MODEL_MAPPING available but no model selected - using first available model as default" + ); + model_name = Some(first_model.clone()); + } + } + } + } + tracing::debug!( model_name = ?model_name, has_endpoint = endpoint.is_some(), @@ -106,20 +123,23 @@ impl PromptInjectionScanner { self.scan_conversation(messages) ); - let highest_confidence_result = - self.select_highest_confidence_result(tool_result?, context_result?); + let tool_result = tool_result?; + let context_result = context_result?; let threshold = self.get_threshold_from_config(); + let final_result = + self.select_result_with_context_awareness(tool_result, context_result, threshold); + tracing::info!( - "✅ Security analysis complete: confidence={:.3}, malicious={}", - highest_confidence_result.confidence, - highest_confidence_result.confidence >= threshold + "Security analysis complete: confidence={:.3}, malicious={}", + final_result.confidence, + final_result.confidence >= threshold ); Ok(ScanResult { - is_malicious: highest_confidence_result.confidence >= threshold, - confidence: highest_confidence_result.confidence, - explanation: self.build_explanation(&highest_confidence_result, threshold), + is_malicious: final_result.confidence >= threshold, + confidence: final_result.confidence, + explanation: self.build_explanation(&final_result, threshold), }) } @@ -169,12 +189,29 @@ impl PromptInjectionScanner { }) } - fn select_highest_confidence_result( + fn select_result_with_context_awareness( &self, tool_result: DetailedScanResult, context_result: DetailedScanResult, + threshold: f32, ) -> DetailedScanResult { - if tool_result.confidence >= context_result.confidence { + let context_is_safe = context_result + .ml_confidence + .is_some_and(|conf| conf < threshold); + + let tool_has_only_non_critical = !tool_result.pattern_matches.is_empty() + && tool_result + .pattern_matches + .iter() + .all(|m| m.threat.risk_level != crate::security::patterns::RiskLevel::Critical); + + if context_is_safe && tool_has_only_non_critical { + DetailedScanResult { + confidence: 0.0, + pattern_matches: Vec::new(), + ml_confidence: context_result.ml_confidence, + } + } else if tool_result.confidence >= context_result.confidence { tool_result } else { context_result