Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 34 additions & 33 deletions crates/goose/src/security/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,29 @@ impl PromptInjectionScanner {
threshold
);

let final_result =
self.select_result_with_context_awareness(tool_result, context_result, threshold);
let final_confidence =
self.combine_confidences(tool_result.confidence, context_result.confidence);

tracing::info!(
"Security analysis complete: final_confidence={:.3}, malicious={}",
final_result.confidence,
final_result.confidence >= threshold
tool_confidence = %tool_result.confidence,
context_confidence = %context_result.confidence,
final_confidence = %final_confidence,
has_ml = tool_result.ml_confidence.is_some(),
has_patterns = !tool_result.pattern_matches.is_empty(),
threshold = %threshold,
malicious = final_confidence >= threshold,
"Security analysis complete"
);

let final_result = DetailedScanResult {
confidence: final_confidence,
pattern_matches: tool_result.pattern_matches,
ml_confidence: tool_result.ml_confidence,
};

Ok(ScanResult {
is_malicious: final_result.confidence >= threshold,
confidence: final_result.confidence,
is_malicious: final_confidence >= threshold,
confidence: final_confidence,
explanation: self.build_explanation(&final_result, threshold, &tool_content),
})
}
Expand Down Expand Up @@ -228,33 +239,23 @@ impl PromptInjectionScanner {
})
}

fn select_result_with_context_awareness(
&self,
tool_result: DetailedScanResult,
context_result: DetailedScanResult,
threshold: f32,
) -> DetailedScanResult {
let context_is_safe = context_result
.ml_confidence
.is_some_and(|conf| conf < threshold);

let tool_has_only_non_critical = !tool_result.pattern_matches.is_empty()
&& tool_result
.pattern_matches
.iter()
.all(|m| m.threat.risk_level != crate::security::patterns::RiskLevel::Critical);

if context_is_safe && tool_has_only_non_critical {
DetailedScanResult {
confidence: 0.0,
pattern_matches: Vec::new(),
ml_confidence: context_result.ml_confidence,
}
} else if tool_result.confidence >= context_result.confidence {
tool_result
} else {
context_result
fn combine_confidences(&self, tool_confidence: f32, context_confidence: f32) -> f32 {
// If tool is safe, context is not taken into account
if tool_confidence < 0.3 {
return tool_confidence;
}

if context_confidence < 0.3 {
return tool_confidence * 0.9;
}

if tool_confidence > 0.8 && context_confidence > 0.8 {
let max_conf = tool_confidence.max(context_confidence);
return (max_conf * 1.05).min(1.0);
}

// Default: weighted average (tool is primary signal)
tool_confidence * 0.8 + context_confidence * 0.2
}

async fn scan_with_classifier(
Expand Down
Loading