diff --git a/quic/s2n-quic-core/Cargo.toml b/quic/s2n-quic-core/Cargo.toml index c80609fdad..83132bb69d 100644 --- a/quic/s2n-quic-core/Cargo.toml +++ b/quic/s2n-quic-core/Cargo.toml @@ -13,7 +13,7 @@ exclude = ["corpus.tar.gz"] [features] default = ["alloc", "std"] alloc = ["atomic-waker", "bytes", "crossbeam-utils", "s2n-codec/alloc"] -std = ["alloc", "once_cell"] +std = ["alloc", "once_cell", "futures-channel", "futures"] testing = ["std", "generator", "s2n-codec/testing", "checked-counters", "insta", "futures-test"] generator = ["bolero-generator"] checked-counters = [] @@ -47,6 +47,8 @@ tracing = { version = "0.1", default-features = false, optional = true } zerocopy = { version = "0.8", features = ["derive"] } futures-test = { version = "0.3", optional = true } # For testing Waker interactions once_cell = { version = "1", optional = true } +futures-channel = { version = "0.3", default-features = false, optional=true, features = ["std", "alloc"]} +futures = { version = "0.3", default-features = false, optional=true, features = ["std", "alloc"]} [dev-dependencies] bolero = "0.13" diff --git a/quic/s2n-quic-core/src/crypto/tls.rs b/quic/s2n-quic-core/src/crypto/tls.rs index 3e93065e21..7aec40e27c 100644 --- a/quic/s2n-quic-core/src/crypto/tls.rs +++ b/quic/s2n-quic-core/src/crypto/tls.rs @@ -20,6 +20,9 @@ pub mod null; #[cfg(feature = "alloc")] pub mod slow_tls; +#[cfg(feature = "std")] +pub mod offload; + /// Holds all application parameters which are exchanged within the TLS handshake. #[derive(Debug)] pub struct ApplicationParameters<'a> { @@ -177,13 +180,13 @@ pub trait Context { /// is willing to buffer. fn receive_application(&mut self, max_len: Option) -> Option; - fn can_send_initial(&self) -> bool; + fn can_send_initial(&mut self) -> bool; fn send_initial(&mut self, transmission: Bytes); - fn can_send_handshake(&self) -> bool; + fn can_send_handshake(&mut self) -> bool; fn send_handshake(&mut self, transmission: Bytes); - fn can_send_application(&self) -> bool; + fn can_send_application(&mut self) -> bool; fn send_application(&mut self, transmission: Bytes); fn waker(&self) -> &core::task::Waker; diff --git a/quic/s2n-quic-core/src/crypto/tls/offload.rs b/quic/s2n-quic-core/src/crypto/tls/offload.rs new file mode 100644 index 0000000000..30f16e1174 --- /dev/null +++ b/quic/s2n-quic-core/src/crypto/tls/offload.rs @@ -0,0 +1,567 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +use crate::{ + application, + crypto::{ + tls::{self, NamedGroup, Session}, + CryptoSuite, + }, + transport, +}; +use alloc::{sync::Arc, task::Wake, vec, vec::Vec}; +use core::{ + any::Any, + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, +}; +use futures::{prelude::Stream, task}; +use futures_channel::{ + mpsc::{UnboundedReceiver, UnboundedSender}, + oneshot::{Receiver, Sender}, +}; +use std::thread; + +type SessionProducer = ( + ::Session, + UnboundedSender::Session>>, +); +pub struct OffloadEndpoint { + new_session: UnboundedSender>, + _thread: thread::JoinHandle<()>, + inner: E, + remote_thread_waker: Waker, +} + +impl OffloadEndpoint { + pub fn new(inner: E) -> Self { + let (tx, mut rx) = futures_channel::mpsc::unbounded::>(); + + let handle = thread::spawn(move || { + let mut sessions = vec![]; + let waker = Waker::from(Arc::new(ThreadWaker(thread::current()))); + + loop { + let mut cx = Context::from_waker(&waker); + + // Add incoming sessions to queue + while let Poll::Ready(Some((new_session, tx))) = + Pin::new(&mut rx).poll_next(&mut cx) + { + sessions.push(( + new_session, + RemoteContext { + tx, + waker: waker.clone(), + receive_initial: AsyncRequest::empty(), + receive_handshake: AsyncRequest::empty(), + receive_application: AsyncRequest::empty(), + + can_send_initial: AsyncRequest::empty(), + can_send_handshake: AsyncRequest::empty(), + can_send_application: AsyncRequest::empty(), + }, + )) + } + + let mut next_sessions = vec![]; + + // Make progress on all stored sessions, prioritizing existing sessions over incoming ones + for (mut session, mut ctx) in sessions { + match session.poll(&mut ctx) { + Poll::Ready(res) => { + let _ = ctx.tx.unbounded_send(Request::Done(session, res)); + } + Poll::Pending => { + next_sessions.push((session, ctx)); + } + } + } + sessions = next_sessions; + + thread::park(); + } + }); + + Self { + inner, + remote_thread_waker: task::Waker::from(Arc::new(ThreadWaker(handle.thread().clone()))), + _thread: handle, + new_session: tx, + } + } +} + +struct ThreadWaker(thread::Thread); + +impl Wake for ThreadWaker { + fn wake(self: Arc) { + self.0.unpark(); + } +} + +impl tls::Endpoint for OffloadEndpoint { + type Session = OffloadSession<::Session>; + + fn new_server_session( + &mut self, + transport_parameters: &Params, + ) -> Self::Session { + OffloadSession::new( + self.inner.new_server_session(transport_parameters), + &mut self.new_session, + self.remote_thread_waker.clone(), + ) + } + + fn new_client_session( + &mut self, + transport_parameters: &Params, + server_name: application::ServerName, + ) -> Self::Session { + OffloadSession::new( + self.inner + .new_client_session(transport_parameters, server_name), + &mut self.new_session, + self.remote_thread_waker.clone(), + ) + } + + fn max_tag_length(&self) -> usize { + self.inner.max_tag_length() + } +} + +#[derive(Debug)] +pub struct OffloadSession { + // Inner is none while remote thread has the session + inner: Option, + is_poll_done: Option>, + pending_requests: UnboundedReceiver>, + waker: Waker, +} + +impl OffloadSession { + fn new( + inner: S, + new_session: &mut UnboundedSender<(S, UnboundedSender>)>, + remote_thread: Waker, + ) -> Self { + // Channel to pass requests from remote TLS thread to main thread + let (tx, rx) = futures_channel::mpsc::unbounded::>(); + + // Send the session to the TLS thread. It will pass it back when the handshake has finished. + let _ = new_session.unbounded_send((inner, tx)); + + Self { + pending_requests: rx, + waker: remote_thread, + is_poll_done: None, + inner: None, + } + } +} + +impl tls::Session for OffloadSession { + #[inline] + fn poll(&mut self, context: &mut W) -> Poll> + where + W: tls::Context, + { + if let Some(finished) = self.is_poll_done { + return Poll::Ready(finished); + } + // This will wake up the TLS remote thread + self.waker.wake_by_ref(); + + loop { + let mut cx = Context::from_waker(context.waker()); + + let req = match Pin::new(&mut self.pending_requests).poll_next(&mut cx) { + Poll::Ready(Some(request)) => request, + Poll::Ready(None) => { + return Poll::Ready(Err(crate::transport::Error::INTERNAL_ERROR + .with_reason("offloaded crypto session finished without sending Done"))) + } + Poll::Pending => break, + }; + + match req { + Request::HandshakeKeys(key, header_key) => { + context.on_handshake_keys(key, header_key)?; + } + Request::ZeroRttKeys(key, header_key, transport_parameters) => { + context.on_zero_rtt_keys( + key, + header_key, + tls::ApplicationParameters { + transport_parameters: &transport_parameters, + }, + )?; + } + Request::ClientParams(client_params, mut server_params) => context + .on_client_application_params( + tls::ApplicationParameters { + transport_parameters: &client_params, + }, + &mut server_params, + )?, + Request::OneRttKeys(key, header_key, transport_parameters) => { + context.on_one_rtt_keys( + key, + header_key, + tls::ApplicationParameters { + transport_parameters: &transport_parameters, + }, + )?; + } + Request::Done(session, res) => { + self.inner = Some(session); + self.is_poll_done = Some(res); + + return Poll::Ready(res); + } + Request::ServerName(server_name) => { + context.on_server_name(server_name)?; + } + Request::ApplicationProtocol(application_protocol) => { + context.on_application_protocol(application_protocol)?; + } + Request::HandshakeComplete => { + context.on_handshake_complete()?; + } + Request::CanSendInitial(sender) => { + let _ = sender.send(context.can_send_initial()); + } + Request::ReceiveInitial(max_len, sender) => { + let resp = context.receive_initial(max_len); + let _ = sender.send(resp); + } + Request::ReceiveApplication(max_len, sender) => { + let resp = context.receive_application(max_len); + let _ = sender.send(resp); + } + Request::ReceiveHandshake(max_len, sender) => { + let resp = context.receive_handshake(max_len); + if resp.is_some() { + // We need to wake up the s2n-quic endpoint after providing + // handshake packets to the TLS provider as there may now be + // handshake data that needs to be sent in response. + context.waker().wake_by_ref(); + } + let _ = sender.send(resp); + } + Request::CanSendHandshake(sender) => { + let _ = sender.send(context.can_send_handshake()); + } + Request::CanSendApplication(sender) => { + let _ = sender.send(context.can_send_application()); + } + Request::SendApplication(bytes) => { + context.send_application(bytes); + } + Request::SendHandshake(bytes) => { + context.send_handshake(bytes); + } + Request::SendInitial(bytes) => context.send_initial(bytes), + Request::KeyExchangeGroup(named_group) => { + context.on_key_exchange_group(named_group)?; + } + Request::TlsContext(ctx) => context.on_tls_context(ctx), + } + } + Poll::Pending + } +} + +impl CryptoSuite for OffloadSession { + type HandshakeKey = ::HandshakeKey; + type HandshakeHeaderKey = ::HandshakeHeaderKey; + type InitialKey = ::InitialKey; + type InitialHeaderKey = ::InitialHeaderKey; + type ZeroRttKey = ::ZeroRttKey; + type ZeroRttHeaderKey = ::ZeroRttHeaderKey; + type OneRttKey = ::OneRttKey; + type OneRttHeaderKey = ::OneRttHeaderKey; + type RetryKey = ::RetryKey; +} + +struct AsyncRequest { + rx: Option>, +} + +impl AsyncRequest { + fn empty() -> Self { + AsyncRequest { rx: None } + } + + fn poll_request( + &mut self, + cx: &mut core::task::Context<'_>, + issue: impl FnOnce(Sender), + ) -> Poll { + loop { + if let Some(mut receiver) = self.rx.as_mut() { + match Pin::new(&mut receiver).poll(cx) { + Poll::Ready(Ok(value)) => { + receiver.close(); + self.rx = None; + return Poll::Ready(value); + } + Poll::Ready(Err(_)) => { + // treat cancellation as reason to ask again. + // FIXME: this probably means that the parent thread is no longer interested + // in this connection and we should instead tear it down. + receiver.close(); + self.rx = None; + // loop around to next loop iteration + } + Poll::Pending => return Poll::Pending, + } + } else { + let (tx, rx) = futures_channel::oneshot::channel(); + self.rx = Some(rx); + issue(tx); + return Poll::Pending; + } + } + } +} + +/// Context used on the remote thread. This must delegate all methods via a channel to the calling +/// thread, using `Request` to send parameters (and optionally receive results). +struct RemoteContext { + tx: UnboundedSender>, + waker: Waker, + + receive_initial: AsyncRequest>, + receive_handshake: AsyncRequest>, + receive_application: AsyncRequest>, + + can_send_initial: AsyncRequest, + can_send_handshake: AsyncRequest, + can_send_application: AsyncRequest, +} + +impl tls::Context for RemoteContext { + fn on_client_application_params( + &mut self, + client_params: tls::ApplicationParameters, + server_params: &mut alloc::vec::Vec, + ) -> Result<(), crate::transport::Error> { + let _ = self.tx.unbounded_send(Request::ClientParams( + client_params.transport_parameters.to_vec(), + server_params.to_vec(), + )); + Ok(()) + } + + fn on_handshake_keys( + &mut self, + key: ::HandshakeKey, + header_key: ::HandshakeHeaderKey, + ) -> Result<(), crate::transport::Error> { + let _ = self + .tx + .unbounded_send(Request::HandshakeKeys(key, header_key)); + Ok(()) + } + + fn on_zero_rtt_keys( + &mut self, + key: ::ZeroRttKey, + header_key: ::ZeroRttHeaderKey, + application_parameters: tls::ApplicationParameters, + ) -> Result<(), crate::transport::Error> { + let _ = self.tx.unbounded_send(Request::ZeroRttKeys( + key, + header_key, + application_parameters.transport_parameters.to_vec(), + )); + Ok(()) + } + + fn on_one_rtt_keys( + &mut self, + key: ::OneRttKey, + header_key: ::OneRttHeaderKey, + application_parameters: tls::ApplicationParameters, + ) -> Result<(), crate::transport::Error> { + let _ = self.tx.unbounded_send(Request::OneRttKeys( + key, + header_key, + application_parameters.transport_parameters.to_vec(), + )); + Ok(()) + } + + fn on_server_name( + &mut self, + server_name: crate::application::ServerName, + ) -> Result<(), crate::transport::Error> { + let _ = self.tx.unbounded_send(Request::ServerName(server_name)); + Ok(()) + } + + fn on_application_protocol( + &mut self, + application_protocol: bytes::Bytes, + ) -> Result<(), crate::transport::Error> { + let _ = self + .tx + .unbounded_send(Request::ApplicationProtocol(application_protocol)); + Ok(()) + } + + fn on_key_exchange_group( + &mut self, + named_group: tls::NamedGroup, + ) -> Result<(), crate::transport::Error> { + let _ = self + .tx + .unbounded_send(Request::KeyExchangeGroup(named_group)); + Ok(()) + } + + fn on_handshake_complete(&mut self) -> Result<(), crate::transport::Error> { + let _ = self.tx.unbounded_send(Request::HandshakeComplete); + Ok(()) + } + + fn on_tls_context(&mut self, context: alloc::boxed::Box) { + let _ = self.tx.unbounded_send(Request::TlsContext(context)); + } + + fn on_tls_exporter_ready( + &mut self, + _session: &impl tls::TlsSession, + ) -> Result<(), crate::transport::Error> { + // FIXME: needs some form of async callback, or maybe never gets called during remote phase? + Ok(()) + } + + fn receive_initial(&mut self, max_len: Option) -> Option { + let mut cx = Context::from_waker(&self.waker); + if let Poll::Ready(resp) = self.receive_initial.poll_request(&mut cx, |tx| { + let _ = self.tx.unbounded_send(Request::ReceiveInitial(max_len, tx)); + }) { + resp + } else { + None + } + } + + fn receive_handshake(&mut self, max_len: Option) -> Option { + let mut cx = Context::from_waker(&self.waker); + if let Poll::Ready(resp) = self.receive_handshake.poll_request(&mut cx, |tx| { + let _ = self + .tx + .unbounded_send(Request::ReceiveHandshake(max_len, tx)); + }) { + resp + } else { + None + } + } + + fn receive_application(&mut self, max_len: Option) -> Option { + let mut cx = Context::from_waker(&self.waker); + if let Poll::Ready(resp) = self.receive_application.poll_request(&mut cx, |tx| { + let _ = self + .tx + .unbounded_send(Request::ReceiveApplication(max_len, tx)); + }) { + resp + } else { + None + } + } + + fn can_send_initial(&mut self) -> bool { + let mut cx = Context::from_waker(&self.waker); + if let Poll::Ready(resp) = self.can_send_initial.poll_request(&mut cx, |tx| { + let _ = self.tx.unbounded_send(Request::CanSendInitial(tx)); + }) { + resp + } else { + // FIXME: either async-ify, remove, or figure out what the Pending value should be. + false + } + } + + fn send_initial(&mut self, transmission: bytes::Bytes) { + let _ = self.tx.unbounded_send(Request::SendInitial(transmission)); + } + + fn can_send_handshake(&mut self) -> bool { + let mut cx = Context::from_waker(&self.waker); + if let Poll::Ready(resp) = self.can_send_handshake.poll_request(&mut cx, |tx| { + let _ = self.tx.unbounded_send(Request::CanSendHandshake(tx)); + }) { + resp + } else { + // FIXME: either async-ify, remove, or figure out what the Pending value should be. + false + } + } + + fn send_handshake(&mut self, transmission: bytes::Bytes) { + let _ = self.tx.unbounded_send(Request::SendHandshake(transmission)); + } + + fn can_send_application(&mut self) -> bool { + let mut cx = Context::from_waker(&self.waker); + if let Poll::Ready(resp) = self.can_send_application.poll_request(&mut cx, |tx| { + let _ = self.tx.unbounded_send(Request::CanSendApplication(tx)); + }) { + resp + } else { + // FIXME: either async-ify, remove, or figure out what the Pending value should be. + false + } + } + + fn send_application(&mut self, transmission: bytes::Bytes) { + let _ = self + .tx + .unbounded_send(Request::SendApplication(transmission)); + } + + fn waker(&self) -> &core::task::Waker { + &self.waker + } +} + +enum Request { + ClientParams(Vec, Vec), + HandshakeKeys( + ::HandshakeKey, + ::HandshakeHeaderKey, + ), + ZeroRttKeys( + ::ZeroRttKey, + ::ZeroRttHeaderKey, + Vec, + ), + OneRttKeys( + ::OneRttKey, + ::OneRttHeaderKey, + Vec, + ), + TlsContext(alloc::boxed::Box), + ServerName(crate::application::ServerName), + ApplicationProtocol(bytes::Bytes), + KeyExchangeGroup(NamedGroup), + HandshakeComplete, + + ReceiveInitial(Option, Sender>), + ReceiveApplication(Option, Sender>), + ReceiveHandshake(Option, Sender>), + CanSendInitial(Sender), + CanSendHandshake(Sender), + CanSendApplication(Sender), + SendApplication(bytes::Bytes), + SendHandshake(bytes::Bytes), + SendInitial(bytes::Bytes), + Done(S, Result<(), crate::transport::Error>), +} diff --git a/quic/s2n-quic-core/src/crypto/tls/slow_tls.rs b/quic/s2n-quic-core/src/crypto/tls/slow_tls.rs index 773d0befc8..647549753e 100644 --- a/quic/s2n-quic-core/src/crypto/tls/slow_tls.rs +++ b/quic/s2n-quic-core/src/crypto/tls/slow_tls.rs @@ -176,7 +176,7 @@ where self.0.receive_application(max_len) } - fn can_send_initial(&self) -> bool { + fn can_send_initial(&mut self) -> bool { self.0.can_send_initial() } @@ -184,7 +184,7 @@ where self.0.send_initial(transmission); } - fn can_send_handshake(&self) -> bool { + fn can_send_handshake(&mut self) -> bool { self.0.can_send_handshake() } @@ -192,7 +192,7 @@ where self.0.send_handshake(transmission); } - fn can_send_application(&self) -> bool { + fn can_send_application(&mut self) -> bool { self.0.can_send_application() } diff --git a/quic/s2n-quic-core/src/crypto/tls/testing.rs b/quic/s2n-quic-core/src/crypto/tls/testing.rs index 406b8cc5eb..61e884fae6 100644 --- a/quic/s2n-quic-core/src/crypto/tls/testing.rs +++ b/quic/s2n-quic-core/src/crypto/tls/testing.rs @@ -791,7 +791,7 @@ where self.application.rx(max_len) } - fn can_send_initial(&self) -> bool { + fn can_send_initial(&mut self) -> bool { true } @@ -800,7 +800,7 @@ where self.initial.tx(transmission) } - fn can_send_handshake(&self) -> bool { + fn can_send_handshake(&mut self) -> bool { self.handshake.crypto.is_some() } @@ -813,7 +813,7 @@ where self.handshake.tx(transmission) } - fn can_send_application(&self) -> bool { + fn can_send_application(&mut self) -> bool { self.application.crypto.is_some() } diff --git a/quic/s2n-quic-transport/src/space/session_context.rs b/quic/s2n-quic-transport/src/space/session_context.rs index 3dc9f50bfd..c716bb8ecf 100644 --- a/quic/s2n-quic-transport/src/space/session_context.rs +++ b/quic/s2n-quic-transport/src/space/session_context.rs @@ -647,7 +647,7 @@ impl .map(|bytes| bytes.freeze()) } - fn can_send_initial(&self) -> bool { + fn can_send_initial(&mut self) -> bool { self.initial .as_ref() .map(|space| space.crypto_stream.can_send()) @@ -663,7 +663,7 @@ impl .push(transmission); } - fn can_send_handshake(&self) -> bool { + fn can_send_handshake(&mut self) -> bool { self.handshake .as_ref() .map(|space| space.crypto_stream.can_send()) @@ -679,7 +679,7 @@ impl .push(transmission); } - fn can_send_application(&self) -> bool { + fn can_send_application(&mut self) -> bool { self.application .as_ref() .map(|space| space.crypto_stream.can_send()) diff --git a/quic/s2n-quic/Cargo.toml b/quic/s2n-quic/Cargo.toml index eae86e2850..2e0177c27d 100644 --- a/quic/s2n-quic/Cargo.toml +++ b/quic/s2n-quic/Cargo.toml @@ -61,6 +61,8 @@ unstable-congestion-controller = ["s2n-quic-core/unstable-congestion-controller" unstable-limits = ["s2n-quic-core/unstable-limits"] # The feature enables the close formatter provider unstable-provider-connection-close-formatter = [] +# This feature enables the use of the offloaded TLS feature +unstable-offload-tls = [] [dependencies] bytes = { version = "1", default-features = false } diff --git a/quic/s2n-quic/src/provider/tls.rs b/quic/s2n-quic/src/provider/tls.rs index 978fc26021..0b04a870ec 100644 --- a/quic/s2n-quic/src/provider/tls.rs +++ b/quic/s2n-quic/src/provider/tls.rs @@ -309,3 +309,25 @@ pub mod s2n_tls { } } } + +#[cfg(feature = "unstable-offload-tls")] +pub mod offload { + use super::Provider; + use s2n_quic_core::crypto::tls::offload::OffloadEndpoint; + + pub struct Offload(pub E); + + impl Provider for Offload { + type Server = OffloadEndpoint<::Server>; + type Client = OffloadEndpoint<::Client>; + type Error = E::Error; + + fn start_server(self) -> Result { + Ok(OffloadEndpoint::new(E::start_server(self.0)?)) + } + + fn start_client(self) -> Result { + Ok(OffloadEndpoint::new(E::start_client(self.0)?)) + } + } +} diff --git a/quic/s2n-quic/src/tests.rs b/quic/s2n-quic/src/tests.rs index ba12f6f1f1..1dc024f8f2 100644 --- a/quic/s2n-quic/src/tests.rs +++ b/quic/s2n-quic/src/tests.rs @@ -25,6 +25,7 @@ use std::{ mod recorder; mod connection_limits; +mod offload_tls; mod resumption; mod setup; mod slow_tls; diff --git a/quic/s2n-quic/src/tests/offload_tls.rs b/quic/s2n-quic/src/tests/offload_tls.rs new file mode 100644 index 0000000000..a40014db44 --- /dev/null +++ b/quic/s2n-quic/src/tests/offload_tls.rs @@ -0,0 +1,43 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +#[test] +#[cfg(feature = "unstable-offload-tls")] +fn offload_tls() { + use super::*; + use crate::provider::tls::{default, offload::Offload}; + use s2n_quic_core::crypto::tls::testing::certificates::{CERT_PEM, KEY_PEM}; + + let model = Model::default(); + + let server_endpoint = default::Server::builder() + .with_certificate(CERT_PEM, KEY_PEM) + .unwrap() + .build() + .unwrap(); + let client_endpoint = default::Client::builder() + .with_certificate(CERT_PEM) + .unwrap() + .build() + .unwrap(); + let server_endpoint = Offload(server_endpoint); + let client_endpoint = Offload(client_endpoint); + test(model, |handle| { + let server = Server::builder() + .with_io(handle.builder().build()?)? + .with_event(tracing_events())? + .with_tls(server_endpoint)? + .start()?; + + let client = Client::builder() + .with_io(handle.builder().build()?)? + .with_tls(client_endpoint)? + .with_event(tracing_events())? + .start()?; + let addr = start_server(server)?; + start_client(client, addr, Data::new(1000))?; + + Ok(addr) + }) + .unwrap(); +}