Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(bindings): Associate an application context with a Connection #4563

Merged
merged 5 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 109 additions & 1 deletion bindings/rust/s2n-tls/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use core::{
};
use libc::c_void;
use s2n_tls_sys::*;
use std::ffi::CStr;
use std::{any::Any, ffi::CStr};

mod builder;
pub use builder::*;
Expand Down Expand Up @@ -1049,6 +1049,45 @@ impl Connection {
pub fn resumed(&self) -> bool {
unsafe { s2n_connection_is_session_resumed(self.connection.as_ptr()) == 1 }
}

/// Associates an arbitrary application context with the Connection to be later retrieved via
/// the [`Self::application_context()`] and [`Self::application_context_mut()`] APIs.
///
/// This API will override an existing application context set on the Connection.
pub fn set_application_context<T: Send + Sync + 'static>(&mut self, app_context: T) {
self.context_mut().app_context = Some(Box::new(app_context));
}

/// Retrieves a reference to the application context associated with the Connection.
///
/// If an application context hasn't already been set on the Connection, or if the set
/// application context isn't of type T, None will be returned.
///
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve a
/// mutable reference to the context, use [`Self::application_context_mut()`].
pub fn application_context<T: Send + Sync + 'static>(&self) -> Option<&T> {
match self.context().app_context.as_ref() {
None => None,
// The Any trait keeps track of the application context's type. downcast_ref() returns
// Some only if the correct type is provided:
// https://doc.rust-lang.org/std/any/trait.Any.html#method.downcast_ref
Some(app_context) => app_context.downcast_ref::<T>(),
}
}

/// Retrieves a mutable reference to the application context associated with the Connection.
///
/// If an application context hasn't already been set on the Connection, or if the set
/// application context isn't of type T, None will be returned.
///
/// To set a context on the connection, use [`Self::set_application_context()`]. To retrieve an
/// immutable reference to the context, use [`Self::application_context()`].
pub fn application_context_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
match self.context_mut().app_context.as_mut() {
None => None,
Some(app_context) => app_context.downcast_mut::<T>(),
}
}
}

struct Context {
Expand All @@ -1057,6 +1096,7 @@ struct Context {
async_callback: Option<AsyncCallback>,
verify_host_callback: Option<Box<dyn VerifyHostNameCallback>>,
connection_initialized: bool,
app_context: Option<Box<dyn Any + Send + Sync>>,
}

impl Context {
Expand All @@ -1067,6 +1107,7 @@ impl Context {
async_callback: None,
verify_host_callback: None,
connection_initialized: false,
app_context: None,
}
}
}
Expand Down Expand Up @@ -1181,4 +1222,71 @@ mod tests {
fn assert_sync<T: 'static + Sync>() {}
assert_sync::<Context>();
}

/// Test that an application context can be set and retrieved.
#[test]
fn test_app_context_set_and_retrieve() {
let mut connection = Connection::new_server();

// Before a context is set, None is returned.
assert!(connection.application_context::<u32>().is_none());

let test_value: u32 = 1142;
connection.set_application_context(test_value);

// After a context is set, the application data is returned.
assert_eq!(*connection.application_context::<u32>().unwrap(), 1142);
}

/// Test that an application context can be modified.
#[test]
fn test_app_context_modify() {
let test_value: u64 = 0;

let mut connection = Connection::new_server();
connection.set_application_context(test_value);

let context_value = connection.application_context_mut::<u64>().unwrap();
*context_value += 1;

assert_eq!(*connection.application_context::<u64>().unwrap(), 1);
}

/// Test that an application context can be overridden.
#[test]
fn test_app_context_override() {
let mut connection = Connection::new_server();

let test_value: u16 = 1142;
connection.set_application_context(test_value);

assert_eq!(*connection.application_context::<u16>().unwrap(), 1142);

// Override the context with a new value.
let test_value: u16 = 10;
connection.set_application_context(test_value);

assert_eq!(*connection.application_context::<u16>().unwrap(), 10);

// Override the context with a new type.
let test_value: i16 = -20;
connection.set_application_context(test_value);

assert_eq!(*connection.application_context::<i16>().unwrap(), -20);
}

/// Test that a context of another type can't be retrieved.
#[test]
fn test_app_context_invalid_type() {
let mut connection = Connection::new_server();

let test_value: u32 = 0;
connection.set_application_context(test_value);

// A context type that wasn't set shouldn't be returned.
assert!(connection.application_context::<i16>().is_none());

// Retrieving the correct type succeeds.
assert!(connection.application_context::<u32>().is_some());
}
}
62 changes: 61 additions & 1 deletion bindings/rust/s2n-tls/src/testing/s2n_tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ impl<'a, T: 'a + Context> Callback<'a, T> {
#[cfg(test)]
mod tests {
use crate::{
callbacks::{ClientHelloCallback, ConnectionFuture},
callbacks::{ClientHelloCallback, ConnectionFuture, ConnectionFutureResult},
enums::ClientAuthType,
error::ErrorType,
testing::{client_hello::*, s2n_tls::*, *},
Expand Down Expand Up @@ -970,4 +970,64 @@ mod tests {
init::init();
assert!(init::fips_mode().unwrap().is_enabled());
}

/// Test that a context can be used from within a callback.
#[test]
fn test_app_context_callback() {
struct TestApplicationContext {
invoked_count: u32,
}

struct TestClientHelloHandler {}

impl ClientHelloCallback for TestClientHelloHandler {
fn on_client_hello(
&self,
connection: &mut connection::Connection,
) -> ConnectionFutureResult {
let app_context = connection
.application_context_mut::<TestApplicationContext>()
.unwrap();
app_context.invoked_count += 1;
Ok(None)
}
}

let config = {
let keypair = CertKeyPair::default();
let mut builder = Builder::new();
builder
.set_verify_host_callback(InsecureAcceptAllCertificatesHandler {})
.unwrap();
builder
.set_client_hello_callback(TestClientHelloHandler {})
.unwrap();
builder.load_pem(keypair.cert, keypair.key).unwrap();
builder.trust_pem(keypair.cert).unwrap();
builder.build().unwrap()
};

let mut pair = tls_pair(config);
pair.server
.0
.connection_mut()
.set_waker(Some(&noop_waker()))
.unwrap();

let context = TestApplicationContext { invoked_count: 0 };
pair.server
.0
.connection_mut()
.set_application_context(context);

assert!(poll_tls_pair_result(&mut pair).is_ok());

let context = pair
.server
.0
.connection()
.application_context::<TestApplicationContext>()
.unwrap();
assert_eq!(context.invoked_count, 1);
}
}
Loading