Skip to content
Merged
14 changes: 10 additions & 4 deletions quic/s2n-quic-core/src/crypto/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,6 @@ pub mod testing;
/// Holds all application parameters which are exchanged within the TLS handshake.
#[derive(Debug)]
pub struct ApplicationParameters<'a> {
/// The negotiated Application Layer Protocol
pub application_protocol: &'a [u8],
/// Server Name Indication
pub server_name: Option<crate::application::ServerName>,
/// Encoded transport parameters
pub transport_parameters: &'a [u8],
}
Expand Down Expand Up @@ -53,6 +49,16 @@ pub trait Context<Crypto: CryptoSuite> {
application_parameters: ApplicationParameters,
) -> Result<(), transport::Error>;

fn on_server_name(
&mut self,
server_name: crate::application::ServerName,
) -> Result<(), transport::Error>;

fn on_application_protocol(
&mut self,
application_protocol: Bytes,
) -> Result<(), transport::Error>;

//= https://www.rfc-editor.org/rfc/rfc9001#section-4.1.1
//# The TLS handshake is considered complete when the
//# TLS stack has reported that the handshake is complete. This happens
Expand Down
90 changes: 60 additions & 30 deletions quic/s2n-quic-core/src/crypto/tls/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
header_crypto::{LONG_HEADER_MASK, SHORT_HEADER_MASK},
tls, CryptoSuite, HeaderKey, Key,
},
transport,
endpoint, transport,
};
use bytes::Bytes;
use core::{
Expand Down Expand Up @@ -113,12 +113,12 @@ impl<S: tls::Session, C: tls::Session> Pair<S, C> {
use crate::crypto::InitialKey;

let server = server_endpoint.new_server_session(&TEST_SERVER_TRANSPORT_PARAMS);
let mut server_context = Context::default();
let mut server_context = Context::new(endpoint::Type::Server);
server_context.initial.crypto = Some(S::InitialKey::new_server(server_name.as_bytes()));

let client =
client_endpoint.new_client_session(&TEST_CLIENT_TRANSPORT_PARAMS, server_name.clone());
let mut client_context = Context::default();
let mut client_context = Context::new(endpoint::Type::Client);
client_context.initial.crypto = Some(C::InitialKey::new_client(server_name.as_bytes()));

Self {
Expand Down Expand Up @@ -215,8 +215,14 @@ impl<S: tls::Session, C: tls::Session> Pair<S, C> {
TEST_CLIENT_TRANSPORT_PARAMS,
"server did not receive the client transport parameters"
);
// TODO fix sni bug in s2n-quic-rustls
// assert_eq!(self.client.1.server_name.as_ref().expect("missing SNI on client"), &self.server_name[..]);
assert_eq!(
self.client
.1
.server_name
.as_ref()
.expect("missing SNI on client"),
&self.server_name[..]
);
assert_eq!(
self.server
.1
Expand All @@ -240,26 +246,10 @@ pub struct Context<C: CryptoSuite> {
pub server_name: Option<Bytes>,
pub application_protocol: Option<Bytes>,
pub transport_parameters: Option<Bytes>,
endpoint: endpoint::Type,
waker: Waker,
}

impl<C: CryptoSuite> Default for Context<C> {
fn default() -> Self {
let (waker, _wake_counter) = new_count_waker();
Self {
initial: Space::default(),
handshake: Space::default(),
application: Space::default(),
zero_rtt_crypto: None,
handshake_complete: false,
server_name: None,
application_protocol: None,
transport_parameters: None,
waker,
}
}
}

impl<C: CryptoSuite> fmt::Debug for Context<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Context")
Expand All @@ -276,6 +266,22 @@ impl<C: CryptoSuite> fmt::Debug for Context<C> {
}

impl<C: CryptoSuite> Context<C> {
fn new(endpoint: endpoint::Type) -> Self {
let (waker, _wake_counter) = new_count_waker();
Self {
initial: Space::default(),
handshake: Space::default(),
application: Space::default(),
zero_rtt_crypto: None,
handshake_complete: false,
server_name: None,
application_protocol: None,
transport_parameters: None,
endpoint,
waker,
}
}

/// Transfers incoming and outgoing buffers between two contexts
pub fn transfer<O: CryptoSuite>(&mut self, other: &mut Context<O>) {
self.initial.transfer(&mut other.initial);
Expand All @@ -288,11 +294,10 @@ impl<C: CryptoSuite> Context<C> {
self.assert_done();
other.assert_done();

// TODO fix sni bug in s2n-quic-rustls
//assert_eq!(
// self.sni, other.sni,
// "sni is not consistent between endpoints"
//);
assert_eq!(
self.server_name, other.server_name,
"sni is not consistent between endpoints"
);
assert_eq!(
self.application_protocol, other.application_protocol,
"application_protocol is not consistent between endpoints"
Expand Down Expand Up @@ -322,13 +327,16 @@ impl<C: CryptoSuite> Context<C> {
}

fn on_application_params(&mut self, params: tls::ApplicationParameters) {
self.application_protocol = Some(Bytes::copy_from_slice(params.application_protocol));
self.server_name = params.server_name.map(|sni| sni.into_bytes());
self.transport_parameters = Some(Bytes::copy_from_slice(params.transport_parameters));
}

fn log(&self, event: &str) {
eprintln!("{}: {}", core::any::type_name::<C>(), event);
eprintln!(
"{:?}: {}: {}",
self.endpoint,
core::any::type_name::<C>(),
event,
);
}
}

Expand Down Expand Up @@ -507,11 +515,33 @@ impl<C: CryptoSuite> tls::Context<C> for Context<C> {
Ok(())
}

fn on_server_name(
&mut self,
server_name: crate::application::ServerName,
) -> Result<(), transport::Error> {
self.log("server name");
self.server_name = Some(server_name.into_bytes());
Ok(())
}

fn on_application_protocol(
&mut self,
application_protocol: Bytes,
) -> Result<(), transport::Error> {
self.log("application protocol");
self.application_protocol = Some(application_protocol);
Ok(())
}

fn on_handshake_complete(&mut self) -> Result<(), transport::Error> {
assert!(
!self.handshake_complete,
"handshake complete called multiple times"
);
assert!(
!self.application_protocol.as_ref().unwrap().is_empty(),
"application_protocol is empty at handshake complete"
);
self.handshake_complete = true;
self.log("handshake complete");
Ok(())
Expand Down
1 change: 1 addition & 0 deletions quic/s2n-quic-rustls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ license = "Apache-2.0"
exclude = ["corpus.tar.gz"]

[dependencies]
bytes = { version = "1", default-features = false }
rustls = { version = "0.20", features = ["quic"] }
rustls-pemfile = "0.3"
s2n-codec = { version = "=0.1.0", path = "../../common/s2n-codec", default-features = false }
Expand Down
6 changes: 3 additions & 3 deletions quic/s2n-quic-rustls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ impl tls::Endpoint for Client {
//# Endpoints MUST send the quic_transport_parameters extension;
let transport_parameters = encode_transport_parameters(transport_parameters);

let server_name =
let rustls_server_name =
rustls::ServerName::try_from(server_name.as_ref()).expect("invalid server name");

let session = rustls::ClientConnection::new_quic(
self.config.clone(),
crate::QUIC_VERSION,
server_name,
rustls_server_name,
transport_parameters,
)
.expect("could not create rustls client session");

Session::new(session.into())
Session::new(session.into(), Some(server_name))
}

fn max_tag_length(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion quic/s2n-quic-rustls/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl tls::Endpoint for Server {
)
.expect("could not create rustls server session");

Session::new(session.into())
Session::new(session.into(), None)
}

fn new_client_session<Params: EncoderValue>(
Expand Down
Loading