Add ML-based prompt injection detection #5623
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR standardizes security configuration keys from snake_case to UPPER_CASE convention and introduces ML-based prompt injection detection using BERT models through a new Gondola provider. The changes enable more accurate security threat detection while maintaining backward compatibility through pattern-based scanning as a fallback.
Key changes:
- Configuration key standardization:
security_prompt_*→SECURITY_PROMPT_* - New ML detection infrastructure with Gondola provider for BERT-based model inference
- Enhanced scanner to combine pattern-based and ML-based detection results
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| ui/desktop/src/utils/configUtils.ts | Updated config labels to use new UPPER_CASE keys and added ML detection settings |
| ui/desktop/src/components/settings/security/SecurityToggle.tsx | Added ML detection toggle UI with model selection dropdown |
| documentation/docs/guides/security/prompt-injection-detection.md | Updated config examples to use new UPPER_CASE keys |
| documentation/docs/guides/config-files.md | Updated configuration reference table with new ML detection settings and standardized keys |
| crates/goose/src/security/scanner.rs | Refactored to support optional ML detection, enhanced conversation context scanning, and simplified tool content extraction |
| crates/goose/src/security/prompt_ml_detector.rs | New ML detector implementation with model registry and Gondola provider integration |
| crates/goose/src/security/mod.rs | Updated to initialize scanner with ML detection when enabled and handle fallback scenarios |
| crates/goose/src/providers/mod.rs | Added gondola module to provider list |
| crates/goose/src/providers/gondola.rs | New Gondola provider implementation for batch inference with BERT models |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
1f5fd06 to
9fe268c
Compare
459eb56 to
a0d6d5c
Compare
eff6fbe to
ba5feee
Compare
ef65a16 to
fce20ae
Compare
| use anyhow::{Context, Result}; | ||
| use serde::{Deserialize, Serialize}; | ||
| use std::collections::HashMap; | ||
| use std::time::Duration; | ||
| use url::Url; | ||
|
|
||
| /// Request format following HuggingFace Inference Text Classification API specification | ||
| #[derive(Debug, Serialize)] | ||
| struct ClassificationRequest { | ||
| inputs: String, | ||
| #[serde(skip_serializing_if = "Option::is_none")] | ||
| parameters: Option<serde_json::Value>, | ||
| } | ||
|
|
||
| #[derive(Debug, Deserialize, Clone)] | ||
| struct ClassificationLabel { | ||
| label: String, | ||
| score: f32, | ||
| } | ||
|
|
||
| type ClassificationResponse = Vec<Vec<ClassificationLabel>>; | ||
|
|
||
| #[derive(Debug, Deserialize, Clone)] | ||
| pub struct ModelEndpointInfo { | ||
| pub endpoint: String, | ||
| #[serde(flatten)] | ||
| pub extra_params: HashMap<String, serde_json::Value>, | ||
| } | ||
|
|
||
| #[derive(Debug, Deserialize, Clone)] | ||
| pub struct ModelMappingConfig { | ||
| #[serde(flatten)] | ||
| pub models: HashMap<String, ModelEndpointInfo>, | ||
| } | ||
|
|
||
| #[derive(Debug)] | ||
| pub struct ClassificationClient { | ||
| endpoint_url: String, | ||
| client: reqwest::Client, | ||
| auth_token: Option<String>, | ||
| extra_params: Option<HashMap<String, serde_json::Value>>, | ||
| } | ||
|
|
||
| impl ClassificationClient { | ||
| pub fn new( | ||
| endpoint_url: String, | ||
| timeout_ms: Option<u64>, | ||
| auth_token: Option<String>, | ||
| extra_params: Option<HashMap<String, serde_json::Value>>, | ||
| ) -> Result<Self> { | ||
| let timeout = Duration::from_millis(timeout_ms.unwrap_or(5000)); | ||
|
|
||
| let client = reqwest::Client::builder() | ||
| .timeout(timeout) | ||
| .build() | ||
| .context("Failed to create HTTP client")?; | ||
|
|
||
| Ok(Self { | ||
| endpoint_url, | ||
| client, | ||
| auth_token, | ||
| extra_params, | ||
| }) | ||
| } | ||
|
|
||
| pub fn from_model_name(model_name: &str, timeout_ms: Option<u64>) -> Result<Self> { | ||
| let mapping_json = std::env::var("SECURITY_ML_MODEL_MAPPING") | ||
| .context("SECURITY_ML_MODEL_MAPPING environment variable not set")?; | ||
|
|
||
| let mapping = serde_json::from_str::<ModelMappingConfig>(&mapping_json) | ||
| .context("Failed to parse SECURITY_ML_MODEL_MAPPING JSON")?; | ||
|
|
||
| let model_info = mapping.models.get(model_name).context(format!( | ||
| "Model '{}' not found in SECURITY_ML_MODEL_MAPPING", | ||
| model_name | ||
| ))?; | ||
|
|
||
| tracing::info!( | ||
| model_name = %model_name, | ||
| endpoint = %model_info.endpoint, | ||
| extra_params = ?model_info.extra_params, | ||
| "Creating classification client from model mapping" | ||
| ); | ||
|
|
||
| Self::new( | ||
| model_info.endpoint.clone(), | ||
| timeout_ms, | ||
| None, | ||
| Some(model_info.extra_params.clone()), | ||
| ) | ||
| } | ||
|
|
||
| pub fn from_endpoint( | ||
| endpoint_url: String, | ||
| timeout_ms: Option<u64>, | ||
| auth_token: Option<String>, | ||
| ) -> Result<Self> { | ||
| let endpoint_url = endpoint_url.trim().to_string(); | ||
|
|
||
| Url::parse(&endpoint_url) | ||
| .context("Invalid endpoint URL format. Must be a valid HTTP/HTTPS URL")?; | ||
|
|
||
| let auth_token = auth_token | ||
| .map(|t| t.trim().to_string()) | ||
| .filter(|t| !t.is_empty()); | ||
|
|
||
| tracing::info!( | ||
| endpoint = %endpoint_url, | ||
| has_token = auth_token.is_some(), | ||
| "Creating classification client from endpoint" | ||
| ); | ||
|
|
||
| Self::new(endpoint_url, timeout_ms, auth_token, None) | ||
| } | ||
|
|
||
| pub async fn classify(&self, text: &str) -> Result<f32> { | ||
| tracing::debug!( | ||
| endpoint = %self.endpoint_url, | ||
| text_length = text.len(), | ||
| "Sending classification request" | ||
| ); | ||
|
|
||
| let parameters = self | ||
| .extra_params | ||
| .as_ref() | ||
| .map(serde_json::to_value) | ||
| .transpose()?; | ||
|
|
||
| let request = ClassificationRequest { | ||
| inputs: text.to_string(), | ||
| parameters, | ||
| }; | ||
|
|
||
| let mut request_builder = self.client.post(&self.endpoint_url).json(&request); | ||
|
|
||
| if let Some(token) = &self.auth_token { | ||
| request_builder = request_builder.header("Authorization", format!("Bearer {}", token)); | ||
| } | ||
|
|
||
| let response = request_builder | ||
| .send() | ||
| .await | ||
| .context("Failed to send classification request")?; | ||
|
|
||
| let status = response.status(); | ||
| let response = if !status.is_success() { | ||
| let error_body = response.text().await.unwrap_or_default(); | ||
| return Err(anyhow::anyhow!( | ||
| "Classification API returned error status {}: {}", | ||
| status, | ||
| error_body | ||
| )); | ||
| } else { | ||
| response | ||
| }; | ||
|
|
||
| let classification_response: ClassificationResponse = response | ||
| .json() | ||
| .await | ||
| .context("Failed to parse classification response")?; | ||
|
|
||
| let batch_result = classification_response | ||
| .first() | ||
| .context("Classification API returned empty response")?; | ||
|
|
||
| let sum: f32 = batch_result.iter().map(|l| l.score).sum(); | ||
| let is_probabilities = batch_result | ||
| .iter() | ||
| .all(|label| label.score >= 0.0 && label.score <= 1.0) | ||
| && (sum - 1.0).abs() < 0.1; | ||
|
|
||
| let normalized_results: Vec<ClassificationLabel> = if is_probabilities { | ||
| batch_result.to_vec() | ||
| } else { | ||
| self.apply_softmax(batch_result)? | ||
| }; | ||
|
|
||
| let top_label = normalized_results | ||
| .iter() | ||
| .max_by(|a, b| { | ||
| a.score | ||
| .partial_cmp(&b.score) | ||
| .unwrap_or(std::cmp::Ordering::Equal) | ||
| }) | ||
| .context("Classification API returned no labels")?; | ||
|
|
||
| let injection_score = match top_label.label.as_str() { | ||
| "INJECTION" | "LABEL_1" => top_label.score, | ||
| "SAFE" | "LABEL_0" => 1.0 - top_label.score, | ||
| _ => { | ||
| tracing::warn!( | ||
| label = %top_label.label, | ||
| score = %top_label.score, | ||
| "Unknown classification label, defaulting to safe" | ||
| ); | ||
| 0.0 | ||
| } | ||
| }; | ||
|
|
||
| tracing::info!( | ||
| injection_score = %injection_score, | ||
| top_label = %top_label.label, | ||
| top_score = %top_label.score, | ||
| normalized = !is_probabilities, | ||
| "Classification complete" | ||
| ); | ||
|
|
||
| Ok(injection_score) | ||
| } | ||
|
|
||
| fn apply_softmax(&self, labels: &[ClassificationLabel]) -> Result<Vec<ClassificationLabel>> { | ||
| if labels.is_empty() { | ||
| return Ok(Vec::new()); | ||
| } | ||
|
|
||
| let max_score = labels | ||
| .iter() | ||
| .map(|l| l.score) | ||
| .fold(f32::NEG_INFINITY, f32::max); | ||
|
|
||
| let exp_scores: Vec<f32> = labels.iter().map(|l| (l.score - max_score).exp()).collect(); | ||
|
|
||
| let sum_exp: f32 = exp_scores.iter().sum(); | ||
|
|
||
| if sum_exp == 0.0 || !sum_exp.is_finite() { | ||
| anyhow::bail!("Softmax normalization failed: invalid sum"); | ||
| } | ||
|
|
||
| let normalized: Vec<ClassificationLabel> = labels | ||
| .iter() | ||
| .zip(exp_scores.iter()) | ||
| .map(|(label, &exp_score)| ClassificationLabel { | ||
| label: label.label.clone(), | ||
| score: exp_score / sum_exp, | ||
| }) | ||
| .collect(); | ||
|
|
||
| Ok(normalized) | ||
| } | ||
| } |
There was a problem hiding this comment.
The classification_client module has no test coverage. Consider adding unit tests for the ClassificationClient, especially for critical paths like softmax normalization, error handling for malformed responses, and label interpretation logic.
|
|
||
| ## Overview | ||
|
|
||
| Goose requires a classification endpoint that can analyze text and return a score indicating the likelihood of prompt injection. This API follows the **HuggingFace Inference API format** for text classification, making it compatible with HuggingFace Inference Endpoints to allow for easy usage in OSS goose. |
There was a problem hiding this comment.
yeah, I understand, but we should write the documentation for the OSS people. I would just drop anything after ", making"
| :::info Automatic Multi-Model Configuration | ||
| The experimental [AutoPilot](/docs/guides/multi-model/autopilot) feature provides intelligent, context-aware model switching. Configure models for different roles using the `x-advanced-models` setting. | ||
| ::: | ||
|
|
There was a problem hiding this comment.
The info box about AutoPilot configuration appears to be incorrectly placed here, as it's unrelated to the security configuration settings being documented. This content should either be moved to the appropriate section about multi-model configuration or removed from this location.
| :::info Automatic Multi-Model Configuration | |
| The experimental [AutoPilot](/docs/guides/multi-model/autopilot) feature provides intelligent, context-aware model switching. Configure models for different roles using the `x-advanced-models` setting. | |
| ::: |
| let max_confidence = stream::iter(user_messages) | ||
| .map(|msg| async move { self.scan_with_classifier(&msg).await }) | ||
| .buffer_unordered(ML_SCAN_CONCURRENCY) | ||
| .fold(0.0_f32, |acc, result| async move { | ||
| result.unwrap_or(0.0).max(acc) | ||
| }) | ||
| .await; |
There was a problem hiding this comment.
The concurrent message scanning could create a performance bottleneck when processing 10 user messages with concurrent HTTP requests. The ML_SCAN_CONCURRENCY of 3 means up to 3 simultaneous HTTP requests, but if each request takes 5 seconds (the default timeout), this could add up to 17 seconds in the worst case for tool execution. Consider adding a circuit breaker or reducing the timeout for conversation scans specifically, or making the scan asynchronous to avoid blocking tool execution.
…that is evaluated for potential prompt injection
| | `otel_exporter_otlp_timeout` | Export timeout in milliseconds for [observability](/docs/guides/environment-variables#opentelemetry-protocol-otlp) | Integer (ms) | 10000 | No | | ||
| | `SECURITY_PROMPT_ENABLED` | Enable [prompt injection detection](/docs/guides/security/prompt-injection-detection) to identify potentially harmful commands | true/false | false | No | | ||
| | `SECURITY_PROMPT_THRESHOLD` | Sensitivity threshold for [prompt injection detection](/docs/guides/security/prompt-injection-detection) (higher = stricter) | Float between 0.01 and 1.0 | 0.7 | No | | ||
| | `SECURITY_PROMPT_ENABLED` | Enable prompt injection detection to identify potentially harmful commands | true/false | false | No | |
There was a problem hiding this comment.
@dorien-koelemeijer Can you please revert the changes to this existing topic for now? Otherwise the docs will go out before this feature is available in a release.
We'll add them back in after the feature is released (except the note below because autopilot has been removed)
There was a problem hiding this comment.
Thanks for the comments! Do you mean it's best to revert all changes to this file for now?
| @@ -0,0 +1,89 @@ | |||
| # Classification API Specification | |||
There was a problem hiding this comment.
| # Classification API Specification | |
| --- | |
| title: Classification API Specification | |
| unlisted: true | |
| --- |
| unlisted: true | ||
| --- | ||
|
|
||
| This document defines the API that Goose uses for ML-based prompt injection detection. |
There was a problem hiding this comment.
The reference to "Goose" should be lowercase "goose" according to the project naming convention documented in HOWTOAI.md.
| **Goose's Usage:** | ||
| - Goose looks for the label with the highest score |
There was a problem hiding this comment.
The reference to "Goose" should be lowercase "goose" according to the project naming convention documented in HOWTOAI.md.
| **Goose's Usage:** | ||
| - Goose looks for the label with the highest score | ||
| - If the top label is "INJECTION" (or "LABEL_1"), the score is used as the injection confidence | ||
| - If the top label is "SAFE" (or "LABEL_0"), Goose uses `1.0 - score` as the injection confidence |
There was a problem hiding this comment.
The reference to "Goose" should be lowercase "goose" according to the project naming convention documented in HOWTOAI.md.
Summary
This PR adds BERT model prompt injection evaluation alongside the existing pattern-based approach.
Context
https://docs.google.com/document/d/1GNvriNWLAaJUMpWE1heBapxFshQEED5ZYfn1ixJvGCE/edit?tab=t.0#heading=h.rj4p6llxqvrq
Key Changes
BERTin the name of the variables so we can create a follow-up PR to also allow LLM-as-a-judge type prompt injection detection by re-using people's providers, and hopefully prevent confusing names of variables.Planned follow-up PRs / Related PRs
Type of Change
Screenshots of changes
Internal:

External:
