Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 173 additions & 10 deletions crates/goose/src/providers/api_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ use anyhow::Result;
use async_trait::async_trait;
use reqwest::{
header::{HeaderMap, HeaderName, HeaderValue},
Client, Response, StatusCode,
Certificate, Client, Identity, Response, StatusCode,
};
use serde_json::Value;
use std::fmt;
use std::fs::read_to_string;
use std::path::PathBuf;
use std::time::Duration;

pub struct ApiClient {
Expand All @@ -14,6 +16,7 @@ pub struct ApiClient {
auth: AuthMethod,
default_headers: HeaderMap,
timeout: Duration,
tls_config: Option<TlsConfig>,
}

pub enum AuthMethod {
Expand All @@ -27,6 +30,127 @@ pub enum AuthMethod {
Custom(Box<dyn AuthProvider>),
}

#[derive(Debug, Clone)]
pub struct TlsCertKeyPair {
pub cert_path: PathBuf,
pub key_path: PathBuf,
}

#[derive(Debug, Clone)]
pub struct TlsConfig {
pub client_identity: Option<TlsCertKeyPair>,
pub ca_cert_path: Option<PathBuf>,
}

impl TlsConfig {
pub fn new() -> Self {
Self {
client_identity: None,
ca_cert_path: None,
}
}

pub fn from_config() -> Result<Option<Self>> {
let config = crate::config::Config::global();
let mut tls_config = TlsConfig::new();
let mut has_tls_config = false;

let client_cert_path = config.get_param::<String>("GOOSE_CLIENT_CERT_PATH").ok();
let client_key_path = config.get_param::<String>("GOOSE_CLIENT_KEY_PATH").ok();

// Validate that both cert and key are provided if either is provided
match (client_cert_path, client_key_path) {
(Some(cert_path), Some(key_path)) => {
tls_config = tls_config.with_client_cert_and_key(
std::path::PathBuf::from(cert_path),
std::path::PathBuf::from(key_path),
);
has_tls_config = true;
}
(Some(_), None) => {
return Err(anyhow::anyhow!(
"Client certificate provided (GOOSE_CLIENT_CERT_PATH) but no private key (GOOSE_CLIENT_KEY_PATH)"
));
}
(None, Some(_)) => {
return Err(anyhow::anyhow!(
"Client private key provided (GOOSE_CLIENT_KEY_PATH) but no certificate (GOOSE_CLIENT_CERT_PATH)"
));
}
(None, None) => {}
}

if let Ok(ca_cert_path) = config.get_param::<String>("GOOSE_CA_CERT_PATH") {
tls_config = tls_config.with_ca_cert(std::path::PathBuf::from(ca_cert_path));
has_tls_config = true;
}

if has_tls_config {
Ok(Some(tls_config))
} else {
Ok(None)
}
}

pub fn with_client_cert_and_key(mut self, cert_path: PathBuf, key_path: PathBuf) -> Self {
self.client_identity = Some(TlsCertKeyPair {
cert_path,
key_path,
});
self
}

pub fn with_ca_cert(mut self, path: PathBuf) -> Self {
self.ca_cert_path = Some(path);
self
}

pub fn is_configured(&self) -> bool {
self.client_identity.is_some() || self.ca_cert_path.is_some()
}

pub fn load_identity(&self) -> Result<Option<Identity>> {
if let Some(cert_key_pair) = &self.client_identity {
let cert_pem = read_to_string(&cert_key_pair.cert_path)
.map_err(|e| anyhow::anyhow!("Failed to read client certificate: {}", e))?;
let key_pem = read_to_string(&cert_key_pair.key_path)
.map_err(|e| anyhow::anyhow!("Failed to read client private key: {}", e))?;

// Create a combined PEM file with certificate and private key
let combined_pem = format!("{}\n{}", cert_pem, key_pem);

let identity = Identity::from_pem(combined_pem.as_bytes()).map_err(|e| {
anyhow::anyhow!("Failed to create identity from cert and key: {}", e)
})?;

Ok(Some(identity))
} else {
Ok(None)
}
}

pub fn load_ca_certificates(&self) -> Result<Vec<Certificate>> {
match &self.ca_cert_path {
Some(ca_path) => {
let ca_pem = read_to_string(ca_path)
.map_err(|e| anyhow::anyhow!("Failed to read CA certificate: {}", e))?;

let certs = Certificate::from_pem_bundle(ca_pem.as_bytes())
.map_err(|e| anyhow::anyhow!("Failed to parse CA certificate bundle: {}", e))?;

Ok(certs)
}
None => Ok(Vec::new()),
}
}
}

impl Default for TlsConfig {
fn default() -> Self {
Self::new()
}
}

pub struct OAuthConfig {
pub host: String,
pub client_id: String,
Expand Down Expand Up @@ -79,32 +203,71 @@ impl ApiClient {
}

pub fn with_timeout(host: String, auth: AuthMethod, timeout: Duration) -> Result<Self> {
let mut client_builder = Client::builder().timeout(timeout);

// Configure TLS if needed
let tls_config = TlsConfig::from_config()?;
if let Some(ref config) = tls_config {
client_builder = Self::configure_tls(client_builder, config)?;
}

let client = client_builder.build()?;

Ok(Self {
client: Client::builder().timeout(timeout).build()?,
client,
host,
auth,
default_headers: HeaderMap::new(),
timeout,
tls_config,
})
}

fn rebuild_client(&mut self) -> Result<()> {
let mut client_builder = Client::builder()
.timeout(self.timeout)
.default_headers(self.default_headers.clone());

// Configure TLS if needed
if let Some(ref tls_config) = self.tls_config {
client_builder = Self::configure_tls(client_builder, tls_config)?;
}

self.client = client_builder.build()?;
Ok(())
}

/// Configure TLS settings on a reqwest ClientBuilder
fn configure_tls(
mut client_builder: reqwest::ClientBuilder,
tls_config: &TlsConfig,
) -> Result<reqwest::ClientBuilder> {
if tls_config.is_configured() {
// Load client identity (certificate + private key)
if let Some(identity) = tls_config.load_identity()? {
client_builder = client_builder.identity(identity);
}

// Load CA certificates
let ca_certs = tls_config.load_ca_certificates()?;
for ca_cert in ca_certs {
client_builder = client_builder.add_root_certificate(ca_cert);
}
}
Ok(client_builder)
}

pub fn with_headers(mut self, headers: HeaderMap) -> Result<Self> {
self.default_headers = headers;
self.client = Client::builder()
.timeout(self.timeout)
.default_headers(self.default_headers.clone())
.build()?;
self.rebuild_client()?;
Ok(self)
}

pub fn with_header(mut self, key: &str, value: &str) -> Result<Self> {
let header_name = HeaderName::from_bytes(key.as_bytes())?;
let header_value = HeaderValue::from_str(value)?;
self.default_headers.insert(header_name, header_value);
self.client = Client::builder()
.timeout(self.timeout)
.default_headers(self.default_headers.clone())
.build()?;
self.rebuild_client()?;
Ok(self)
}

Expand Down
Loading