diff --git a/crates/goose/src/agents/prompt_manager.rs b/crates/goose/src/agents/prompt_manager.rs index f7cdb01acb83..e2973d3b8989 100644 --- a/crates/goose/src/agents/prompt_manager.rs +++ b/crates/goose/src/agents/prompt_manager.rs @@ -6,7 +6,7 @@ use crate::agents::extension::ExtensionInfo; use crate::agents::router_tool_selector::RouterToolSelectionStrategy; use crate::agents::router_tools::{llm_search_tool_prompt, vector_search_tool_prompt}; use crate::providers::base::get_current_model; -use crate::{config::Config, prompt_template}; +use crate::{config::Config, prompt_template, utils::sanitize_unicode_tags}; pub struct PromptManager { system_prompt_override: Option, @@ -83,7 +83,18 @@ impl PromptManager { )); } - context.insert("extensions", serde_json::to_value(extensions_info).unwrap()); + let sanitized_extensions_info: Vec = extensions_info + .into_iter() + .map(|mut ext_info| { + ext_info.instructions = sanitize_unicode_tags(&ext_info.instructions); + ext_info + }) + .collect(); + + context.insert( + "extensions", + serde_json::to_value(sanitized_extensions_info).unwrap(), + ); match tool_selection_strategy { Some(RouterToolSelectionStrategy::Vector) => { @@ -118,7 +129,8 @@ impl PromptManager { // Conditionally load the override prompt or the global system prompt let base_prompt = if let Some(override_prompt) = &self.system_prompt_override { - prompt_template::render_inline_once(override_prompt, &context) + let sanitized_override_prompt = sanitize_unicode_tags(override_prompt); + prompt_template::render_inline_once(&sanitized_override_prompt, &context) .expect("Prompt should render") } else if let Some(model) = &model_to_use { // Use the fuzzy mapping to determine the prompt file, or fall back to legacy logic @@ -149,13 +161,18 @@ impl PromptManager { .push("Right now you are *NOT* in the chat only mode and have access to tool use and system.".to_string()); } - if system_prompt_extras.is_empty() { + let sanitized_system_prompt_extras: Vec = system_prompt_extras + .into_iter() + .map(|extra| sanitize_unicode_tags(&extra)) + .collect(); + + if sanitized_system_prompt_extras.is_empty() { base_prompt } else { format!( "{}\n\n# Additional Instructions:\n\n{}", base_prompt, - system_prompt_extras.join("\n\n") + sanitized_system_prompt_extras.join("\n\n") ) } } @@ -221,4 +238,93 @@ mod tests { "system.md" ); } + + #[test] + fn test_build_system_prompt_sanitizes_override() { + let mut manager = PromptManager::new(); + let malicious_override = "System prompt\u{E0041}\u{E0042}\u{E0043}with hidden text"; + manager.set_system_prompt_override(malicious_override.to_string()); + + let result = + manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None); + + assert!(!result.contains('\u{E0041}')); + assert!(!result.contains('\u{E0042}')); + assert!(!result.contains('\u{E0043}')); + assert!(result.contains("System prompt")); + assert!(result.contains("with hidden text")); + } + + #[test] + fn test_build_system_prompt_sanitizes_extras() { + let mut manager = PromptManager::new(); + let malicious_extra = "Extra instruction\u{E0041}\u{E0042}\u{E0043}hidden"; + manager.add_system_prompt_extra(malicious_extra.to_string()); + + let result = + manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None); + + assert!(!result.contains('\u{E0041}')); + assert!(!result.contains('\u{E0042}')); + assert!(!result.contains('\u{E0043}')); + assert!(result.contains("Extra instruction")); + assert!(result.contains("hidden")); + } + + #[test] + fn test_build_system_prompt_sanitizes_multiple_extras() { + let mut manager = PromptManager::new(); + manager.add_system_prompt_extra("First\u{E0041}instruction".to_string()); + manager.add_system_prompt_extra("Second\u{E0042}instruction".to_string()); + manager.add_system_prompt_extra("Third\u{E0043}instruction".to_string()); + + let result = + manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None); + + assert!(!result.contains('\u{E0041}')); + assert!(!result.contains('\u{E0042}')); + assert!(!result.contains('\u{E0043}')); + assert!(result.contains("Firstinstruction")); + assert!(result.contains("Secondinstruction")); + assert!(result.contains("Thirdinstruction")); + } + + #[test] + fn test_build_system_prompt_preserves_legitimate_unicode_in_extras() { + let mut manager = PromptManager::new(); + let legitimate_unicode = "Instruction with δΈ–η•Œ and 🌍 emojis"; + manager.add_system_prompt_extra(legitimate_unicode.to_string()); + + let result = + manager.build_system_prompt(vec![], None, Value::String("".to_string()), None, None); + + assert!(result.contains("δΈ–η•Œ")); + assert!(result.contains("🌍")); + assert!(result.contains("Instruction with")); + assert!(result.contains("emojis")); + } + + #[test] + fn test_build_system_prompt_sanitizes_extension_instructions() { + let manager = PromptManager::new(); + let malicious_extension_info = ExtensionInfo::new( + "test_extension", + "Extension help\u{E0041}\u{E0042}\u{E0043}hidden instructions", + false, + ); + + let result = manager.build_system_prompt( + vec![malicious_extension_info], + None, + Value::String("".to_string()), + None, + None, + ); + + assert!(!result.contains('\u{E0041}')); + assert!(!result.contains('\u{E0042}')); + assert!(!result.contains('\u{E0043}')); + assert!(result.contains("Extension help")); + assert!(result.contains("hidden instructions")); + } } diff --git a/crates/goose/src/conversation/message.rs b/crates/goose/src/conversation/message.rs index 9bedd5370cf6..107b88b70e27 100644 --- a/crates/goose/src/conversation/message.rs +++ b/crates/goose/src/conversation/message.rs @@ -8,20 +8,10 @@ use serde::{Deserialize, Deserializer, Serialize}; use serde_json::Value; use std::collections::HashSet; use std::fmt; -use unicode_normalization::UnicodeNormalization; use utoipa::ToSchema; use crate::conversation::tool_result_serde; - -/// Sanitize Unicode Tags Block characters from text -fn sanitize_unicode_tags(text: &str) -> String { - let normalized: String = text.nfc().collect(); - - normalized - .chars() - .filter(|&c| !matches!(c, '\u{E0000}'..='\u{E007F}')) - .collect() -} +use crate::utils::sanitize_unicode_tags; /// Custom deserializer for MessageContent that sanitizes Unicode Tags in text content fn deserialize_sanitized_content<'de, D>(deserializer: D) -> Result, D::Error> @@ -611,20 +601,6 @@ mod tests { use rmcp::model::{ErrorCode, ErrorData}; use serde_json::{json, Value}; - #[test] - fn test_sanitize_unicode_tags() { - let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC" - let cleaned = super::sanitize_unicode_tags(malicious); - assert_eq!(cleaned, "Helloworld"); - } - - #[test] - fn test_no_sanitize_unicode_tags() { - let clean_text = "Hello world δΈ–η•Œ 🌍"; - let cleaned = super::sanitize_unicode_tags(clean_text); - assert_eq!(cleaned, clean_text); - } - #[test] fn test_sanitize_with_text() { let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC" diff --git a/crates/goose/src/utils.rs b/crates/goose/src/utils.rs index 1205953e83cc..32ea5c259df7 100644 --- a/crates/goose/src/utils.rs +++ b/crates/goose/src/utils.rs @@ -1,4 +1,19 @@ use tokio_util::sync::CancellationToken; +use unicode_normalization::UnicodeNormalization; + +/// Sanitize Unicode Tags Block characters from text +/// Used to prevent Unicode-based prompt injection attacks +/// +/// This function removes invisible Unicode Tags Block characters (U+E0000-U+E007F) +/// that can be used for steganographic attacks while preserving legitimate Unicode. +pub fn sanitize_unicode_tags(text: &str) -> String { + let normalized: String = text.nfc().collect(); + + normalized + .chars() + .filter(|&c| !matches!(c, '\u{E0000}'..='\u{E007F}')) + .collect() +} /// Safely truncate a string at character boundaries, not byte boundaries /// @@ -30,6 +45,45 @@ pub fn is_token_cancelled(cancellation_token: &Option) -> boo mod tests { use super::*; + #[test] + fn test_sanitize_unicode_tags() { + // Test that Unicode Tags Block characters are removed + let malicious = "Hello\u{E0041}\u{E0042}\u{E0043}world"; // Invisible "ABC" + let cleaned = sanitize_unicode_tags(malicious); + assert_eq!(cleaned, "Helloworld"); + } + + #[test] + fn test_sanitize_unicode_tags_preserves_legitimate_unicode() { + // Test that legitimate Unicode characters are preserved + let clean_text = "Hello world δΈ–η•Œ 🌍"; + let cleaned = sanitize_unicode_tags(clean_text); + assert_eq!(cleaned, clean_text); + } + + #[test] + fn test_sanitize_unicode_tags_empty_string() { + let empty = ""; + let cleaned = sanitize_unicode_tags(empty); + assert_eq!(cleaned, ""); + } + + #[test] + fn test_sanitize_unicode_tags_only_malicious() { + // Test string containing only Unicode Tags characters + let only_malicious = "\u{E0041}\u{E0042}\u{E0043}"; + let cleaned = sanitize_unicode_tags(only_malicious); + assert_eq!(cleaned, ""); + } + + #[test] + fn test_sanitize_unicode_tags_mixed_content() { + // Test mixed legitimate and malicious Unicode + let mixed = "Hello\u{E0041} δΈ–η•Œ\u{E0042} 🌍\u{E0043}!"; + let cleaned = sanitize_unicode_tags(mixed); + assert_eq!(cleaned, "Hello δΈ–η•Œ 🌍!"); + } + #[test] fn test_safe_truncate_ascii() { assert_eq!(safe_truncate("hello world", 20), "hello world");