diff --git a/crates/goose/src/providers/api_client.rs b/crates/goose/src/providers/api_client.rs index 434451e8770a..3e6ced8cee92 100644 --- a/crates/goose/src/providers/api_client.rs +++ b/crates/goose/src/providers/api_client.rs @@ -2,10 +2,12 @@ use anyhow::Result; use async_trait::async_trait; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, - Client, Response, StatusCode, + Certificate, Client, Identity, Response, StatusCode, }; use serde_json::Value; use std::fmt; +use std::fs::read_to_string; +use std::path::PathBuf; use std::time::Duration; pub struct ApiClient { @@ -14,6 +16,7 @@ pub struct ApiClient { auth: AuthMethod, default_headers: HeaderMap, timeout: Duration, + tls_config: Option, } pub enum AuthMethod { @@ -27,6 +30,127 @@ pub enum AuthMethod { Custom(Box), } +#[derive(Debug, Clone)] +pub struct TlsCertKeyPair { + pub cert_path: PathBuf, + pub key_path: PathBuf, +} + +#[derive(Debug, Clone)] +pub struct TlsConfig { + pub client_identity: Option, + pub ca_cert_path: Option, +} + +impl TlsConfig { + pub fn new() -> Self { + Self { + client_identity: None, + ca_cert_path: None, + } + } + + pub fn from_config() -> Result> { + let config = crate::config::Config::global(); + let mut tls_config = TlsConfig::new(); + let mut has_tls_config = false; + + let client_cert_path = config.get_param::("GOOSE_CLIENT_CERT_PATH").ok(); + let client_key_path = config.get_param::("GOOSE_CLIENT_KEY_PATH").ok(); + + // Validate that both cert and key are provided if either is provided + match (client_cert_path, client_key_path) { + (Some(cert_path), Some(key_path)) => { + tls_config = tls_config.with_client_cert_and_key( + std::path::PathBuf::from(cert_path), + std::path::PathBuf::from(key_path), + ); + has_tls_config = true; + } + (Some(_), None) => { + return Err(anyhow::anyhow!( + "Client certificate provided (GOOSE_CLIENT_CERT_PATH) but no private key (GOOSE_CLIENT_KEY_PATH)" + )); + } + (None, Some(_)) => { + return Err(anyhow::anyhow!( + "Client private key provided (GOOSE_CLIENT_KEY_PATH) but no certificate (GOOSE_CLIENT_CERT_PATH)" + )); + } + (None, None) => {} + } + + if let Ok(ca_cert_path) = config.get_param::("GOOSE_CA_CERT_PATH") { + tls_config = tls_config.with_ca_cert(std::path::PathBuf::from(ca_cert_path)); + has_tls_config = true; + } + + if has_tls_config { + Ok(Some(tls_config)) + } else { + Ok(None) + } + } + + pub fn with_client_cert_and_key(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self { + self.client_identity = Some(TlsCertKeyPair { + cert_path, + key_path, + }); + self + } + + pub fn with_ca_cert(mut self, path: PathBuf) -> Self { + self.ca_cert_path = Some(path); + self + } + + pub fn is_configured(&self) -> bool { + self.client_identity.is_some() || self.ca_cert_path.is_some() + } + + pub fn load_identity(&self) -> Result> { + if let Some(cert_key_pair) = &self.client_identity { + let cert_pem = read_to_string(&cert_key_pair.cert_path) + .map_err(|e| anyhow::anyhow!("Failed to read client certificate: {}", e))?; + let key_pem = read_to_string(&cert_key_pair.key_path) + .map_err(|e| anyhow::anyhow!("Failed to read client private key: {}", e))?; + + // Create a combined PEM file with certificate and private key + let combined_pem = format!("{}\n{}", cert_pem, key_pem); + + let identity = Identity::from_pem(combined_pem.as_bytes()).map_err(|e| { + anyhow::anyhow!("Failed to create identity from cert and key: {}", e) + })?; + + Ok(Some(identity)) + } else { + Ok(None) + } + } + + pub fn load_ca_certificates(&self) -> Result> { + match &self.ca_cert_path { + Some(ca_path) => { + let ca_pem = read_to_string(ca_path) + .map_err(|e| anyhow::anyhow!("Failed to read CA certificate: {}", e))?; + + let certs = Certificate::from_pem_bundle(ca_pem.as_bytes()) + .map_err(|e| anyhow::anyhow!("Failed to parse CA certificate bundle: {}", e))?; + + Ok(certs) + } + None => Ok(Vec::new()), + } + } +} + +impl Default for TlsConfig { + fn default() -> Self { + Self::new() + } +} + pub struct OAuthConfig { pub host: String, pub client_id: String, @@ -79,21 +203,63 @@ impl ApiClient { } pub fn with_timeout(host: String, auth: AuthMethod, timeout: Duration) -> Result { + let mut client_builder = Client::builder().timeout(timeout); + + // Configure TLS if needed + let tls_config = TlsConfig::from_config()?; + if let Some(ref config) = tls_config { + client_builder = Self::configure_tls(client_builder, config)?; + } + + let client = client_builder.build()?; + Ok(Self { - client: Client::builder().timeout(timeout).build()?, + client, host, auth, default_headers: HeaderMap::new(), timeout, + tls_config, }) } + fn rebuild_client(&mut self) -> Result<()> { + let mut client_builder = Client::builder() + .timeout(self.timeout) + .default_headers(self.default_headers.clone()); + + // Configure TLS if needed + if let Some(ref tls_config) = self.tls_config { + client_builder = Self::configure_tls(client_builder, tls_config)?; + } + + self.client = client_builder.build()?; + Ok(()) + } + + /// Configure TLS settings on a reqwest ClientBuilder + fn configure_tls( + mut client_builder: reqwest::ClientBuilder, + tls_config: &TlsConfig, + ) -> Result { + if tls_config.is_configured() { + // Load client identity (certificate + private key) + if let Some(identity) = tls_config.load_identity()? { + client_builder = client_builder.identity(identity); + } + + // Load CA certificates + let ca_certs = tls_config.load_ca_certificates()?; + for ca_cert in ca_certs { + client_builder = client_builder.add_root_certificate(ca_cert); + } + } + Ok(client_builder) + } + pub fn with_headers(mut self, headers: HeaderMap) -> Result { self.default_headers = headers; - self.client = Client::builder() - .timeout(self.timeout) - .default_headers(self.default_headers.clone()) - .build()?; + self.rebuild_client()?; Ok(self) } @@ -101,10 +267,7 @@ impl ApiClient { let header_name = HeaderName::from_bytes(key.as_bytes())?; let header_value = HeaderValue::from_str(value)?; self.default_headers.insert(header_name, header_value); - self.client = Client::builder() - .timeout(self.timeout) - .default_headers(self.default_headers.clone()) - .build()?; + self.rebuild_client()?; Ok(self) }