Skip to content

Commit

Permalink
[bindings] Hide ffi types + basic debug info (#3279)
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart authored Apr 26, 2022
1 parent 8314a96 commit a74a467
Show file tree
Hide file tree
Showing 6 changed files with 235 additions and 62 deletions.
3 changes: 2 additions & 1 deletion bindings/rust/s2n-tls-tokio/examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ async fn run_client(trust_pem: &[u8], addr: &str) -> Result<(), Box<dyn Error>>

// Connect to the server.
let stream = TcpStream::connect(addr).await?;
client.connect("localhost", stream).await?;
let tls = client.connect("localhost", stream).await?;
println!("{:#?}", tls);

// TODO: echo

Expand Down
3 changes: 2 additions & 1 deletion bindings/rust/s2n-tls-tokio/examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ async fn run_server(cert_pem: &[u8], key_pem: &[u8], addr: &str) -> Result<(), B
// Wait for a client to connect.
let (stream, peer_addr) = listener.accept().await?;
println!("Connection from {:?}", peer_addr);
server.accept(stream).await?;
let tls = server.accept(stream).await?;
println!("{:#?}", tls);

// TODO: echo
}
Expand Down
24 changes: 16 additions & 8 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
use errno::{set_errno, Errno};
use s2n_tls::raw::{
config::Config,
connection::Connection,
connection::{CallbackResult, Connection, Mode},
error::Error,
ffi::{s2n_mode, s2n_status_code},
};
use std::{
fmt,
future::Future,
os::raw::{c_int, c_void},
pin::Pin,
Expand All @@ -29,7 +29,7 @@ impl TlsAcceptor {
where
S: AsyncRead + AsyncWrite + Unpin,
{
TlsStream::open(self.config.clone(), s2n_mode::SERVER, stream).await
TlsStream::open(self.config.clone(), Mode::Server, stream).await
}
}

Expand All @@ -46,7 +46,7 @@ impl TlsConnector {
where
S: AsyncRead + AsyncWrite + Unpin,
{
TlsStream::open(self.config.clone(), s2n_mode::CLIENT, stream).await
TlsStream::open(self.config.clone(), Mode::Client, stream).await
}
}

Expand All @@ -69,15 +69,15 @@ where
}

pub struct TlsStream<S> {
conn: Connection,
pub conn: Connection,
stream: S,
}

impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
async fn open(config: Config, mode: s2n_mode::Type, stream: S) -> Result<Self, Error> {
async fn open(config: Config, mode: Mode, stream: S) -> Result<Self, Error> {
let mut conn = Connection::new(mode);
conn.set_config(config)?;

Expand Down Expand Up @@ -126,9 +126,9 @@ where
Poll::Ready(Ok(len)) => len as c_int,
Poll::Pending => {
set_errno(Errno(libc::EWOULDBLOCK));
s2n_status_code::FAILURE
CallbackResult::Failure.into()
}
_ => s2n_status_code::FAILURE,
_ => CallbackResult::Failure.into(),
}
}

Expand All @@ -148,3 +148,11 @@ where
})
}
}

impl<S> fmt::Debug for TlsStream<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("TlsStream")
.field("connection", &self.conn)
.finish()
}
}
72 changes: 45 additions & 27 deletions bindings/rust/s2n-tls-tokio/tests/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use s2n_tls::raw::{config::Config, error::Error, security::DEFAULT_TLS13};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector};
use s2n_tls::raw::{config::Config, connection::Version, error::Error, security::DEFAULT_TLS13};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream};
use tokio::net::{TcpListener, TcpStream};

/// NOTE: this certificate and key are used for testing purposes only!
Expand All @@ -15,37 +15,55 @@ pub static KEY_PEM: &[u8] = include_bytes!(concat!(
"/examples/certs/key.pem"
));

async fn run_client(stream: TcpStream) -> Result<(), Error> {
let mut config = Config::builder();
config.set_security_policy(&DEFAULT_TLS13)?;
config.trust_pem(CERT_PEM)?;
unsafe {
config.disable_x509_verification()?;
}

let client = TlsConnector::new(config.build()?);
client.connect("localhost", stream).await?;
Ok(())
async fn get_streams() -> Result<(TcpStream, TcpStream), tokio::io::Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await?;
let addr = listener.local_addr()?;
let client_stream = TcpStream::connect(&addr).await?;
let (server_stream, _) = listener.accept().await?;
Ok((server_stream, client_stream))
}

async fn run_server(stream: TcpStream) -> Result<(), Error> {
let mut config = Config::builder();
config.set_security_policy(&DEFAULT_TLS13)?;
config.load_pem(CERT_PEM, KEY_PEM)?;
async fn run_client(config: Config, stream: TcpStream) -> Result<TlsStream<TcpStream>, Error> {
let client = TlsConnector::new(config);
client.connect("localhost", stream).await
}

let server = TlsAcceptor::new(config.build()?);
server.accept(stream).await?;
Ok(())
async fn run_server(config: Config, stream: TcpStream) -> Result<TlsStream<TcpStream>, Error> {
let server = TlsAcceptor::new(config);
server.accept(stream).await
}

#[tokio::test]
async fn handshake_basic() -> Result<(), Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(&addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
async fn handshake_basic() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = get_streams().await?;

let mut client_config = Config::builder();
client_config.set_security_policy(&DEFAULT_TLS13)?;
client_config.trust_pem(CERT_PEM)?;
unsafe {
client_config.disable_x509_verification()?;
}
let client_config = client_config.build()?;

let mut server_config = Config::builder();
server_config.set_security_policy(&DEFAULT_TLS13)?;
server_config.load_pem(CERT_PEM, KEY_PEM)?;
let server_config = server_config.build()?;

let (client_result, server_result) = tokio::try_join!(
run_client(client_config, client_stream),
run_server(server_config, server_stream)
)?;

for tls in [client_result, server_result] {
// Security policy ensures TLS1.3.
assert_eq!(tls.conn.actual_protocol_version()?, Version::TLS13);
// Handshake types may change, but will at least be negotiated.
assert!(tls.conn.handshake_type()?.contains("NEGOTIATED"));
// Cipher suite may change, so just makes sure we can retrieve it.
assert!(tls.conn.cipher_suite().is_ok());
}

tokio::try_join!(run_client(client_stream), run_server(server_stream))?;
Ok(())
}
118 changes: 108 additions & 10 deletions bindings/rust/s2n-tls/src/raw/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::raw::{
security,
};
use core::{
convert::TryInto,
convert::{TryFrom, TryInto},
fmt,
ptr::NonNull,
task::{Poll, Waker},
Expand All @@ -18,17 +18,91 @@ use libc::c_void;
use s2n_tls_sys::*;
use std::{ffi::CStr, mem};

pub use s2n_tls_sys::s2n_mode;
use s2n_tls_sys::s2n_mode;

macro_rules! static_const_str {
($c_chars:expr) => {
unsafe { CStr::from_ptr($c_chars) }
.to_str()
.map_err(|_| Error::InvalidInput)
};
}

#[derive(Debug, PartialEq)]
pub enum CallbackResult {
Success,
Failure,
}

impl From<CallbackResult> for s2n_status_code::Type {
fn from(input: CallbackResult) -> s2n_status_code::Type {
match input {
CallbackResult::Success => s2n_status_code::SUCCESS,
CallbackResult::Failure => s2n_status_code::FAILURE,
}
}
}

#[derive(Debug, PartialEq)]
pub enum Mode {
Server,
Client,
}

impl From<Mode> for s2n_mode::Type {
fn from(input: Mode) -> s2n_mode::Type {
match input {
Mode::Server => s2n_mode::SERVER,
Mode::Client => s2n_mode::CLIENT,
}
}
}

#[non_exhaustive]
#[derive(Debug, PartialEq)]
pub enum Version {
SSLV2,
SSLV3,
TLS10,
TLS11,
TLS12,
TLS13,
}

impl TryFrom<s2n_tls_version::Type> for Version {
type Error = Error;

fn try_from(input: s2n_tls_version::Type) -> Result<Self, Self::Error> {
let version = match input {
s2n_tls_version::SSLV2 => Self::SSLV2,
s2n_tls_version::SSLV3 => Self::SSLV3,
s2n_tls_version::TLS10 => Self::TLS10,
s2n_tls_version::TLS11 => Self::TLS11,
s2n_tls_version::TLS12 => Self::TLS12,
s2n_tls_version::TLS13 => Self::TLS13,
_ => return Err(Error::InvalidInput),
};
Ok(version)
}
}

pub struct Connection {
connection: NonNull<s2n_connection>,
}

impl fmt::Debug for Connection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Connection")
// TODO add paths
.finish()
let mut debug = f.debug_struct("Connection");
if let Ok(handshake) = self.handshake_type() {
debug.field("handshake_type", &handshake);
}
if let Ok(cipher) = self.cipher_suite() {
debug.field("cipher_suite", &cipher);
}
if let Ok(version) = self.actual_protocol_version() {
debug.field("actual_protocol_version", &version);
}
debug.finish_non_exhaustive()
}
}

Expand All @@ -38,9 +112,10 @@ impl fmt::Debug for Connection {
unsafe impl Send for Connection {}

impl Connection {
pub fn new(mode: s2n_mode::Type) -> Self {
pub fn new(mode: Mode) -> Self {
crate::raw::init::init();
let connection = unsafe { s2n_connection_new(mode).into_result() }.unwrap();

let connection = unsafe { s2n_connection_new(mode.into()).into_result() }.unwrap();

unsafe {
debug_assert! {
Expand All @@ -62,11 +137,11 @@ impl Connection {
}

pub fn new_client() -> Self {
Self::new(s2n_mode::CLIENT)
Self::new(Mode::Client)
}

pub fn new_server() -> Self {
Self::new(s2n_mode::SERVER)
Self::new(Mode::Server)
}

/// # Safety
Expand Down Expand Up @@ -289,7 +364,7 @@ impl Connection {

match unsafe { s2n_negotiate(self.connection.as_ptr(), &mut blocked).into_result() } {
Ok(_) => Ok(self).into(),
Err(err) if err.kind() == s2n_error_type::BLOCKED => Poll::Pending,
Err(err) if err.is_retryable() => Poll::Pending,
Err(err) => Err(err).into(),
}
}
Expand Down Expand Up @@ -389,6 +464,29 @@ impl Connection {
};
Ok(())
}

pub fn actual_protocol_version(&self) -> Result<Version, Error> {
let version = unsafe {
s2n_connection_get_actual_protocol_version(self.connection.as_ptr()).into_result()?
};
version.try_into()
}

pub fn handshake_type(&self) -> Result<&str, Error> {
let handshake = unsafe {
s2n_connection_get_handshake_type_name(self.connection.as_ptr()).into_result()?
};
// The strings returned by s2n_connection_get_handshake_type_name
// are static and immutable after they are first calculated
static_const_str!(handshake)
}

pub fn cipher_suite(&self) -> Result<&str, Error> {
let cipher = unsafe { s2n_connection_get_cipher(self.connection.as_ptr()).into_result()? };
// The strings returned by s2n_connection_get_cipher
// are static and immutable since they are const fields on static const structs
static_const_str!(cipher)
}
}

#[derive(Default)]
Expand Down
Loading

0 comments on commit a74a467

Please sign in to comment.