Skip to content

Commit 156386e

Browse files
committed
Error handling
1 parent 746b1e4 commit 156386e

File tree

4 files changed

+95
-50
lines changed

4 files changed

+95
-50
lines changed

Cargo.lock

Lines changed: 7 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ thiserror = "2.0.17"
1313
clap = { version = "4.5.51", features = ["derive"] }
1414
webpki-roots = "1.0.4"
1515
rustls-pemfile = "2.2.0"
16+
anyhow = "1.0.100"
1617

1718
[dev-dependencies]
1819
rcgen = "0.14.5"

src/lib.rs

Lines changed: 51 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
mod attestation;
22

33
pub use attestation::{AttestationPlatform, MockAttestation, NoAttestation};
4-
use tokio_rustls::rustls::server::WebPkiClientVerifier;
4+
use thiserror::Error;
5+
use tokio_rustls::rustls::server::{VerifierBuilderError, WebPkiClientVerifier};
56

67
#[cfg(test)]
78
mod test_helpers;
@@ -60,26 +61,23 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
6061
local_attestation_platform: L,
6162
remote_attestation_platform: R,
6263
client_auth: bool,
63-
) -> Self {
64+
) -> Result<Self, ProxyError> {
6465
if remote_attestation_platform.is_cvm() && !client_auth {
65-
panic!("Client auth is required when the client is running in a CVM");
66+
return Err(ProxyError::NoClientAuth);
6667
}
6768

6869
let server_config = if client_auth {
6970
let root_store =
7071
RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
71-
let verifier = WebPkiClientVerifier::builder(Arc::new(root_store))
72-
.build()
73-
.expect("invalid client verifier");
72+
let verifier = WebPkiClientVerifier::builder(Arc::new(root_store)).build()?;
73+
7474
ServerConfig::builder()
7575
.with_client_cert_verifier(verifier)
76-
.with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)
77-
.expect("Failed to create rustls server config")
76+
.with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)?
7877
} else {
7978
ServerConfig::builder()
8079
.with_no_client_auth()
81-
.with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)
82-
.expect("Failed to create rustls server config")
80+
.with_single_cert(cert_and_key.cert_chain.clone(), cert_and_key.key)?
8381
};
8482

8583
Self::new_with_tls_config(
@@ -103,26 +101,27 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyServer<L, R> {
103101
target: SocketAddr,
104102
local_attestation_platform: L,
105103
remote_attestation_platform: R,
106-
) -> Self {
104+
) -> Result<Self, ProxyError> {
107105
let acceptor = tokio_rustls::TlsAcceptor::from(server_config);
108-
let listener = TcpListener::bind(local).await.unwrap();
106+
let listener = TcpListener::bind(local).await?;
109107

110108
let inner = Proxy {
111109
listener,
112110
local_attestation_platform,
113111
remote_attestation_platform,
114112
};
115-
Self {
113+
114+
Ok(Self {
116115
acceptor,
117116
target,
118117
inner,
119118
cert_chain,
120-
}
119+
})
121120
}
122121

123122
/// Accept an incoming connection
124-
pub async fn accept(&self) -> io::Result<()> {
125-
let (inbound, _client_addr) = self.inner.listener.accept().await.unwrap();
123+
pub async fn accept(&self) -> Result<(), ProxyError> {
124+
let (inbound, _client_addr) = self.inner.listener.accept().await?;
126125

127126
let acceptor = self.acceptor.clone();
128127
let target = self.target;
@@ -215,9 +214,9 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
215214
server_name: ServerName<'static>,
216215
local_attestation_platform: L,
217216
remote_attestation_platform: R,
218-
) -> Self {
217+
) -> Result<Self, ProxyError> {
219218
if local_attestation_platform.is_cvm() && cert_and_key.is_none() {
220-
panic!("Client auth is required when the client is running in a CVM");
219+
return Err(ProxyError::NoClientAuth);
221220
}
222221

223222
let root_store = RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
@@ -228,8 +227,7 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
228227
.with_client_auth_cert(
229228
cert_and_key.cert_chain.clone(),
230229
cert_and_key.key.clone_key(),
231-
)
232-
.unwrap()
230+
)?
233231
} else {
234232
ClientConfig::builder()
235233
.with_root_certificates(root_store)
@@ -256,8 +254,8 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
256254
local_attestation_platform: L,
257255
remote_attestation_platform: R,
258256
cert_chain: Option<Vec<CertificateDer<'static>>>,
259-
) -> Self {
260-
let listener = TcpListener::bind(local).await.unwrap();
257+
) -> Result<Self, ProxyError> {
258+
let listener = TcpListener::bind(local).await?;
261259
let connector = TlsConnector::from(client_config.clone());
262260

263261
let inner = Proxy {
@@ -266,17 +264,17 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
266264
remote_attestation_platform,
267265
};
268266

269-
Self {
267+
Ok(Self {
270268
inner,
271269
connector,
272270
target,
273271
target_name,
274272
cert_chain,
275-
}
273+
})
276274
}
277275

278276
pub async fn accept(&self) -> io::Result<()> {
279-
let (inbound, _client_addr) = self.inner.listener.accept().await.unwrap();
277+
let (inbound, _client_addr) = self.inner.listener.accept().await?;
280278

281279
let connector = self.connector.clone();
282280
let target_name = self.target_name.clone();
@@ -348,6 +346,18 @@ impl<L: AttestationPlatform, R: AttestationPlatform> ProxyClient<L, R> {
348346
}
349347
}
350348

349+
#[derive(Error, Debug)]
350+
pub enum ProxyError {
351+
#[error("Client auth is required when the client is running in a CVM")]
352+
NoClientAuth,
353+
#[error("TLS: {0}")]
354+
Rustls(#[from] tokio_rustls::rustls::Error),
355+
#[error("Verifier builder: {0}")]
356+
VerifierBuilder(#[from] VerifierBuilderError),
357+
#[error("IO: {0}")]
358+
Io(#[from] std::io::Error),
359+
}
360+
351361
fn length_prefix(input: &[u8]) -> [u8; 4] {
352362
let len = input.len() as u32;
353363
len.to_be_bytes()
@@ -377,7 +387,9 @@ mod tests {
377387
MockAttestation,
378388
NoAttestation,
379389
)
380-
.await;
390+
.await
391+
.unwrap();
392+
381393
let proxy_addr = proxy_server.local_addr().unwrap();
382394

383395
tokio::spawn(async move {
@@ -393,7 +405,8 @@ mod tests {
393405
MockAttestation,
394406
None,
395407
)
396-
.await;
408+
.await
409+
.unwrap();
397410

398411
let proxy_client_addr = proxy_client.local_addr().unwrap();
399412

@@ -439,7 +452,9 @@ mod tests {
439452
MockAttestation,
440453
MockAttestation,
441454
)
442-
.await;
455+
.await
456+
.unwrap();
457+
443458
let proxy_addr = proxy_server.local_addr().unwrap();
444459

445460
tokio::spawn(async move {
@@ -455,7 +470,8 @@ mod tests {
455470
MockAttestation,
456471
Some(client_cert_chain),
457472
)
458-
.await;
473+
.await
474+
.unwrap();
459475

460476
let proxy_client_addr = proxy_client.local_addr().unwrap();
461477

@@ -491,7 +507,9 @@ mod tests {
491507
local_attestation_platform,
492508
NoAttestation,
493509
)
494-
.await;
510+
.await
511+
.unwrap();
512+
495513
let proxy_server_addr = proxy_server.local_addr().unwrap();
496514

497515
tokio::spawn(async move {
@@ -507,7 +525,9 @@ mod tests {
507525
MockAttestation,
508526
None,
509527
)
510-
.await;
528+
.await
529+
.unwrap();
530+
511531
let proxy_client_addr = proxy_client.local_addr().unwrap();
512532

513533
tokio::spawn(async move {

src/main.rs

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use anyhow::{anyhow, ensure};
12
use clap::{Parser, Subcommand};
23
use std::{fs::File, net::SocketAddr, path::PathBuf};
34
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
@@ -46,7 +47,7 @@ enum CliCommand {
4647
}
4748

4849
#[tokio::main]
49-
async fn main() {
50+
async fn main() -> anyhow::Result<()> {
5051
let cli = Cli::parse();
5152

5253
match cli.command {
@@ -56,21 +57,33 @@ async fn main() {
5657
private_key,
5758
cert_chain,
5859
} => {
59-
let tls_cert_and_chain = private_key
60-
.map(|private_key| load_tls_cert_and_key(cert_chain.unwrap(), private_key));
60+
let tls_cert_and_chain = if let Some(private_key) = private_key {
61+
Some(load_tls_cert_and_key(
62+
cert_chain.ok_or(anyhow!("Private key given but no certificate chain"))?,
63+
private_key,
64+
)?)
65+
} else {
66+
ensure!(
67+
cert_chain.is_none(),
68+
"Certificate chain given but no private key"
69+
);
70+
None
71+
};
6172

6273
let client = ProxyClient::new(
6374
tls_cert_and_chain,
6475
cli.address,
6576
server_address,
66-
server_name.try_into().unwrap(),
77+
server_name.try_into()?,
6778
NoAttestation,
6879
MockAttestation,
6980
)
70-
.await;
81+
.await?;
7182

7283
loop {
73-
client.accept().await.unwrap();
84+
if let Err(err) = client.accept().await {
85+
eprintln!("Failed to handle connection: {err}");
86+
}
7487
}
7588
}
7689
CliCommand::Server {
@@ -79,7 +92,7 @@ async fn main() {
7992
cert_chain,
8093
client_auth,
8194
} => {
82-
let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key);
95+
let tls_cert_and_chain = load_tls_cert_and_key(cert_chain, private_key)?;
8396
let local_attestation = MockAttestation;
8497
let remote_attestation = NoAttestation;
8598

@@ -91,37 +104,41 @@ async fn main() {
91104
remote_attestation,
92105
client_auth,
93106
)
94-
.await;
107+
.await?;
95108

96109
loop {
97-
server.accept().await.unwrap();
110+
if let Err(err) = server.accept().await {
111+
eprintln!("Failed to handle connection: {err}");
112+
}
98113
}
99114
}
100115
}
101116
}
102117

103-
fn load_tls_cert_and_key(cert_chain: PathBuf, private_key: PathBuf) -> TlsCertAndKey {
104-
let key = load_private_key_pem(private_key);
105-
let cert_chain = load_certs_pem(cert_chain).unwrap();
106-
TlsCertAndKey { key, cert_chain }
118+
fn load_tls_cert_and_key(
119+
cert_chain: PathBuf,
120+
private_key: PathBuf,
121+
) -> anyhow::Result<TlsCertAndKey> {
122+
let key = load_private_key_pem(private_key)?;
123+
let cert_chain = load_certs_pem(cert_chain)?;
124+
Ok(TlsCertAndKey { key, cert_chain })
107125
}
108126

109127
pub fn load_certs_pem(path: PathBuf) -> std::io::Result<Vec<CertificateDer<'static>>> {
110128
Ok(
111129
rustls_pemfile::certs(&mut std::io::BufReader::new(File::open(path)?))
112-
.map(|res| res.unwrap())
130+
.map(|res| res.unwrap()) //TODO
113131
.collect(),
114132
)
115133
}
116134

117-
pub fn load_private_key_pem(path: PathBuf) -> PrivateKeyDer<'static> {
118-
let mut reader = std::io::BufReader::new(File::open(path).unwrap());
135+
pub fn load_private_key_pem(path: PathBuf) -> anyhow::Result<PrivateKeyDer<'static>> {
136+
let mut reader = std::io::BufReader::new(File::open(path)?);
119137

120138
// Tries to read the key as PKCS#8, PKCS#1, or SEC1
121139
let pks8_key = rustls_pemfile::pkcs8_private_keys(&mut reader)
122140
.next()
123-
.unwrap()
124-
.unwrap();
141+
.ok_or(anyhow!("No PKS8 Key"))??;
125142

126-
PrivateKeyDer::Pkcs8(pks8_key)
143+
Ok(PrivateKeyDer::Pkcs8(pks8_key))
127144
}

0 commit comments

Comments
 (0)