Skip to content

Add ML-based prompt injection detection #5623

Merged
dorien-koelemeijer merged 44 commits into
mainfrom
feat/ml-based-prompt-injection-detection
Jan 8, 2026
Merged

Add ML-based prompt injection detection #5623
dorien-koelemeijer merged 44 commits into
mainfrom
feat/ml-based-prompt-injection-detection

Conversation

@dorien-koelemeijer
Copy link
Copy Markdown
Collaborator

@dorien-koelemeijer dorien-koelemeijer commented Nov 7, 2025

Summary

This PR adds BERT model prompt injection evaluation alongside the existing pattern-based approach.

Context

https://docs.google.com/document/d/1GNvriNWLAaJUMpWE1heBapxFshQEED5ZYfn1ixJvGCE/edit?tab=t.0#heading=h.rj4p6llxqvrq

Key Changes

  • Updated PromptInjectionScanner to support both pattern-based + ML-based scanning, using the highest confidence score from either method to determine threats. ML-based scanning also looks at recent user messages in addition to tool call content.
  • Added ClassificationClient - Generic HTTP client that is compatible with HuggingFace's text classification API, supporting both internal users (Gondola hosted BERT models) and external/oss users (direct endpoints, such as using the hugging face API). The ML inference team is creating an additional wrapper endpoint that allows us to use their BatchInfer API with the same inputs and outputs as the Hugging Face text classification API.
  • Updated UI settings to allow users to configure what models to use if they decide to enable ML-based prompt injection detection. Have included BERT in the name of the variables so we can create a follow-up PR to also allow LLM-as-a-judge type prompt injection detection by re-using people's providers, and hopefully prevent confusing names of variables.

Planned follow-up PRs / Related PRs

  • Add a bunch of variables to goose-releases (will open PR shortly)
  • Re-use provider to do ML-based prompt injection detection rather than users requiring setup of the BERT models (mostly for oss/external users)
  • Provide reference implementation for users who want to run their BERT models locally (I've already prepared this but didn't want to make this PR super huge + I am not sure where best to store this info in the repo)
  • Waiting for this PR to get merged: https://github.com/squareup/gondola/pull/1085

Type of Change

  • Feature
  • Bug fix
  • Refactor / Code quality
  • Performance improvement
  • Documentation
  • Tests
  • Security fix
  • Build / Release
  • Other (specify below)

Screenshots of changes

Internal:
Screenshot 2025-11-26 at 11 51 36 am

Screenshot 2025-11-26 at 11 54 04 am

External:
Screenshot 2025-11-26 at 11 57 58 am

Screenshot 2025-11-26 at 11 58 46 am

@dorien-koelemeijer dorien-koelemeijer requested a review from a team as a code owner November 7, 2025 01:21
Copilot AI review requested due to automatic review settings November 7, 2025 01:21
@dorien-koelemeijer dorien-koelemeijer marked this pull request as draft November 7, 2025 01:21
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR standardizes security configuration keys from snake_case to UPPER_CASE convention and introduces ML-based prompt injection detection using BERT models through a new Gondola provider. The changes enable more accurate security threat detection while maintaining backward compatibility through pattern-based scanning as a fallback.

Key changes:

  • Configuration key standardization: security_prompt_*SECURITY_PROMPT_*
  • New ML detection infrastructure with Gondola provider for BERT-based model inference
  • Enhanced scanner to combine pattern-based and ML-based detection results

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
ui/desktop/src/utils/configUtils.ts Updated config labels to use new UPPER_CASE keys and added ML detection settings
ui/desktop/src/components/settings/security/SecurityToggle.tsx Added ML detection toggle UI with model selection dropdown
documentation/docs/guides/security/prompt-injection-detection.md Updated config examples to use new UPPER_CASE keys
documentation/docs/guides/config-files.md Updated configuration reference table with new ML detection settings and standardized keys
crates/goose/src/security/scanner.rs Refactored to support optional ML detection, enhanced conversation context scanning, and simplified tool content extraction
crates/goose/src/security/prompt_ml_detector.rs New ML detector implementation with model registry and Gondola provider integration
crates/goose/src/security/mod.rs Updated to initialize scanner with ML detection when enabled and handle fallback scenarios
crates/goose/src/providers/mod.rs Added gondola module to provider list
crates/goose/src/providers/gondola.rs New Gondola provider implementation for batch inference with BERT models

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread documentation/docs/guides/config-files.md Outdated
Comment thread crates/goose/src/security/scanner.rs Outdated
Comment thread crates/goose/src/security/scanner.rs Outdated
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Nov 7, 2025

PR Preview Action v1.6.3
Preview removed because the pull request was closed.
2026-01-08 01:59 UTC

@dorien-koelemeijer dorien-koelemeijer force-pushed the feat/ml-based-prompt-injection-detection branch from 1f5fd06 to 9fe268c Compare November 7, 2025 01:29
@dorien-koelemeijer dorien-koelemeijer force-pushed the feat/ml-based-prompt-injection-detection branch from 459eb56 to a0d6d5c Compare November 7, 2025 06:07
@dorien-koelemeijer dorien-koelemeijer changed the title Add ML-based prompt injection detection using gondola-hosted BERT model [WIP] Add ML-based prompt injection detection using gondola-hosted BERT model Nov 11, 2025
@dorien-koelemeijer dorien-koelemeijer force-pushed the feat/ml-based-prompt-injection-detection branch from eff6fbe to ba5feee Compare November 11, 2025 05:34
@dorien-koelemeijer dorien-koelemeijer marked this pull request as ready for review November 11, 2025 05:35
Copilot AI review requested due to automatic review settings November 11, 2025 05:35
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 6 comments.

Comment thread ui/desktop/src/components/settings/security/SecurityToggle.tsx Outdated
Comment thread ui/desktop/src/components/settings/security/SecurityToggle.tsx Outdated
Comment thread crates/goose/src/security/scanner.rs
Comment thread crates/goose/src/security/prompt_ml_detector.rs Outdated
Comment thread ui/desktop/src/components/settings/security/SecurityToggle.tsx Outdated
@aaif-goose aaif-goose deleted a comment from Copilot AI Nov 11, 2025
Copilot AI review requested due to automatic review settings November 11, 2025 05:57
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

Comment thread ui/desktop/src/components/settings/security/SecurityToggle.tsx Outdated
Comment thread crates/goose/src/providers/gondola.rs Outdated
Comment thread crates/goose/src/security/scanner.rs
@dorien-koelemeijer dorien-koelemeijer force-pushed the feat/ml-based-prompt-injection-detection branch from ef65a16 to fce20ae Compare November 11, 2025 06:10
@aaif-goose aaif-goose deleted a comment from Copilot AI Dec 15, 2025
@aaif-goose aaif-goose deleted a comment from Copilot AI Dec 15, 2025
@aaif-goose aaif-goose deleted a comment from Copilot AI Dec 15, 2025
@aaif-goose aaif-goose deleted a comment from Copilot AI Dec 15, 2025
Copilot AI review requested due to automatic review settings December 15, 2025 03:10
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Comment on lines +1 to +240
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
use url::Url;

/// Request format following HuggingFace Inference Text Classification API specification
#[derive(Debug, Serialize)]
struct ClassificationRequest {
inputs: String,
#[serde(skip_serializing_if = "Option::is_none")]
parameters: Option<serde_json::Value>,
}

#[derive(Debug, Deserialize, Clone)]
struct ClassificationLabel {
label: String,
score: f32,
}

type ClassificationResponse = Vec<Vec<ClassificationLabel>>;

#[derive(Debug, Deserialize, Clone)]
pub struct ModelEndpointInfo {
pub endpoint: String,
#[serde(flatten)]
pub extra_params: HashMap<String, serde_json::Value>,
}

#[derive(Debug, Deserialize, Clone)]
pub struct ModelMappingConfig {
#[serde(flatten)]
pub models: HashMap<String, ModelEndpointInfo>,
}

#[derive(Debug)]
pub struct ClassificationClient {
endpoint_url: String,
client: reqwest::Client,
auth_token: Option<String>,
extra_params: Option<HashMap<String, serde_json::Value>>,
}

impl ClassificationClient {
pub fn new(
endpoint_url: String,
timeout_ms: Option<u64>,
auth_token: Option<String>,
extra_params: Option<HashMap<String, serde_json::Value>>,
) -> Result<Self> {
let timeout = Duration::from_millis(timeout_ms.unwrap_or(5000));

let client = reqwest::Client::builder()
.timeout(timeout)
.build()
.context("Failed to create HTTP client")?;

Ok(Self {
endpoint_url,
client,
auth_token,
extra_params,
})
}

pub fn from_model_name(model_name: &str, timeout_ms: Option<u64>) -> Result<Self> {
let mapping_json = std::env::var("SECURITY_ML_MODEL_MAPPING")
.context("SECURITY_ML_MODEL_MAPPING environment variable not set")?;

let mapping = serde_json::from_str::<ModelMappingConfig>(&mapping_json)
.context("Failed to parse SECURITY_ML_MODEL_MAPPING JSON")?;

let model_info = mapping.models.get(model_name).context(format!(
"Model '{}' not found in SECURITY_ML_MODEL_MAPPING",
model_name
))?;

tracing::info!(
model_name = %model_name,
endpoint = %model_info.endpoint,
extra_params = ?model_info.extra_params,
"Creating classification client from model mapping"
);

Self::new(
model_info.endpoint.clone(),
timeout_ms,
None,
Some(model_info.extra_params.clone()),
)
}

pub fn from_endpoint(
endpoint_url: String,
timeout_ms: Option<u64>,
auth_token: Option<String>,
) -> Result<Self> {
let endpoint_url = endpoint_url.trim().to_string();

Url::parse(&endpoint_url)
.context("Invalid endpoint URL format. Must be a valid HTTP/HTTPS URL")?;

let auth_token = auth_token
.map(|t| t.trim().to_string())
.filter(|t| !t.is_empty());

tracing::info!(
endpoint = %endpoint_url,
has_token = auth_token.is_some(),
"Creating classification client from endpoint"
);

Self::new(endpoint_url, timeout_ms, auth_token, None)
}

pub async fn classify(&self, text: &str) -> Result<f32> {
tracing::debug!(
endpoint = %self.endpoint_url,
text_length = text.len(),
"Sending classification request"
);

let parameters = self
.extra_params
.as_ref()
.map(serde_json::to_value)
.transpose()?;

let request = ClassificationRequest {
inputs: text.to_string(),
parameters,
};

let mut request_builder = self.client.post(&self.endpoint_url).json(&request);

if let Some(token) = &self.auth_token {
request_builder = request_builder.header("Authorization", format!("Bearer {}", token));
}

let response = request_builder
.send()
.await
.context("Failed to send classification request")?;

let status = response.status();
let response = if !status.is_success() {
let error_body = response.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"Classification API returned error status {}: {}",
status,
error_body
));
} else {
response
};

let classification_response: ClassificationResponse = response
.json()
.await
.context("Failed to parse classification response")?;

let batch_result = classification_response
.first()
.context("Classification API returned empty response")?;

let sum: f32 = batch_result.iter().map(|l| l.score).sum();
let is_probabilities = batch_result
.iter()
.all(|label| label.score >= 0.0 && label.score <= 1.0)
&& (sum - 1.0).abs() < 0.1;

let normalized_results: Vec<ClassificationLabel> = if is_probabilities {
batch_result.to_vec()
} else {
self.apply_softmax(batch_result)?
};

let top_label = normalized_results
.iter()
.max_by(|a, b| {
a.score
.partial_cmp(&b.score)
.unwrap_or(std::cmp::Ordering::Equal)
})
.context("Classification API returned no labels")?;

let injection_score = match top_label.label.as_str() {
"INJECTION" | "LABEL_1" => top_label.score,
"SAFE" | "LABEL_0" => 1.0 - top_label.score,
_ => {
tracing::warn!(
label = %top_label.label,
score = %top_label.score,
"Unknown classification label, defaulting to safe"
);
0.0
}
};

tracing::info!(
injection_score = %injection_score,
top_label = %top_label.label,
top_score = %top_label.score,
normalized = !is_probabilities,
"Classification complete"
);

Ok(injection_score)
}

fn apply_softmax(&self, labels: &[ClassificationLabel]) -> Result<Vec<ClassificationLabel>> {
if labels.is_empty() {
return Ok(Vec::new());
}

let max_score = labels
.iter()
.map(|l| l.score)
.fold(f32::NEG_INFINITY, f32::max);

let exp_scores: Vec<f32> = labels.iter().map(|l| (l.score - max_score).exp()).collect();

let sum_exp: f32 = exp_scores.iter().sum();

if sum_exp == 0.0 || !sum_exp.is_finite() {
anyhow::bail!("Softmax normalization failed: invalid sum");
}

let normalized: Vec<ClassificationLabel> = labels
.iter()
.zip(exp_scores.iter())
.map(|(label, &exp_score)| ClassificationLabel {
label: label.label.clone(),
score: exp_score / sum_exp,
})
.collect();

Ok(normalized)
}
}
Copy link

Copilot AI Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The classification_client module has no test coverage. Consider adding unit tests for the ClassificationClient, especially for critical paths like softmax normalization, error handling for malformed responses, and label interpretation logic.

Copilot uses AI. Check for mistakes.
Comment thread ui/desktop/src/components/settings/security/SecurityToggle.tsx Outdated

## Overview

Goose requires a classification endpoint that can analyze text and return a score indicating the likelihood of prompt injection. This API follows the **HuggingFace Inference API format** for text classification, making it compatible with HuggingFace Inference Endpoints to allow for easy usage in OSS goose.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, I understand, but we should write the documentation for the OSS people. I would just drop anything after ", making"

Comment thread crates/goose/src/security/scanner.rs Outdated
@aaif-goose aaif-goose deleted a comment from Copilot AI Jan 5, 2026
Copilot AI review requested due to automatic review settings January 5, 2026 01:55
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Comment on lines 53 to 56
:::info Automatic Multi-Model Configuration
The experimental [AutoPilot](/docs/guides/multi-model/autopilot) feature provides intelligent, context-aware model switching. Configure models for different roles using the `x-advanced-models` setting.
:::

Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The info box about AutoPilot configuration appears to be incorrectly placed here, as it's unrelated to the security configuration settings being documented. This content should either be moved to the appropriate section about multi-model configuration or removed from this location.

Suggested change
:::info Automatic Multi-Model Configuration
The experimental [AutoPilot](/docs/guides/multi-model/autopilot) feature provides intelligent, context-aware model switching. Configure models for different roles using the `x-advanced-models` setting.
:::

Copilot uses AI. Check for mistakes.
Comment on lines +157 to +163
let max_confidence = stream::iter(user_messages)
.map(|msg| async move { self.scan_with_classifier(&msg).await })
.buffer_unordered(ML_SCAN_CONCURRENCY)
.fold(0.0_f32, |acc, result| async move {
result.unwrap_or(0.0).max(acc)
})
.await;
Copy link

Copilot AI Jan 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The concurrent message scanning could create a performance bottleneck when processing 10 user messages with concurrent HTTP requests. The ML_SCAN_CONCURRENCY of 3 means up to 3 simultaneous HTTP requests, but if each request takes 5 seconds (the default timeout), this could add up to 17 seconds in the worst case for tool execution. Consider adding a circuit breaker or reducing the timeout for conversation scans specifically, or making the scan asynchronous to avoid blocking tool execution.

Copilot uses AI. Check for mistakes.
Copilot AI review requested due to automatic review settings January 5, 2026 02:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

@aaif-goose aaif-goose deleted a comment from Copilot AI Jan 5, 2026
@aaif-goose aaif-goose deleted a comment from Copilot AI Jan 5, 2026
| `otel_exporter_otlp_timeout` | Export timeout in milliseconds for [observability](/docs/guides/environment-variables#opentelemetry-protocol-otlp) | Integer (ms) | 10000 | No |
| `SECURITY_PROMPT_ENABLED` | Enable [prompt injection detection](/docs/guides/security/prompt-injection-detection) to identify potentially harmful commands | true/false | false | No |
| `SECURITY_PROMPT_THRESHOLD` | Sensitivity threshold for [prompt injection detection](/docs/guides/security/prompt-injection-detection) (higher = stricter) | Float between 0.01 and 1.0 | 0.7 | No |
| `SECURITY_PROMPT_ENABLED` | Enable prompt injection detection to identify potentially harmful commands | true/false | false | No |
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dorien-koelemeijer Can you please revert the changes to this existing topic for now? Otherwise the docs will go out before this feature is available in a release.

We'll add them back in after the feature is released (except the note below because autopilot has been removed)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comments! Do you mean it's best to revert all changes to this file for now?

@@ -0,0 +1,89 @@
# Classification API Specification
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Classification API Specification
---
title: Classification API Specification
unlisted: true
---

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 3 comments.

unlisted: true
---

This document defines the API that Goose uses for ML-based prompt injection detection.
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference to "Goose" should be lowercase "goose" according to the project naming convention documented in HOWTOAI.md.

Copilot generated this review using guidance from repository custom instructions.
Comment on lines +71 to +72
**Goose's Usage:**
- Goose looks for the label with the highest score
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference to "Goose" should be lowercase "goose" according to the project naming convention documented in HOWTOAI.md.

Copilot generated this review using guidance from repository custom instructions.
**Goose's Usage:**
- Goose looks for the label with the highest score
- If the top label is "INJECTION" (or "LABEL_1"), the score is used as the injection confidence
- If the top label is "SAFE" (or "LABEL_0"), Goose uses `1.0 - score` as the injection confidence
Copy link

Copilot AI Jan 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reference to "Goose" should be lowercase "goose" according to the project naming convention documented in HOWTOAI.md.

Copilot generated this review using guidance from repository custom instructions.
Copy link
Copy Markdown
Contributor

@dianed-square dianed-square left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants