diff --git a/crates/goose/src/agents/extension.rs b/crates/goose/src/agents/extension.rs index 9f1daa6cd279..de63751c1e56 100644 --- a/crates/goose/src/agents/extension.rs +++ b/crates/goose/src/agents/extension.rs @@ -11,6 +11,7 @@ use once_cell::sync::Lazy; use rmcp::model::Tool; use rmcp::service::ClientInitializeError; use rmcp::ServiceError as ClientError; +use serde::Deserializer; use serde::{Deserialize, Serialize}; use thiserror::Error; use tracing::warn; @@ -216,6 +217,7 @@ pub enum ExtensionConfig { /// The name used to identify this extension name: String, #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, uri: String, @@ -237,6 +239,7 @@ pub enum ExtensionConfig { /// The name used to identify this extension name: String, #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, cmd: String, @@ -257,6 +260,7 @@ pub enum ExtensionConfig { /// The name used to identify this extension name: String, #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, display_name: Option, // needed for the UI @@ -271,7 +275,7 @@ pub enum ExtensionConfig { Platform { /// The name used to identify this extension name: String, - #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, #[serde(default)] @@ -284,7 +288,7 @@ pub enum ExtensionConfig { StreamableHttp { /// The name used to identify this extension name: String, - #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, uri: String, @@ -307,7 +311,7 @@ pub enum ExtensionConfig { Frontend { /// The name used to identify this extension name: String, - #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, /// The tools provided by the frontend @@ -324,7 +328,7 @@ pub enum ExtensionConfig { InlinePython { /// The name used to identify this extension name: String, - #[serde(default)] + #[serde(deserialize_with = "deserialize_null_with_default")] #[schema(required)] description: String, /// The Python code to execute @@ -544,6 +548,15 @@ impl ExtensionInfo { } } +fn deserialize_null_with_default<'de, D, T>(deserializer: D) -> Result +where + T: Default + Deserialize<'de>, + D: Deserializer<'de>, +{ + let opt = Option::deserialize(deserializer)?; + Ok(opt.unwrap_or_default()) +} + /// Information about the tool used for building prompts #[derive(Clone, Debug, Serialize, ToSchema)] pub struct ToolInfo { @@ -568,3 +581,69 @@ impl ToolInfo { } } } + +#[cfg(test)] +mod tests { + use crate::agents::*; + + #[test] + fn test_deserialize_missing_description() { + let config: ExtensionConfig = serde_yaml::from_str( + "enabled: true +type: builtin +name: developer +display_name: Developer +timeout: 300 +bundled: true +available_tools: []", + ) + .unwrap(); + if let ExtensionConfig::Builtin { description, .. } = config { + assert_eq!(description, "") + } else { + panic!("unexpected result of deserialization: {}", config) + } + } + + #[test] + fn test_deserialize_null_description() { + let config: ExtensionConfig = serde_yaml::from_str( + "enabled: true +type: builtin +name: developer +display_name: Developer +description: null +timeout: 300 +bundled: true +available_tools: [] +", + ) + .unwrap(); + if let ExtensionConfig::Builtin { description, .. } = config { + assert_eq!(description, "") + } else { + panic!("unexpected result of deserialization: {}", config) + } + } + + #[test] + fn test_deserialize_normal_description() { + let config: ExtensionConfig = serde_yaml::from_str( + "enabled: true +type: builtin +name: developer +display_name: Developer +description: description goes here +timeout: 300 +bundled: true +available_tools: [] + ", + ) + .unwrap(); + if let ExtensionConfig::Builtin { description, .. } = config { + assert_eq!(description, "description goes here") + } else { + panic!("unexpected result of deserialization: {}", config) + } + } +}