diff --git a/crates/goose/src/agents/extension_malware_check.rs b/crates/goose/src/agents/extension_malware_check.rs new file mode 100644 index 000000000000..29b8d566c516 --- /dev/null +++ b/crates/goose/src/agents/extension_malware_check.rs @@ -0,0 +1,511 @@ +use reqwest::header::{HeaderMap, HeaderValue, USER_AGENT}; +use reqwest::Url; +use serde::{Deserialize, Serialize}; +use tracing::{debug, error, trace}; + +use crate::agents::extension::ExtensionError; + +#[derive(Clone)] +pub struct OsvChecker { + client: reqwest::Client, + endpoint: Url, +} + +impl OsvChecker { + /// Constructs a checker. Honors OSV_ENDPOINT env var if present. + pub fn new() -> Result> { + let client = http_client().map_err(Box::new)?; + let endpoint = std::env::var("OSV_ENDPOINT") + .ok() + .and_then(|s| Url::parse(&s).ok()) + .unwrap_or_else(|| Url::parse(DEFAULT_OSV_ENDPOINT).expect("valid default OSV url")); + Ok(Self { client, endpoint }) + } + + /// Constructs with a custom endpoint (handy for tests). + pub fn with_endpoint(endpoint: Url) -> Result> { + let client = http_client().map_err(Box::new)?; + Ok(Self { client, endpoint }) + } + + /// Query OSV and **fail** if any MAL-* advisories are found. + /// - `ecosystem`: e.g., "npm", "PyPI" + /// - `version`: if `None`, checks by name only. + pub async fn deny_if_malicious( + &self, + name: &str, + ecosystem: &str, + version: Option<&str>, + ) -> Result<(), ExtensionError> { + deny_if_malicious_impl(&self.client, &self.endpoint, name, ecosystem, version).await + } +} + +/// Convenience: infer ecosystem from command token + parse first package arg. +/// - ends_with("npx") → npm +/// - ends_with("uvx") → PyPI +/// unknown commands → skip (fail open) +pub async fn deny_if_malicious_cmd_args(cmd: &str, args: &[String]) -> Result<(), ExtensionError> { + let ecosystem = if cmd.ends_with("uvx") { + "PyPI" + } else if cmd.ends_with("npx") { + "npm" + } else { + debug!(%cmd, ?args, "Unknown ecosystem for command; skipping OSV check (fail open)."); + return Ok(()); + }; + + if let Some((name, version)) = parse_first_package_arg(ecosystem, args) { + OsvChecker::new() + .map_err(|e| *e)? + .deny_if_malicious(&name, ecosystem, version.as_deref()) + .await?; + } else { + debug!(%cmd, ?args, "No package token found; skipping OSV check."); + } + + Ok(()) +} + +/// Direct call without command inference. +pub async fn deny_if_malicious( + name: &str, + ecosystem: &str, + version: Option<&str>, +) -> Result<(), ExtensionError> { + OsvChecker::new() + .map_err(|e| *e)? + .deny_if_malicious(name, ecosystem, version) + .await +} + +fn parse_first_package_arg(ecosystem: &str, args: &[String]) -> Option<(String, Option)> { + let is_flag = |s: &str| s.starts_with('-'); + let token = args + .iter() + .find(|a| !is_flag(a.as_str()))? + .trim() + .to_string(); + if token.is_empty() { + return None; + } + match ecosystem { + "npm" => parse_npm_token(&token), + "PyPI" => parse_pypi_token(&token), + _ => None, + } +} + +fn parse_npm_token(token: &str) -> Option<(String, Option)> { + // Handles: + // react@18.3.1 + // @scope/pkg@1.2.3 (split at the LAST '@') + // eslint (no version) + if token.starts_with('@') { + if let Some(idx) = token.rfind('@') { + if idx > 0 { + let (name, ver) = token.split_at(idx); + let ver = ver.trim_start_matches('@'); + if !ver.is_empty() && ver != "latest" { + return Some((name.to_string(), Some(ver.to_string()))); + } else { + return Some((name.to_string(), None)); + } + } + } + Some((token.to_string(), None)) + } else if let Some(idx) = token.find('@') { + let (name, ver) = token.split_at(idx); + let ver = ver.trim_start_matches('@'); + if !name.is_empty() { + if !ver.is_empty() && ver != "latest" { + return Some((name.to_string(), Some(ver.to_string()))); + } else { + return Some((name.to_string(), None)); + } + } + None + } else { + Some((token.to_string(), None)) + } +} + +fn parse_pypi_token(token: &str) -> Option<(String, Option)> { + // Accept exact pins: + // package==1.2.3 + // package[extra]==1.2.3 + // Treat "latest" as None. Ignore other specifiers (>=, <=, ~=, !=) for pinning. + let lowered = token.to_ascii_lowercase(); + if let Some(idx) = lowered.find("==") { + let (name, ver) = token.split_at(idx); + let ver = ver.trim_start_matches('=').trim_start_matches('='); + let name = name.trim(); + if name.is_empty() { + return None; + } + if ver.is_empty() || ver.eq_ignore_ascii_case("latest") { + return Some((name.to_string(), None)); + } + return Some((name.to_string(), Some(ver.to_string()))); + } + Some((token.to_string(), None)) +} + +const DEFAULT_OSV_ENDPOINT: &str = "https://api.osv.dev/v1/query"; + +#[derive(Serialize)] +struct QueryReq<'a> { + #[serde(skip_serializing_if = "Option::is_none")] + version: Option<&'a str>, + package: Package<'a>, + #[serde(skip_serializing_if = "Option::is_none")] + page_token: Option, +} + +#[derive(Serialize)] +struct Package<'a> { + name: &'a str, + ecosystem: &'a str, + #[serde(skip_serializing_if = "Option::is_none")] + purl: Option<&'a str>, +} + +#[derive(Deserialize)] +struct QueryResp { + #[serde(default)] + vulns: Vec, + #[serde(default)] + next_page_token: Option, +} + +#[derive(Deserialize)] +struct Vuln { + id: String, + #[serde(default)] + summary: String, +} + +async fn deny_if_malicious_impl( + client: &reqwest::Client, + endpoint: &Url, + name: &str, + ecosystem: &str, + version: Option<&str>, +) -> Result<(), ExtensionError> { + debug!(name, ecosystem, ?version, "OSV query starting"); + let mut page_token: Option = None; + let mut mal: Vec = Vec::new(); + + loop { + let body = QueryReq { + version, + package: Package { + name, + ecosystem, + purl: None, + }, + page_token: page_token.clone(), + }; + trace!(?body.page_token, "OSV page"); + + let resp = match client.post(endpoint.clone()).json(&body).send().await { + Ok(r) => r, + Err(e) => { + error!(%e, name, ecosystem, ?version, "OSV request failed; failing open."); + return Ok(()); + } + }; + + let resp = match resp.error_for_status() { + Ok(r) => r, + Err(e) => { + error!(%e, name, ecosystem, ?version, "OSV HTTP error; failing open."); + return Ok(()); + } + }; + + let payload: QueryResp = match resp.json().await { + Ok(p) => p, + Err(e) => { + error!(%e, name, ecosystem, ?version, "OSV JSON parse error; failing open."); + return Ok(()); + } + }; + + mal.extend( + payload + .vulns + .into_iter() + .filter(|v| v.id.starts_with("MAL-")), + ); + + match payload.next_page_token { + Some(tok) if !tok.is_empty() => page_token = Some(tok), + _ => break, + } + } + + if !mal.is_empty() { + let ver = version.unwrap_or(""); + let details = mal + .into_iter() + .map(|v| { + if v.summary.is_empty() { + v.id + } else { + format!("{} — {}", v.id, v.summary) + } + }) + .collect::>() + .join("; "); + error!(name, ecosystem, version=%ver, %details, "Blocked malicious package via OSV MAL-*."); + return Err(ExtensionError::ConfigError(format!( + "Blocked malicious package: {name}@{ver} ({ecosystem}). OSV MAL advisories: {details}" + ))); + } + + debug!(name, ecosystem, ?version, "OSV: no MAL advisories."); + Ok(()) +} + +#[allow(clippy::result_large_err)] +fn http_client() -> Result { + let mut headers = HeaderMap::new(); + headers.insert( + USER_AGENT, + HeaderValue::from_static("goose-osv-check/1.1 (+https://osv.dev)"), + ); + reqwest::Client::builder() + .default_headers(headers) + .timeout(std::time::Duration::from_secs(10)) + .build() + .map_err(|e| ExtensionError::SetupError(format!("failed to build HTTP client: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use serial_test; + use tokio; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + fn checker_for(server: &MockServer) -> OsvChecker { + let url = Url::parse(&format!("{}/v1/query", server.uri())).unwrap(); + OsvChecker::with_endpoint(url).unwrap() + } + + // Helper to temporarily set an environment variable and restore it on drop + struct TempEnvVar { + key: String, + original: Option, + } + + impl TempEnvVar { + fn set(key: &str, value: &str) -> Self { + let original = std::env::var(key).ok(); + std::env::set_var(key, value); + Self { + key: key.to_string(), + original, + } + } + } + + impl Drop for TempEnvVar { + fn drop(&mut self) { + match &self.original { + Some(val) => std::env::set_var(&self.key, val), + None => std::env::remove_var(&self.key), + } + } + } + + #[tokio::test] + async fn allows_clean_package() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [], + "next_page_token": null + }))) + .mount(&server) + .await; + + let c = checker_for(&server); + let res = c + .deny_if_malicious("some_clean_package", "PyPI", None) + .await; + assert!(res.is_ok()); + } + + #[tokio::test] + async fn blocks_malicious_package() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [ { "id": "MAL-1234", "summary": "Malicious package" } ], + "next_page_token": null + }))) + .mount(&server) + .await; + + let c = checker_for(&server); + let res = c + .deny_if_malicious("bad_package", "PyPI", Some("1.0.0")) + .await; + assert!(res.is_err()); + let msg = format!("{:?}", res.unwrap_err()); + assert!(msg.contains("Blocked malicious package")); + assert!(msg.contains("MAL-1234")); + } + + #[tokio::test] + #[serial_test::serial] + async fn cmd_args_pypi_clean() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [], + "next_page_token": null + }))) + .mount(&server) + .await; + + // Use env var so OsvChecker::new() picks it up + let _env = TempEnvVar::set("OSV_ENDPOINT", &format!("{}/v1/query", server.uri())); + let args = vec!["some_clean_package==1.2.3".to_string()]; + let res = deny_if_malicious_cmd_args("uvx", &args).await; + assert!(res.is_ok()); + } + + #[tokio::test] + #[serial_test::serial] + async fn cmd_args_npm_scoped_malicious() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [ { "id": "MAL-9999", "summary": "Malicious npm package" } ], + "next_page_token": null + }))) + .mount(&server) + .await; + + let _env = TempEnvVar::set("OSV_ENDPOINT", &format!("{}/v1/query", server.uri())); + let args = vec!["@scope/pkg@2.0.0".to_string()]; + let res = deny_if_malicious_cmd_args("npx", &args).await; + assert!(res.is_err()); + let msg = format!("{:?}", res.unwrap_err()); + assert!(msg.contains("Blocked malicious package")); + assert!(msg.contains("MAL-9999")); + } + + #[tokio::test] + #[serial_test::serial] + async fn cmd_args_skip_flags_then_parse() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [], + "next_page_token": null + }))) + .mount(&server) + .await; + + let _env = TempEnvVar::set("OSV_ENDPOINT", &format!("{}/v1/query", server.uri())); + let args = vec![ + "--dry-run".into(), + "-y".into(), + "some_clean_package@1.2.3".into(), + ]; + let res = deny_if_malicious_cmd_args("npx", &args).await; + assert!(res.is_ok()); + } + + #[tokio::test] + async fn pagination_works() { + let server = MockServer::start().await; + // 1st page: no vulns, but has next + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [], + "next_page_token": "page-2" + }))) + .up_to_n_times(1) + .mount(&server) + .await; + + // 2nd page: MAL hit + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(200).set_body_json(json!({ + "vulns": [ { "id": "MAL-4242", "summary": "Second page hit" } ], + "next_page_token": null + }))) + .mount(&server) + .await; + + let c = checker_for(&server); + let res = c.deny_if_malicious("pkg", "npm", None).await; + assert!(res.is_err()); + let msg = format!("{:?}", res.unwrap_err()); + assert!(msg.contains("MAL-4242")); + } + + #[tokio::test] + async fn fail_open_on_http_error() { + let server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/v1/query")) + .respond_with(ResponseTemplate::new(500)) + .mount(&server) + .await; + + let c = checker_for(&server); + let res = c.deny_if_malicious("pkg", "npm", None).await; + assert!(res.is_ok(), "should fail-open on HTTP errors"); + } + + #[tokio::test] + async fn unknown_command_is_skipped() { + let args = vec!["whatever@1.0.0".into()]; + // no mock server: we shouldn't call OSV at all + let res = deny_if_malicious_cmd_args("some-other-bin", &args).await; + assert!(res.is_ok()); + } + + #[test] + fn parse_npm_scoped_with_version() { + assert_eq!( + super::parse_npm_token("@scope/pkg@1.2.3"), + Some(("@scope/pkg".into(), Some("1.2.3".into()))) + ); + } + + #[test] + fn parse_npm_unscoped_latest_is_none() { + assert_eq!( + super::parse_npm_token("react@latest"), + Some(("react".into(), None)) + ); + } + + #[test] + fn parse_pypi_exact_pin_and_latest() { + assert_eq!( + super::parse_pypi_token("requests==2.32.3"), + Some(("requests".into(), Some("2.32.3".into()))) + ); + assert_eq!( + super::parse_pypi_token("requests==latest"), + Some(("requests".into(), None)) + ); + } +} diff --git a/crates/goose/src/agents/extension_manager.rs b/crates/goose/src/agents/extension_manager.rs index 657c9f62b534..b6576d8eb7fa 100644 --- a/crates/goose/src/agents/extension_manager.rs +++ b/crates/goose/src/agents/extension_manager.rs @@ -26,6 +26,7 @@ use tracing::{error, warn}; use super::extension::{ExtensionConfig, ExtensionError, ExtensionInfo, ExtensionResult, ToolInfo}; use super::tool_execution::ToolCallResult; use crate::agents::extension::{Envs, ProcessExit}; +use crate::agents::extension_malware_check; use crate::config::{Config, ExtensionConfigManager}; use crate::oauth::oauth_flow; use crate::prompt_template; @@ -363,6 +364,10 @@ impl ExtensionManager { let command = Command::new(cmd).configure(|command| { command.args(args).envs(all_envs); }); + + // Check for malicious packages before launching the process + extension_malware_check::deny_if_malicious_cmd_args(cmd, args).await?; + let client = child_process_client(command, timeout).await?; Box::new(client) } diff --git a/crates/goose/src/agents/mod.rs b/crates/goose/src/agents/mod.rs index c18cfd916075..ae668b714ea1 100644 --- a/crates/goose/src/agents/mod.rs +++ b/crates/goose/src/agents/mod.rs @@ -1,6 +1,7 @@ mod agent; mod context; pub mod extension; +pub mod extension_malware_check; pub mod extension_manager; pub mod final_output_tool; mod large_response_handler;