Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 111 additions & 5 deletions crates/goose/src/agents/prompt_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::agents::extension::ExtensionInfo;
use crate::agents::router_tool_selector::RouterToolSelectionStrategy;
use crate::agents::router_tools::{llm_search_tool_prompt, vector_search_tool_prompt};
use crate::providers::base::get_current_model;
use crate::{config::Config, prompt_template};
use crate::{config::Config, prompt_template, utils::sanitize_unicode_tags};

pub struct PromptManager {
system_prompt_override: Option<String>,
Expand Down Expand Up @@ -83,7 +83,18 @@ impl PromptManager {
));
}

context.insert("extensions", serde_json::to_value(extensions_info).unwrap());
let sanitized_extensions_info: Vec<ExtensionInfo> = extensions_info
.into_iter()
.map(|mut ext_info| {
ext_info.instructions = sanitize_unicode_tags(&ext_info.instructions);
ext_info
})
.collect();

context.insert(
"extensions",
serde_json::to_value(sanitized_extensions_info).unwrap(),
);

match tool_selection_strategy {
Some(RouterToolSelectionStrategy::Vector) => {
Expand Down Expand Up @@ -118,7 +129,8 @@ impl PromptManager {

// Conditionally load the override prompt or the global system prompt
let base_prompt = if let Some(override_prompt) = &self.system_prompt_override {
prompt_template::render_inline_once(override_prompt, &context)
let sanitized_override_prompt = sanitize_unicode_tags(override_prompt);
prompt_template::render_inline_once(&sanitized_override_prompt, &context)
.expect("Prompt should render")
} else if let Some(model) = &model_to_use {
// Use the fuzzy mapping to determine the prompt file, or fall back to legacy logic
Expand Down Expand Up @@ -149,13 +161,18 @@ impl PromptManager {
.push("Right now you are *NOT* in the chat only mode and have access to tool use and system.".to_string());
}

if system_prompt_extras.is_empty() {
let sanitized_system_prompt_extras: Vec<String> = system_prompt_extras
.into_iter()
.map(|extra| sanitize_unicode_tags(&extra))
.collect();

if sanitized_system_prompt_extras.is_empty() {
base_prompt
} else {
format!(
"{}\n\n# Additional Instructions:\n\n{}",
base_prompt,
system_prompt_extras.join("\n\n")
sanitized_system_prompt_extras.join("\n\n")
)
}
}
Expand Down Expand Up @@ -221,4 +238,93 @@ mod tests {
"system.md"
);
}

#[test]
fn test_build_system_prompt_sanitizes_override() {
let mut manager = PromptManager::new();
let malicious_override = "System prompt\u{E0041}\u{E0042}\u{E0043}with hidden text";
manager.set_system_prompt_override(malicious_override.to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
assert!(!result.contains('\u{E0043}'));
assert!(result.contains("System prompt"));
assert!(result.contains("with hidden text"));
}

#[test]
fn test_build_system_prompt_sanitizes_extras() {
let mut manager = PromptManager::new();
let malicious_extra = "Extra instruction\u{E0041}\u{E0042}\u{E0043}hidden";
manager.add_system_prompt_extra(malicious_extra.to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
assert!(!result.contains('\u{E0043}'));
assert!(result.contains("Extra instruction"));
assert!(result.contains("hidden"));
}

#[test]
fn test_build_system_prompt_sanitizes_multiple_extras() {
let mut manager = PromptManager::new();
manager.add_system_prompt_extra("First\u{E0041}instruction".to_string());
manager.add_system_prompt_extra("Second\u{E0042}instruction".to_string());
manager.add_system_prompt_extra("Third\u{E0043}instruction".to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
assert!(!result.contains('\u{E0043}'));
assert!(result.contains("Firstinstruction"));
assert!(result.contains("Secondinstruction"));
assert!(result.contains("Thirdinstruction"));
}

#[test]
fn test_build_system_prompt_preserves_legitimate_unicode_in_extras() {
let mut manager = PromptManager::new();
let legitimate_unicode = "Instruction with 世界 and 🌍 emojis";
manager.add_system_prompt_extra(legitimate_unicode.to_string());

let result =
manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None);

assert!(result.contains("世界"));
assert!(result.contains("🌍"));
assert!(result.contains("Instruction with"));
assert!(result.contains("emojis"));
}

#[test]
fn test_build_system_prompt_sanitizes_extension_instructions() {
let manager = PromptManager::new();
let malicious_extension_info = ExtensionInfo::new(
"test_extension",
"Extension help\u{E0041}\u{E0042}\u{E0043}hidden instructions",
false,
);

let result = manager.build_system_prompt(
vec![malicious_extension_info],
None,
Value::String("".to_string()),
None,
None,
);

assert!(!result.contains('\u{E0041}'));
assert!(!result.contains('\u{E0042}'));
assert!(!result.contains('\u{E0043}'));
assert!(result.contains("Extension help"));
assert!(result.contains("hidden instructions"));
}
}
26 changes: 1 addition & 25 deletions crates/goose/src/conversation/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,10 @@ use serde::{Deserialize, Deserializer, Serialize};
use serde_json::Value;
use std::collections::HashSet;
use std::fmt;
use unicode_normalization::UnicodeNormalization;
use utoipa::ToSchema;

use crate::conversation::tool_result_serde;

/// Sanitize Unicode Tags Block characters from text
fn sanitize_unicode_tags(text: &str) -> String {
let normalized: String = text.nfc().collect();

normalized
.chars()
.filter(|&c| !matches!(c, '\u{E0000}'..='\u{E007F}'))
.collect()
}
use crate::utils::sanitize_unicode_tags;

/// Custom deserializer for MessageContent that sanitizes Unicode Tags in text content
fn deserialize_sanitized_content<'de, D>(deserializer: D) -> Result<Vec<MessageContent>, D::Error>
Expand Down Expand Up @@ -611,20 +601,6 @@ mod tests {
use rmcp::model::{ErrorCode, ErrorData};
use serde_json::{json, Value};

#[test]
fn test_sanitize_unicode_tags() {
let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC"
let cleaned = super::sanitize_unicode_tags(malicious);
assert_eq!(cleaned, "Helloworld");
}

#[test]
fn test_no_sanitize_unicode_tags() {
let clean_text = "Hello world 世界 🌍";
let cleaned = super::sanitize_unicode_tags(clean_text);
assert_eq!(cleaned, clean_text);
}

#[test]
fn test_sanitize_with_text() {
let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC"
Expand Down
54 changes: 54 additions & 0 deletions crates/goose/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@
use tokio_util::sync::CancellationToken;
use unicode_normalization::UnicodeNormalization;

/// Sanitize Unicode Tags Block characters from text
/// Used to prevent Unicode-based prompt injection attacks
///
/// This function removes invisible Unicode Tags Block characters (U+E0000-U+E007F)
/// that can be used for steganographic attacks while preserving legitimate Unicode.
pub fn sanitize_unicode_tags(text: &str) -> String {
let normalized: String = text.nfc().collect();

normalized
.chars()
.filter(|&c| !matches!(c, '\u{E0000}'..='\u{E007F}'))
.collect()
}

/// Safely truncate a string at character boundaries, not byte boundaries
///
Expand Down Expand Up @@ -30,6 +45,45 @@ pub fn is_token_cancelled(cancellation_token: &Option<CancellationToken>) -> boo
mod tests {
use super::*;

#[test]
fn test_sanitize_unicode_tags() {
// Test that Unicode Tags Block characters are removed
let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC"
let cleaned = sanitize_unicode_tags(malicious);
assert_eq!(cleaned, "Helloworld");
}

#[test]
fn test_sanitize_unicode_tags_preserves_legitimate_unicode() {
// Test that legitimate Unicode characters are preserved
let clean_text = "Hello world 世界 🌍";
let cleaned = sanitize_unicode_tags(clean_text);
assert_eq!(cleaned, clean_text);
}

#[test]
fn test_sanitize_unicode_tags_empty_string() {
let empty = "";
let cleaned = sanitize_unicode_tags(empty);
assert_eq!(cleaned, "");
}

#[test]
fn test_sanitize_unicode_tags_only_malicious() {
// Test string containing only Unicode Tags characters
let only_malicious = "\u{E0041}\u{E0042}\u{E0043}";
let cleaned = sanitize_unicode_tags(only_malicious);
assert_eq!(cleaned, "");
}

#[test]
fn test_sanitize_unicode_tags_mixed_content() {
// Test mixed legitimate and malicious Unicode
let mixed = "Hello\u{E0041} 世界\u{E0042} 🌍\u{E0043}!";
let cleaned = sanitize_unicode_tags(mixed);
assert_eq!(cleaned, "Hello 世界 🌍!");
}

#[test]
fn test_safe_truncate_ascii() {
assert_eq!(safe_truncate("hello world", 20), "hello world");
Expand Down
Loading