Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions crates/goose/src/security/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ impl PromptInjectionScanner {
fn create_classifier_from_config() -> Result<ClassificationClient> {
let config = Config::global();

let model_name = config
let mut model_name = config
.get_param::<String>("SECURITY_PROMPT_CLASSIFIER_MODEL")
.ok()
.filter(|s| !s.trim().is_empty());
Expand All @@ -59,6 +59,23 @@ impl PromptInjectionScanner {
.ok()
.filter(|s| !s.trim().is_empty());

if model_name.is_none() {
if let Ok(mapping_json) = std::env::var("SECURITY_ML_MODEL_MAPPING") {
if let Ok(mapping) = serde_json::from_str::<
crate::security::classification_client::ModelMappingConfig,
>(&mapping_json)
{
if let Some(first_model) = mapping.models.keys().next() {
tracing::info!(
default_model = %first_model,
"SECURITY_ML_MODEL_MAPPING available but no model selected - using first available model as default"
Comment on lines +68 to +71
Copy link

Copilot AI Jan 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HashMap iteration order is randomized in Rust, making default model selection non-deterministic across runs. For security features, consider sorting keys or using the first key alphabetically to ensure consistent behavior.

Suggested change
if let Some(first_model) = mapping.models.keys().next() {
tracing::info!(
default_model = %first_model,
"SECURITY_ML_MODEL_MAPPING available but no model selected - using first available model as default"
let mut model_names: Vec<_> = mapping.models.keys().cloned().collect();
model_names.sort();
if let Some(first_model) = model_names.first() {
tracing::info!(
default_model = %first_model,
"SECURITY_ML_MODEL_MAPPING available but no model selected - using first available model as default (lexicographically smallest)"

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really a critical comment

);
model_name = Some(first_model.clone());
}
}
}
}

tracing::debug!(
model_name = ?model_name,
has_endpoint = endpoint.is_some(),
Expand Down Expand Up @@ -106,20 +123,23 @@ impl PromptInjectionScanner {
self.scan_conversation(messages)
);

let highest_confidence_result =
self.select_highest_confidence_result(tool_result?, context_result?);
let tool_result = tool_result?;
let context_result = context_result?;
let threshold = self.get_threshold_from_config();

let final_result =
self.select_result_with_context_awareness(tool_result, context_result, threshold);

tracing::info!(
"Security analysis complete: confidence={:.3}, malicious={}",
highest_confidence_result.confidence,
highest_confidence_result.confidence >= threshold
"Security analysis complete: confidence={:.3}, malicious={}",
final_result.confidence,
final_result.confidence >= threshold
);

Ok(ScanResult {
is_malicious: highest_confidence_result.confidence >= threshold,
confidence: highest_confidence_result.confidence,
explanation: self.build_explanation(&highest_confidence_result, threshold),
is_malicious: final_result.confidence >= threshold,
confidence: final_result.confidence,
explanation: self.build_explanation(&final_result, threshold),
})
}

Expand Down Expand Up @@ -169,12 +189,29 @@ impl PromptInjectionScanner {
})
}

fn select_highest_confidence_result(
fn select_result_with_context_awareness(
&self,
tool_result: DetailedScanResult,
context_result: DetailedScanResult,
threshold: f32,
) -> DetailedScanResult {
if tool_result.confidence >= context_result.confidence {
let context_is_safe = context_result
.ml_confidence
.is_some_and(|conf| conf < threshold);

let tool_has_only_non_critical = !tool_result.pattern_matches.is_empty()
&& tool_result
.pattern_matches
.iter()
.all(|m| m.threat.risk_level != crate::security::patterns::RiskLevel::Critical);

if context_is_safe && tool_has_only_non_critical {
DetailedScanResult {
confidence: 0.0,
pattern_matches: Vec::new(),
ml_confidence: context_result.ml_confidence,
}
} else if tool_result.confidence >= context_result.confidence {
tool_result
} else {
context_result
Expand Down
Loading