diff --git a/examples/tls-offloading/Cargo.toml b/examples/tls-offloading/Cargo.toml new file mode 100644 index 0000000000..17d8e334ac --- /dev/null +++ b/examples/tls-offloading/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "tls-offloading" +version = "0.1.0" +edition = "2024" + +[dependencies] +s2n-quic = { version = "1", path = "../../quic/s2n-quic", features = ["unstable-offload-tls"]} +tokio = { version = "1", features = ["full"] } diff --git a/examples/tls-offloading/README.md b/examples/tls-offloading/README.md new file mode 100644 index 0000000000..74538d0873 --- /dev/null +++ b/examples/tls-offloading/README.md @@ -0,0 +1,25 @@ +# TLS offload + +s2n-quic is single-threaded by default. This can cause performance issues in the instance where many clients try to connect to a single s2n-quic server at the same time. Each incoming Client Hello will cause the s2n-quic server event loop to be blocked while the TLS provider completes the expensive cryptographic operations necessary to process the Client Hello. This has the potential to slow down all existing connections in favor of new ones. The TLS offloading feature attempts to alleviate this problem by moving each TLS connection to a separate async task, which can then be spawned by the runtime the user provides. + +To do this, implement the `offload::Executor` trait with the runtime of your choice. In this example, we use the `tokio::spawn` function as our executor: +``` +struct TokioExecutor; +impl Executor for TokioExecutor { + fn spawn(&self, task: impl core::future::Future + Send + 'static) { + tokio::spawn(task); + } +} +``` + +# Warning +The default offloading feature as-is may result in packet loss in the handshake, depending on how packets arrive to the offload endpoint. This packet loss will slow down the handshake, as QUIC has to detect the loss and the peer has to resend the lost packets. We have an upcoming feature that will combat this packet loss and will probably be required to achieve the fastest handshakes possible with s2n-quic: https://github.com/aws/s2n-quic/pull/2668. However, it is still in the PR process. + +# Set-up + +Currently offloading is disabled by default as it is still in development. It can be enabled by adding this line to your Cargo.toml file: + +```toml +[dependencies] +s2n-quic = { version = "1", features = ["unstable-offload-tls"]} +``` diff --git a/examples/tls-offloading/src/bin/client.rs b/examples/tls-offloading/src/bin/client.rs new file mode 100644 index 0000000000..674283607e --- /dev/null +++ b/examples/tls-offloading/src/bin/client.rs @@ -0,0 +1,64 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic::{ + Client, + client::Connect, + provider::tls::{ + default, + offload::{Executor, OffloadBuilder}, + }, +}; +use std::{error::Error, net::SocketAddr}; + +/// NOTE: this certificate is to be used for demonstration purposes only! +pub static CERT_PEM: &str = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../quic/s2n-quic-core/certs/cert.pem" +)); + +struct TokioExecutor; +impl Executor for TokioExecutor { + fn spawn(&self, task: impl core::future::Future + Send + 'static) { + tokio::spawn(task); + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let tls = default::Client::builder() + .with_certificate(CERT_PEM)? + .build()?; + let tls_endpoint = OffloadBuilder::new() + .with_endpoint(tls) + .with_executor(TokioExecutor) + .build(); + + let client = Client::builder() + .with_tls(tls_endpoint)? + .with_io("0.0.0.0:0")? + .start()?; + + let addr: SocketAddr = "127.0.0.1:4433".parse()?; + let connect = Connect::new(addr).with_server_name("localhost"); + let mut connection = client.connect(connect).await?; + + // ensure the connection doesn't time out with inactivity + connection.keep_alive(true)?; + + // open a new stream and split the receiving and sending sides + let stream = connection.open_bidirectional_stream().await?; + let (mut receive_stream, mut send_stream) = stream.split(); + + // spawn a task that copies responses from the server to stdout + tokio::spawn(async move { + let mut stdout = tokio::io::stdout(); + let _ = tokio::io::copy(&mut receive_stream, &mut stdout).await; + }); + + // copy data from stdin and send it to the server + let mut stdin = tokio::io::stdin(); + tokio::io::copy(&mut stdin, &mut send_stream).await?; + + Ok(()) +} diff --git a/examples/tls-offloading/src/bin/server.rs b/examples/tls-offloading/src/bin/server.rs new file mode 100644 index 0000000000..dd61a38657 --- /dev/null +++ b/examples/tls-offloading/src/bin/server.rs @@ -0,0 +1,67 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use s2n_quic::{ + Server, + provider::tls::{ + default, + offload::{Executor, OffloadBuilder}, + }, +}; +use std::error::Error; + +/// NOTE: this certificate is to be used for demonstration purposes only! +pub static CERT_PEM: &str = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../quic/s2n-quic-core/certs/cert.pem" +)); +/// NOTE: this certificate is to be used for demonstration purposes only! +pub static KEY_PEM: &str = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/../../quic/s2n-quic-core/certs/key.pem" +)); + +struct TokioExecutor; +impl Executor for TokioExecutor { + fn spawn(&self, task: impl core::future::Future + Send + 'static) { + tokio::spawn(task); + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let tls = default::Server::builder() + .with_certificate(CERT_PEM, KEY_PEM)? + .build()?; + + let tls_endpoint = OffloadBuilder::new() + .with_endpoint(tls) + .with_executor(TokioExecutor) + .build(); + + let mut server = Server::builder() + .with_tls(tls_endpoint)? + .with_io("127.0.0.1:4433")? + .start()?; + + while let Some(mut connection) = server.accept().await { + // spawn a new task for the connection + tokio::spawn(async move { + eprintln!("Connection accepted from {:?}", connection.remote_addr()); + + while let Ok(Some(mut stream)) = connection.accept_bidirectional_stream().await { + // spawn a new task for the stream + tokio::spawn(async move { + eprintln!("Stream opened from {:?}", stream.connection().remote_addr()); + + // echo any data back to the stream + while let Ok(Some(data)) = stream.receive().await { + stream.send(data).await.expect("stream should be open"); + } + }); + } + }); + } + + Ok(()) +} diff --git a/quic/s2n-quic-core/src/crypto/tls.rs b/quic/s2n-quic-core/src/crypto/tls.rs index 1fdbf8dc23..6a5b916fce 100644 --- a/quic/s2n-quic-core/src/crypto/tls.rs +++ b/quic/s2n-quic-core/src/crypto/tls.rs @@ -20,6 +20,9 @@ pub mod null; #[cfg(feature = "alloc")] pub mod slow_tls; +#[cfg(feature = "std")] +pub mod offload; + /// Holds all application parameters which are exchanged within the TLS handshake. #[derive(Debug)] pub struct ApplicationParameters<'a> { diff --git a/quic/s2n-quic-core/src/crypto/tls/offload.rs b/quic/s2n-quic-core/src/crypto/tls/offload.rs new file mode 100644 index 0000000000..f524e88158 --- /dev/null +++ b/quic/s2n-quic-core/src/crypto/tls/offload.rs @@ -0,0 +1,652 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +use crate::{ + application, + crypto::{ + tls::{self, NamedGroup, TlsSession}, + CryptoSuite, + }, + sync::spsc::{channel, Receiver, SendSlice, Sender}, + transport, +}; +use alloc::{boxed::Box, collections::vec_deque::VecDeque, sync::Arc, vec::Vec}; +use core::{any::Any, future::Future, task::Poll}; +use std::sync::Mutex; + +/// Trait used for spawning async tasks corresponding to TLS operations. Each task will signify TLS work +/// that needs to be done per QUIC connection. +pub trait Executor { + fn spawn(&self, task: impl Future + Send + 'static); +} + +/// Allows access to the TlsSession on handshake failure and when the exporter secret is ready. +pub trait ExporterHandler { + fn on_tls_handshake_failed(&self, session: &impl TlsSession) -> Option>; + fn on_tls_exporter_ready(&self, session: &impl TlsSession) -> Option>; +} + +// Most people don't need the TlsSession so we ignore these callbacks by default +impl ExporterHandler for () { + fn on_tls_handshake_failed( + &self, + _session: &impl TlsSession, + ) -> Option> { + None + } + + fn on_tls_exporter_ready( + &self, + _session: &impl TlsSession, + ) -> Option> { + None + } +} + +pub struct OffloadEndpoint { + inner: E, + executor: X, + exporter: H, + channel_capacity: usize, +} + +impl OffloadEndpoint { + pub fn new(inner: E, executor: X, exporter: H, channel_capacity: usize) -> Self { + Self { + inner, + executor, + exporter, + channel_capacity, + } + } +} + +impl tls::Endpoint for OffloadEndpoint +where + E: tls::Endpoint, + X: Executor + Send + 'static, + H: ExporterHandler + Send + 'static + Sync + Clone, +{ + type Session = OffloadSession<::Session>; + + fn new_server_session( + &mut self, + transport_parameters: &Params, + ) -> Self::Session { + OffloadSession::new( + self.inner.new_server_session(transport_parameters), + &self.executor, + self.exporter.clone(), + self.channel_capacity, + ) + } + + fn new_client_session( + &mut self, + transport_parameters: &Params, + server_name: application::ServerName, + ) -> Self::Session { + OffloadSession::new( + self.inner + .new_client_session(transport_parameters, server_name), + &self.executor, + self.exporter.clone(), + self.channel_capacity, + ) + } + + fn max_tag_length(&self) -> usize { + self.inner.max_tag_length() + } +} + +#[derive(Debug)] +pub struct OffloadSession { + recv_from_tls: Receiver>, + send_to_tls: Sender, + allowed_to_send: Arc>, +} + +impl OffloadSession { + fn new( + mut inner: S, + executor: &impl Executor, + exporter: impl ExporterHandler + Sync + Send + 'static + Clone, + channel_capacity: usize, + ) -> Self { + let (mut send_to_quic, recv_from_tls): (Sender>, Receiver>) = + channel(channel_capacity); + let (send_to_tls, mut recv_from_quic): (Sender, Receiver) = + channel(channel_capacity); + let allowed_to_send = Arc::new(Mutex::new(AllowedToSend::default())); + let clone = allowed_to_send.clone(); + + let future = async move { + let mut initial_data = VecDeque::default(); + let mut handshake_data = VecDeque::default(); + let mut application_data = VecDeque::default(); + + core::future::poll_fn(|ctx| { + match send_to_quic.poll_slice(ctx) { + Poll::Ready(res) => match res { + Ok(send_slice) => { + let allowed_to_send = *allowed_to_send.lock().unwrap(); + + let mut context = RemoteContext { + send_to_quic: send_slice, + waker: ctx.waker().clone(), + initial_data: &mut initial_data, + handshake_data: &mut handshake_data, + application_data: &mut application_data, + exporter_handler: exporter.clone(), + allowed_to_send, + error: None, + }; + + while let Poll::Ready(res) = recv_from_quic.poll_slice(ctx) { + match res { + Ok(mut recv_slice) => { + while let Some(response) = recv_slice.pop() { + match response { + Response::Initial(data) => { + context.initial_data.push_back(data); + } + Response::Handshake(data) => { + context.handshake_data.push_back(data); + } + Response::Application(data) => { + context.application_data.push_back(data) + } + Response::SendStatusChanged => (), + } + } + } + Err(_) => { + // For whatever reason the QUIC side decided to drop this channel. In this case + // we complete the future. + return Poll::Ready(()); + } + } + } + + let res = inner.poll(&mut context); + // Either there was an error or the handshake has finished if TLS returned Poll::Ready. + // Notify the QUIC side accordingly. + if let Poll::Ready(res) = res { + let request = match res { + Ok(_) => Request::TlsDone, + Err(e) => Request::TlsError(e), + }; + let _ = context.send_to_quic.push(request); + } + + // We also need to notify the QUIC side of any stored errors that we have. + if let Some(error) = context.error { + let _ = context.send_to_quic.push(Request::TlsError(error)); + } + + // We've already sent the Result to the QUIC side so we can just map it out here. + res.map(|_| ()) + } + Err(_) => { + // For whatever reason the QUIC side decided to drop this channel. In this case + // we complete the future. + Poll::Ready(()) + } + }, + Poll::Pending => Poll::Pending, + } + }) + .await; + }; + executor.spawn(future); + + Self { + recv_from_tls, + send_to_tls, + allowed_to_send: clone, + } + } +} + +impl tls::Session for OffloadSession { + #[inline] + fn poll(&mut self, context: &mut W) -> Poll> + where + W: tls::Context, + { + let cloned_waker = &context.waker().clone(); + let mut ctx = core::task::Context::from_waker(cloned_waker); + + match self.recv_from_tls.poll_slice(&mut ctx) { + Poll::Ready(res) => match res { + Ok(mut slice) => { + while let Some(request) = slice.pop() { + match request { + Request::HandshakeKeys(key, header_key) => { + context.on_handshake_keys(key, header_key)?; + } + Request::ServerName(server_name) => { + context.on_server_name(server_name)? + } + Request::SendInitial(bytes) => context.send_initial(bytes), + Request::ClientParams(client_params, mut server_params) => context + .on_client_application_params( + tls::ApplicationParameters { + transport_parameters: &client_params, + }, + &mut server_params, + )?, + Request::ApplicationProtocol(bytes) => { + context.on_application_protocol(bytes)?; + } + Request::KeyExchangeGroup(named_group) => { + context.on_key_exchange_group(named_group)?; + } + Request::OneRttKeys(key, header_key, transport_parameters) => context + .on_one_rtt_keys( + key, + header_key, + tls::ApplicationParameters { + transport_parameters: &transport_parameters, + }, + )?, + Request::SendHandshake(bytes) => { + context.send_handshake(bytes); + } + Request::HandshakeComplete => { + context.on_handshake_complete()?; + } + Request::TlsDone => { + return Poll::Ready(Ok(())); + } + Request::ZeroRtt(key, header_key, transport_parameters) => { + context.on_zero_rtt_keys( + key, + header_key, + tls::ApplicationParameters { + transport_parameters: &transport_parameters, + }, + )?; + } + Request::TlsContext(ctx) => { + context.on_tls_context(ctx); + } + Request::SendApplication(transmission) => { + context.send_application(transmission); + } + Request::TlsError(e) => return Poll::Ready(Err(e)), + } + } + } + Err(_) => { + // For whatever reason the TLS task was cancelled. We cannot continue the handshake. + return Poll::Ready(Err(transport::Error::from(tls::Error::HANDSHAKE_FAILURE))); + } + }, + Poll::Pending => (), + } + + let mut allowed_to_send = self.allowed_to_send.lock().unwrap(); + let mut state_change = false; + if allowed_to_send.can_send_initial != context.can_send_initial() + || allowed_to_send.can_send_handshake != context.can_send_handshake() + || allowed_to_send.can_send_application != context.can_send_application() + { + *allowed_to_send = AllowedToSend { + can_send_initial: context.can_send_initial(), + can_send_handshake: context.can_send_handshake(), + can_send_application: context.can_send_application(), + }; + state_change = true; + } + // Drop the lock ASAP + drop(allowed_to_send); + + match self.send_to_tls.poll_slice(&mut ctx) { + Poll::Ready(res) => match res { + Ok(mut slice) => { + if let Some(resp) = context.receive_initial(None) { + let _ = slice.push(Response::Initial(resp)); + } + + if let Some(resp) = context.receive_handshake(None) { + let _ = slice.push(Response::Handshake(resp)); + } + + if let Some(resp) = context.receive_application(None) { + let _ = slice.push(Response::Application(resp)); + } + + if state_change { + let _ = slice.push(Response::SendStatusChanged); + } + } + Err(_) => { + // For whatever reason the TLS task was cancelled. We cannot continue the handshake. + return Poll::Ready(Err(transport::Error::from(tls::Error::HANDSHAKE_FAILURE))); + } + }, + Poll::Pending => (), + } + + Poll::Pending + } +} + +impl CryptoSuite for OffloadSession { + type HandshakeKey = ::HandshakeKey; + type HandshakeHeaderKey = ::HandshakeHeaderKey; + type InitialKey = ::InitialKey; + type InitialHeaderKey = ::InitialHeaderKey; + type ZeroRttKey = ::ZeroRttKey; + type ZeroRttHeaderKey = ::ZeroRttHeaderKey; + type OneRttKey = ::OneRttKey; + type OneRttHeaderKey = ::OneRttHeaderKey; + type RetryKey = ::RetryKey; +} + +#[derive(Debug, Default, Copy, Clone)] +struct AllowedToSend { + can_send_initial: bool, + can_send_handshake: bool, + can_send_application: bool, +} + +const SLICE_ERROR: crate::transport::Error = + crate::transport::Error::INTERNAL_ERROR.with_reason("Slice is full"); + +#[derive(Debug)] +struct RemoteContext<'a, Request, H> { + send_to_quic: SendSlice<'a, Request>, + initial_data: &'a mut VecDeque, + handshake_data: &'a mut VecDeque, + application_data: &'a mut VecDeque, + waker: core::task::Waker, + allowed_to_send: AllowedToSend, + exporter_handler: H, + error: Option, +} + +impl<'a, S: CryptoSuite, H: ExporterHandler> tls::Context for RemoteContext<'a, Request, H> { + fn on_client_application_params( + &mut self, + client_params: tls::ApplicationParameters, + server_params: &mut alloc::vec::Vec, + ) -> Result<(), crate::transport::Error> { + match self.send_to_quic.push(Request::ClientParams( + client_params.transport_parameters.to_vec(), + server_params.to_vec(), + )) { + Ok(_) => return Ok(()), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_handshake_keys( + &mut self, + key: ::HandshakeKey, + header_key: ::HandshakeHeaderKey, + ) -> Result<(), crate::transport::Error> { + match self + .send_to_quic + .push(Request::HandshakeKeys(key, header_key)) + { + Ok(_) => return Ok(()), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_zero_rtt_keys( + &mut self, + key: ::ZeroRttKey, + header_key: ::ZeroRttHeaderKey, + application_parameters: tls::ApplicationParameters, + ) -> Result<(), crate::transport::Error> { + match self.send_to_quic.push(Request::ZeroRtt( + key, + header_key, + application_parameters.transport_parameters.to_vec(), + )) { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_one_rtt_keys( + &mut self, + key: ::OneRttKey, + header_key: ::OneRttHeaderKey, + application_parameters: tls::ApplicationParameters, + ) -> Result<(), crate::transport::Error> { + match self.send_to_quic.push(Request::OneRttKeys( + key, + header_key, + application_parameters.transport_parameters.to_vec(), + )) { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_server_name( + &mut self, + server_name: crate::application::ServerName, + ) -> Result<(), crate::transport::Error> { + match self.send_to_quic.push(Request::ServerName(server_name)) { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_application_protocol( + &mut self, + application_protocol: bytes::Bytes, + ) -> Result<(), crate::transport::Error> { + match self + .send_to_quic + .push(Request::ApplicationProtocol(application_protocol)) + { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_key_exchange_group( + &mut self, + named_group: tls::NamedGroup, + ) -> Result<(), crate::transport::Error> { + match self + .send_to_quic + .push(Request::KeyExchangeGroup(named_group)) + { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + Ok(()) + } + + fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error> { + match self.send_to_quic.push(Request::HandshakeComplete) { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + + Ok(()) + } + + fn on_tls_context(&mut self, _context: Box) { + unimplemented!("TLS Context is not supported in Offload implementation"); + } + + fn on_tls_exporter_ready( + &mut self, + session: &impl TlsSession, + ) -> Result<(), crate::transport::Error> { + if let Some(context) = self.exporter_handler.on_tls_exporter_ready(session) { + match self.send_to_quic.push(Request::TlsContext(context)) { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + } + + Ok(()) + } + + fn receive_initial(&mut self, max_len: Option) -> Option { + gimme_bytes(max_len, self.initial_data) + } + + fn receive_handshake(&mut self, max_len: Option) -> Option { + gimme_bytes(max_len, self.handshake_data) + } + + fn receive_application(&mut self, max_len: Option) -> Option { + gimme_bytes(max_len, self.application_data) + } + + fn can_send_initial(&self) -> bool { + self.allowed_to_send.can_send_initial + } + + fn send_initial(&mut self, transmission: bytes::Bytes) { + if self + .send_to_quic + .push(Request::SendInitial(transmission)) + .is_err() + { + self.error = Some(SLICE_ERROR); + } + } + + fn can_send_handshake(&self) -> bool { + self.allowed_to_send.can_send_handshake + } + + fn send_handshake(&mut self, transmission: bytes::Bytes) { + if self + .send_to_quic + .push(Request::SendHandshake(transmission)) + .is_err() + { + self.error = Some(SLICE_ERROR); + } + } + + fn can_send_application(&self) -> bool { + self.allowed_to_send.can_send_application + } + + fn send_application(&mut self, transmission: bytes::Bytes) { + if self + .send_to_quic + .push(Request::SendApplication(transmission)) + .is_err() + { + self.error = Some(SLICE_ERROR); + } + } + + fn waker(&self) -> &core::task::Waker { + &self.waker + } + + fn on_tls_handshake_failed( + &mut self, + session: &impl tls::TlsSession, + ) -> Result<(), crate::transport::Error> { + if let Some(context) = self.exporter_handler.on_tls_handshake_failed(session) { + match self.send_to_quic.push(Request::TlsContext(context)) { + Ok(_) => (), + Err(_) => self.error = Some(SLICE_ERROR), + } + } + Ok(()) + } +} + +fn gimme_bytes(max_len: Option, vec: &mut VecDeque) -> Option { + let bytes = vec.pop_front(); + if let Some(mut bytes) = bytes { + if let Some(max_len) = max_len { + if bytes.len() > max_len { + let remainder = bytes.split_off(max_len); + vec.push_front(remainder); + } + } + return Some(bytes); + } + None +} + +enum Request { + ZeroRtt( + ::ZeroRttKey, + ::ZeroRttHeaderKey, + Vec, + ), + ServerName(crate::application::ServerName), + SendInitial(bytes::Bytes), + ClientParams(Vec, Vec), + HandshakeKeys( + ::HandshakeKey, + ::HandshakeHeaderKey, + ), + SendHandshake(bytes::Bytes), + ApplicationProtocol(bytes::Bytes), + KeyExchangeGroup(NamedGroup), + OneRttKeys( + ::OneRttKey, + ::OneRttHeaderKey, + Vec, + ), + HandshakeComplete, + TlsDone, + TlsContext(Box), + SendApplication(bytes::Bytes), + TlsError(transport::Error), +} + +enum Response { + Initial(bytes::Bytes), + Handshake(bytes::Bytes), + Application(bytes::Bytes), + SendStatusChanged, +} + +impl alloc::fmt::Debug for Request { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Request::ServerName(_) => write!(f, "ServerName"), + Request::SendInitial(_) => write!(f, "SendInitial"), + Request::ClientParams(_, _) => write!(f, "ClientParams"), + Request::HandshakeKeys(_, _) => write!(f, "HandshakeKeys"), + Request::SendHandshake(_) => write!(f, "SendHandshake"), + Request::ApplicationProtocol(_) => write!(f, "ApplicationProtocol"), + Request::KeyExchangeGroup(_) => write!(f, "KeyExchangeGroup"), + Request::OneRttKeys(_, _, _) => write!(f, "OneRttKeys"), + Request::HandshakeComplete => write!(f, "HandshakeComplete"), + Request::TlsDone => write!(f, "TlsDone"), + Request::ZeroRtt(_, _, _) => write!(f, "ZeroRtt"), + Request::TlsContext(_) => write!(f, "TlsContext"), + Request::SendApplication(_) => write!(f, "SendApplication"), + Request::TlsError(_) => write!(f, "TlsError"), + } + } +} + +impl alloc::fmt::Debug for Response { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Response::Initial(_) => write!(f, "ResponseInitial"), + Response::Handshake(_) => write!(f, "ResponseHandshake"), + Response::Application(_) => write!(f, "ResponseApplication"), + Response::SendStatusChanged => write!(f, "SendStatusChanged"), + } + } +} diff --git a/quic/s2n-quic-core/src/sync/spsc/recv.rs b/quic/s2n-quic-core/src/sync/spsc/recv.rs index 9764af4d05..0ac01539f4 100644 --- a/quic/s2n-quic-core/src/sync/spsc/recv.rs +++ b/quic/s2n-quic-core/src/sync/spsc/recv.rs @@ -8,6 +8,7 @@ use core::{ task::{Context, Poll}, }; +#[derive(Debug)] pub struct Receiver(pub(super) State); impl Receiver { diff --git a/quic/s2n-quic-core/src/sync/spsc/send.rs b/quic/s2n-quic-core/src/sync/spsc/send.rs index 18a3843a45..0e2bd8effe 100644 --- a/quic/s2n-quic-core/src/sync/spsc/send.rs +++ b/quic/s2n-quic-core/src/sync/spsc/send.rs @@ -8,6 +8,7 @@ use core::{ task::{Context, Poll}, }; +#[derive(Debug)] pub struct Sender(pub(super) State); impl Sender { @@ -81,6 +82,7 @@ impl Drop for Sender { } } +#[derive(Debug)] pub struct SendSlice<'a, T>(&'a mut State, Cursor); impl SendSlice<'_, T> { diff --git a/quic/s2n-quic-tests/Cargo.toml b/quic/s2n-quic-tests/Cargo.toml index 0fbd91e992..984ce1bfa5 100644 --- a/quic/s2n-quic-tests/Cargo.toml +++ b/quic/s2n-quic-tests/Cargo.toml @@ -11,13 +11,14 @@ license = "Apache-2.0" publish = false [dependencies] +bach = "0.1.0" bytes = { version = "1", default-features = false } futures = { version = "0.3", default-features = false, features = ["std"] } rand = "0.9" rand_chacha = "0.9" s2n-codec = { path = "../../common/s2n-codec" } s2n-quic-core = { path = "../s2n-quic-core", features = ["branch-tracing", "event-tracing", "probe-tracing", "testing"] } -s2n-quic = { path = "../s2n-quic", features = ["provider-event-tracing", "unstable-provider-io-testing", "unstable-provider-dc", "unstable-provider-packet-interceptor", "unstable-provider-random"] } +s2n-quic = { path = "../s2n-quic", features = ["provider-event-tracing", "unstable-provider-io-testing", "unstable-provider-dc", "unstable-provider-packet-interceptor", "unstable-provider-random", "unstable-offload-tls"] } s2n-quic-platform = { path = "../s2n-quic-platform", features = ["tokio-runtime"] } s2n-quic-transport = { path = "../s2n-quic-transport", features = ["unstable_resumption", "unstable-provider-dc"] } tokio = { version = "1", features = ["full"] } @@ -31,4 +32,4 @@ zerocopy = { version = "0.8", features = ["derive"] } quiche = "0.24" [target.'cfg(unix)'.dependencies] -s2n-quic = { path = "../s2n-quic", features = ["provider-event-tracing", "provider-tls-s2n", "unstable-provider-io-testing", "unstable-provider-dc", "unstable-provider-packet-interceptor", "unstable-provider-random"] } +s2n-quic = { path = "../s2n-quic", features = ["provider-event-tracing", "provider-tls-s2n", "unstable-provider-io-testing", "unstable-provider-dc", "unstable-provider-packet-interceptor", "unstable-provider-random", "unstable-offload-tls", "unstable_client_hello"] } diff --git a/quic/s2n-quic-tests/src/tests.rs b/quic/s2n-quic-tests/src/tests.rs index 71f55d9cd5..073a908d92 100644 --- a/quic/s2n-quic-tests/src/tests.rs +++ b/quic/s2n-quic-tests/src/tests.rs @@ -39,6 +39,7 @@ mod issue_1717; mod issue_954; mod mtu; mod no_tls; +mod offload; mod platform_events; mod pto; mod resumption; diff --git a/quic/s2n-quic-tests/src/tests/offload.rs b/quic/s2n-quic-tests/src/tests/offload.rs new file mode 100644 index 0000000000..a2a7a699c3 --- /dev/null +++ b/quic/s2n-quic-tests/src/tests/offload.rs @@ -0,0 +1,269 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +use super::*; +use s2n_quic::provider::tls::{ + default, + offload::{Executor, ExporterHandler, OffloadBuilder}, +}; +struct BachExecutor; +impl Executor for BachExecutor { + fn spawn(&self, task: impl core::future::Future + Send + 'static) { + bach::spawn(task); + } +} + +#[derive(Clone)] +struct Exporter; +impl ExporterHandler for Exporter { + fn on_tls_handshake_failed( + &self, + _session: &impl s2n_quic_core::crypto::tls::TlsSession, + ) -> Option> { + None + } + + fn on_tls_exporter_ready( + &self, + _session: &impl s2n_quic_core::crypto::tls::TlsSession, + ) -> Option> { + None + } +} + +#[test] +fn tls() { + let model = Model::default(); + test(model, |handle| { + let server_endpoint = default::Server::builder() + .with_certificate(certificates::CERT_PEM, certificates::KEY_PEM) + .unwrap() + .build() + .unwrap(); + let client_endpoint = default::Client::builder() + .with_certificate(certificates::CERT_PEM) + .unwrap() + .build() + .unwrap(); + + let server_endpoint = OffloadBuilder::new() + .with_endpoint(server_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + let client_endpoint = OffloadBuilder::new() + .with_endpoint(client_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_event(tracing_events())? + .with_tls(server_endpoint)? + .start()?; + + let client = Client::builder() + .with_io(handle.builder().build()?)? + .with_tls(client_endpoint)? + .with_event(tracing_events())? + .start()?; + let addr = start_server(server)?; + start_client(client, addr, Data::new(1000))?; + + Ok(addr) + }) + .unwrap(); +} + +#[test] +fn failed_tls_handshake() { + use s2n_quic::connection::Error; + use s2n_quic_core::{crypto::tls::Error as TlsError, transport}; + let connection_closed_subscriber = recorder::ConnectionClosed::new(); + let connection_closed_event = connection_closed_subscriber.events(); + + let model = Model::default(); + test(model, |handle| { + let server_endpoint = default::Server::builder() + .with_certificate( + certificates::UNTRUSTED_CERT_PEM, + certificates::UNTRUSTED_KEY_PEM, + ) + .unwrap() + .build() + .unwrap(); + + let client_endpoint = default::Client::builder() + .with_certificate(certificates::CERT_PEM) + .unwrap() + .build() + .unwrap(); + + let server_endpoint = OffloadBuilder::new() + .with_endpoint(server_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + let client_endpoint = OffloadBuilder::new() + .with_endpoint(client_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_event((tracing_events(), connection_closed_subscriber))? + .with_tls(server_endpoint)? + .start()?; + + let client = Client::builder() + .with_io(handle.builder().build()?)? + .with_tls(client_endpoint)? + .with_event(tracing_events())? + .start()?; + let addr = start_server(server)?; + primary::spawn(async move { + let connect = Connect::new(addr).with_server_name("localhost"); + client.connect(connect).await.unwrap_err(); + }); + + Ok(addr) + }) + .unwrap(); + + let connection_closed_handle = connection_closed_event.lock().unwrap(); + let Error::Transport { code, .. } = connection_closed_handle[0] else { + panic!("Unexpected error type") + }; + let expected_error = TlsError::HANDSHAKE_FAILURE; + assert_eq!(code, transport::Error::from(expected_error).code); +} + +#[test] +#[cfg(unix)] +fn mtls() { + let model = Model::default(); + test(model, |handle| { + let server_endpoint = build_server_mtls_provider(certificates::MTLS_CA_CERT)?; + let client_endpoint = build_client_mtls_provider(certificates::MTLS_CA_CERT)?; + + let server_endpoint = OffloadBuilder::new() + .with_endpoint(server_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + let client_endpoint = OffloadBuilder::new() + .with_endpoint(client_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_event(tracing_events())? + .with_tls(server_endpoint)? + .start()?; + + let client = Client::builder() + .with_io(handle.builder().build()?)? + .with_tls(client_endpoint)? + .with_event(tracing_events())? + .start()?; + let addr = start_server(server)?; + start_client(client, addr, Data::new(1000))?; + + Ok(addr) + }) + .unwrap(); +} + +#[test] +#[cfg(unix)] +fn async_client_hello() { + use futures::{ready, FutureExt}; + use s2n_quic::provider::tls::s2n_tls::{ + self, callbacks::ClientHelloCallback, connection::Connection, error::Error, + }; + use std::task::Poll; + + let model = Model::default(); + + struct MyCallbackHandler; + struct MyConnectionFuture { + output: Option>, + } + + impl ClientHelloCallback for MyCallbackHandler { + fn on_client_hello( + &self, + _connection: &mut Connection, + ) -> Result>>, Error> + { + let fut = MyConnectionFuture { output: None }; + Ok(Some(Box::pin(fut))) + } + } + + impl s2n_tls::callbacks::ConnectionFuture for MyConnectionFuture { + fn poll( + mut self: std::pin::Pin<&mut Self>, + _connection: &mut Connection, + ctx: &mut core::task::Context, + ) -> Poll> { + loop { + if let Some(handle) = &mut self.output { + let _ = ready!(handle.poll_unpin(ctx)); + return Poll::Ready(Ok(())); + } else { + let future = async move { + let timer = bach::time::sleep(Duration::from_secs(3)); + timer.await; + }; + self.output = Some(bach::spawn(future)); + } + } + } + } + test(model, |handle| { + let server_endpoint = default::Server::builder() + .with_certificate(certificates::CERT_PEM, certificates::KEY_PEM) + .unwrap() + .with_client_hello_handler(MyCallbackHandler) + .unwrap() + .build() + .unwrap(); + let client_endpoint = default::Client::builder() + .with_certificate(certificates::CERT_PEM) + .unwrap() + .build() + .unwrap(); + + let server_endpoint = OffloadBuilder::new() + .with_endpoint(server_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + let client_endpoint = OffloadBuilder::new() + .with_endpoint(client_endpoint) + .with_executor(BachExecutor) + .with_exporter(Exporter) + .build(); + + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_event(tracing_events())? + .with_tls(server_endpoint)? + .start()?; + + let client = Client::builder() + .with_io(handle.builder().build()?)? + .with_tls(client_endpoint)? + .with_event(tracing_events())? + .start()?; + let addr = start_server(server)?; + start_client(client, addr, Data::new(1000))?; + + Ok(addr) + }) + .unwrap(); +} diff --git a/quic/s2n-quic/Cargo.toml b/quic/s2n-quic/Cargo.toml index d8e3031604..6116c3d863 100644 --- a/quic/s2n-quic/Cargo.toml +++ b/quic/s2n-quic/Cargo.toml @@ -61,6 +61,8 @@ unstable-congestion-controller = ["s2n-quic-core/unstable-congestion-controller" unstable-limits = ["s2n-quic-core/unstable-limits"] # The feature enables the close formatter provider unstable-provider-connection-close-formatter = [] +# This feature enables the use of the offloaded TLS feature +unstable-offload-tls = [] [dependencies] bytes = { version = "1", default-features = false } diff --git a/quic/s2n-quic/src/provider/tls.rs b/quic/s2n-quic/src/provider/tls.rs index 978fc26021..3170d1da02 100644 --- a/quic/s2n-quic/src/provider/tls.rs +++ b/quic/s2n-quic/src/provider/tls.rs @@ -309,3 +309,128 @@ pub mod s2n_tls { } } } + +#[cfg(feature = "unstable-offload-tls")] +pub mod offload { + use super::Provider; + use s2n_quic_core::crypto::tls::{offload::OffloadEndpoint, Endpoint}; + pub use s2n_quic_core::crypto::tls::{ + offload::{Executor, ExporterHandler}, + TlsSession, + }; + + pub struct Offload { + endpoint: E, + executor: X, + exporter: H, + channel_capacity: usize, + } + + pub struct OffloadBuilder { + endpoint: E, + executor: X, + exporter: H, + channel_capacity: usize, + } + + impl OffloadBuilder<(), (), ()> { + pub fn new() -> Self { + Self { + endpoint: (), + executor: (), + exporter: (), + channel_capacity: 10, + } + } + } + + impl Default for OffloadBuilder<(), (), ()> { + fn default() -> Self { + Self::new() + } + } + + impl OffloadBuilder<(), X, H> { + pub fn with_endpoint(self, endpoint: E) -> OffloadBuilder { + OffloadBuilder:: { + endpoint, + executor: self.executor, + exporter: self.exporter, + channel_capacity: self.channel_capacity, + } + } + } + + impl OffloadBuilder { + pub fn with_executor(self, executor: X) -> OffloadBuilder { + OffloadBuilder:: { + endpoint: self.endpoint, + executor, + exporter: self.exporter, + channel_capacity: self.channel_capacity, + } + } + } + + impl OffloadBuilder { + pub fn with_exporter(self, exporter: H) -> OffloadBuilder { + OffloadBuilder:: { + endpoint: self.endpoint, + executor: self.executor, + exporter, + channel_capacity: self.channel_capacity, + } + } + } + + impl OffloadBuilder { + pub fn with_channel_capacity(self, channel_capacity: usize) -> OffloadBuilder { + OffloadBuilder:: { + endpoint: self.endpoint, + executor: self.executor, + exporter: self.exporter, + channel_capacity, + } + } + } + + impl OffloadBuilder { + pub fn build(self) -> Offload { + Offload { + endpoint: self.endpoint, + executor: self.executor, + exporter: self.exporter, + channel_capacity: self.channel_capacity, + } + } + } + + impl Provider for Offload + where + E: Provider, + X: Executor + Send + 'static, + H: ExporterHandler + Send + 'static + Sync + Clone, + { + type Server = OffloadEndpoint<::Server, X, H>; + type Client = OffloadEndpoint<::Client, X, H>; + type Error = E::Error; + + fn start_server(self) -> Result { + Ok(OffloadEndpoint::new( + E::start_server(self.endpoint)?, + self.executor, + self.exporter, + self.channel_capacity, + )) + } + + fn start_client(self) -> Result { + Ok(OffloadEndpoint::new( + E::start_client(self.endpoint)?, + self.executor, + self.exporter, + self.channel_capacity, + )) + } + } +}