Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bindings] Hide ffi types + basic debug info #3279

Merged
merged 2 commits into from
Apr 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion bindings/rust/s2n-tls-tokio/examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ async fn run_client(trust_pem: &[u8], addr: &str) -> Result<(), Box<dyn Error>>

// Connect to the server.
let stream = TcpStream::connect(addr).await?;
client.connect("localhost", stream).await?;
let tls = client.connect("localhost", stream).await?;
println!("{:#?}", tls);

// TODO: echo

Expand Down
3 changes: 2 additions & 1 deletion bindings/rust/s2n-tls-tokio/examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ async fn run_server(cert_pem: &[u8], key_pem: &[u8], addr: &str) -> Result<(), B
// Wait for a client to connect.
let (stream, peer_addr) = listener.accept().await?;
println!("Connection from {:?}", peer_addr);
server.accept(stream).await?;
let tls = server.accept(stream).await?;
println!("{:#?}", tls);

// TODO: echo
}
Expand Down
24 changes: 16 additions & 8 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
use errno::{set_errno, Errno};
use s2n_tls::raw::{
config::Config,
connection::Connection,
connection::{CallbackResult, Connection, Mode},
error::Error,
ffi::{s2n_mode, s2n_status_code},
};
use std::{
fmt,
future::Future,
os::raw::{c_int, c_void},
pin::Pin,
Expand All @@ -29,7 +29,7 @@ impl TlsAcceptor {
where
S: AsyncRead + AsyncWrite + Unpin,
{
TlsStream::open(self.config.clone(), s2n_mode::SERVER, stream).await
TlsStream::open(self.config.clone(), Mode::Server, stream).await
}
}

Expand All @@ -46,7 +46,7 @@ impl TlsConnector {
where
S: AsyncRead + AsyncWrite + Unpin,
{
TlsStream::open(self.config.clone(), s2n_mode::CLIENT, stream).await
TlsStream::open(self.config.clone(), Mode::Client, stream).await
}
}

Expand All @@ -69,15 +69,15 @@ where
}

pub struct TlsStream<S> {
conn: Connection,
pub conn: Connection,
stream: S,
}

impl<S> TlsStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
{
async fn open(config: Config, mode: s2n_mode::Type, stream: S) -> Result<Self, Error> {
async fn open(config: Config, mode: Mode, stream: S) -> Result<Self, Error> {
let mut conn = Connection::new(mode);
conn.set_config(config)?;

Expand Down Expand Up @@ -126,9 +126,9 @@ where
Poll::Ready(Ok(len)) => len as c_int,
Poll::Pending => {
set_errno(Errno(libc::EWOULDBLOCK));
s2n_status_code::FAILURE
CallbackResult::Failure.into()
Copy link
Contributor

@toidiu toidiu Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just an observation: this seems alittle odd since we dont use the Success type and instead return len as c_int. Also it would be nice to have poll_io return CallbackResult here and have the downstream function handle the conversion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also wondering if CallbackResult should instead be called StatusCode so that its reusable for non callback types

Copy link
Contributor Author

@lrstewart lrstewart Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not trying to solve callbacks here-- that's for later ;) I'm just trying to hide the C types.

just an observation: this seems alittle odd since we dont use the Success type and instead return len as c_int

The interface of some s2n-tls callbacks is a little odd :( We can't use the Success value because success for this callback means a valid length, not a success status code.

it would be nice to have poll_io return CallbackResult here and have the downstream function handle the conversion.

There isn't really a downstream function here to handle the conversion-- this is a callback that will return to C code that doesn't understand Rust enums.

also wondering if CallbackResult should instead be called StatusCode so that its reusable for non callback types

I choose "CallbackResult" because a customer's only interaction with this enum should be writing callbacks, since we handle the status codes returned by s2n-tls APIs. And I would love to remove that interaction too xD

}
_ => s2n_status_code::FAILURE,
_ => CallbackResult::Failure.into(),
}
}

Expand All @@ -148,3 +148,11 @@ where
})
}
}

impl<S> fmt::Debug for TlsStream<S> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("TlsStream")
.field("connection", &self.conn)
.finish()
}
}
72 changes: 45 additions & 27 deletions bindings/rust/s2n-tls-tokio/tests/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use s2n_tls::raw::{config::Config, error::Error, security::DEFAULT_TLS13};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector};
use s2n_tls::raw::{config::Config, connection::Version, error::Error, security::DEFAULT_TLS13};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream};
use tokio::net::{TcpListener, TcpStream};

/// NOTE: this certificate and key are used for testing purposes only!
Expand All @@ -15,37 +15,55 @@ pub static KEY_PEM: &[u8] = include_bytes!(concat!(
"/examples/certs/key.pem"
));

async fn run_client(stream: TcpStream) -> Result<(), Error> {
let mut config = Config::builder();
config.set_security_policy(&DEFAULT_TLS13)?;
config.trust_pem(CERT_PEM)?;
unsafe {
config.disable_x509_verification()?;
}

let client = TlsConnector::new(config.build()?);
client.connect("localhost", stream).await?;
Ok(())
async fn get_streams() -> Result<(TcpStream, TcpStream), tokio::io::Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await?;
let addr = listener.local_addr()?;
let client_stream = TcpStream::connect(&addr).await?;
let (server_stream, _) = listener.accept().await?;
Ok((server_stream, client_stream))
}

async fn run_server(stream: TcpStream) -> Result<(), Error> {
let mut config = Config::builder();
config.set_security_policy(&DEFAULT_TLS13)?;
config.load_pem(CERT_PEM, KEY_PEM)?;
async fn run_client(config: Config, stream: TcpStream) -> Result<TlsStream<TcpStream>, Error> {
let client = TlsConnector::new(config);
client.connect("localhost", stream).await
}

let server = TlsAcceptor::new(config.build()?);
server.accept(stream).await?;
Ok(())
async fn run_server(config: Config, stream: TcpStream) -> Result<TlsStream<TcpStream>, Error> {
let server = TlsAcceptor::new(config);
server.accept(stream).await
}

#[tokio::test]
async fn handshake_basic() -> Result<(), Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await.unwrap();
let addr = listener.local_addr().unwrap();
let client_stream = TcpStream::connect(&addr).await.unwrap();
let (server_stream, _) = listener.accept().await.unwrap();
async fn handshake_basic() -> Result<(), Box<dyn std::error::Error>> {
let (server_stream, client_stream) = get_streams().await?;

let mut client_config = Config::builder();
client_config.set_security_policy(&DEFAULT_TLS13)?;
client_config.trust_pem(CERT_PEM)?;
unsafe {
client_config.disable_x509_verification()?;
}
let client_config = client_config.build()?;

let mut server_config = Config::builder();
server_config.set_security_policy(&DEFAULT_TLS13)?;
server_config.load_pem(CERT_PEM, KEY_PEM)?;
let server_config = server_config.build()?;

let (client_result, server_result) = tokio::try_join!(
run_client(client_config, client_stream),
run_server(server_config, server_stream)
)?;

for tls in [client_result, server_result] {
// Security policy ensures TLS1.3.
assert_eq!(tls.conn.actual_protocol_version()?, Version::TLS13);
// Handshake types may change, but will at least be negotiated.
assert!(tls.conn.handshake_type()?.contains("NEGOTIATED"));
// Cipher suite may change, so just makes sure we can retrieve it.
assert!(tls.conn.cipher_suite().is_ok());
}

tokio::try_join!(run_client(client_stream), run_server(server_stream))?;
Ok(())
}
118 changes: 108 additions & 10 deletions bindings/rust/s2n-tls/src/raw/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::raw::{
security,
};
use core::{
convert::TryInto,
convert::{TryFrom, TryInto},
fmt,
ptr::NonNull,
task::{Poll, Waker},
Expand All @@ -18,17 +18,91 @@ use libc::c_void;
use s2n_tls_sys::*;
use std::{ffi::CStr, mem};

pub use s2n_tls_sys::s2n_mode;
use s2n_tls_sys::s2n_mode;

macro_rules! static_const_str {
($c_chars:expr) => {
unsafe { CStr::from_ptr($c_chars) }
.to_str()
.map_err(|_| Error::InvalidInput)
};
}

#[derive(Debug, PartialEq)]
pub enum CallbackResult {
Success,
Failure,
}

impl From<CallbackResult> for s2n_status_code::Type {
fn from(input: CallbackResult) -> s2n_status_code::Type {
match input {
CallbackResult::Success => s2n_status_code::SUCCESS,
CallbackResult::Failure => s2n_status_code::FAILURE,
}
}
}

#[derive(Debug, PartialEq)]
pub enum Mode {
Server,
Client,
}

impl From<Mode> for s2n_mode::Type {
fn from(input: Mode) -> s2n_mode::Type {
match input {
Mode::Server => s2n_mode::SERVER,
Mode::Client => s2n_mode::CLIENT,
}
}
}

#[non_exhaustive]
#[derive(Debug, PartialEq)]
pub enum Version {
SSLV2,
SSLV3,
TLS10,
TLS11,
TLS12,
TLS13,
}

impl TryFrom<s2n_tls_version::Type> for Version {
type Error = Error;

fn try_from(input: s2n_tls_version::Type) -> Result<Self, Self::Error> {
let version = match input {
s2n_tls_version::SSLV2 => Self::SSLV2,
s2n_tls_version::SSLV3 => Self::SSLV3,
s2n_tls_version::TLS10 => Self::TLS10,
s2n_tls_version::TLS11 => Self::TLS11,
s2n_tls_version::TLS12 => Self::TLS12,
s2n_tls_version::TLS13 => Self::TLS13,
_ => return Err(Error::InvalidInput),
};
Ok(version)
}
}

pub struct Connection {
connection: NonNull<s2n_connection>,
}

impl fmt::Debug for Connection {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Connection")
// TODO add paths
.finish()
let mut debug = f.debug_struct("Connection");
if let Ok(handshake) = self.handshake_type() {
debug.field("handshake_type", &handshake);
}
if let Ok(cipher) = self.cipher_suite() {
debug.field("cipher_suite", &cipher);
}
if let Ok(version) = self.actual_protocol_version() {
debug.field("actual_protocol_version", &version);
}
debug.finish_non_exhaustive()
}
}

Expand All @@ -38,9 +112,10 @@ impl fmt::Debug for Connection {
unsafe impl Send for Connection {}

impl Connection {
pub fn new(mode: s2n_mode::Type) -> Self {
pub fn new(mode: Mode) -> Self {
crate::raw::init::init();
let connection = unsafe { s2n_connection_new(mode).into_result() }.unwrap();

let connection = unsafe { s2n_connection_new(mode.into()).into_result() }.unwrap();

unsafe {
debug_assert! {
Expand All @@ -62,11 +137,11 @@ impl Connection {
}

pub fn new_client() -> Self {
Self::new(s2n_mode::CLIENT)
Self::new(Mode::Client)
}

pub fn new_server() -> Self {
Self::new(s2n_mode::SERVER)
Self::new(Mode::Server)
}

/// # Safety
Expand Down Expand Up @@ -289,7 +364,7 @@ impl Connection {

match unsafe { s2n_negotiate(self.connection.as_ptr(), &mut blocked).into_result() } {
Ok(_) => Ok(self).into(),
Err(err) if err.kind() == s2n_error_type::BLOCKED => Poll::Pending,
Err(err) if err.is_retryable() => Poll::Pending,
Err(err) => Err(err).into(),
}
}
Expand Down Expand Up @@ -389,6 +464,29 @@ impl Connection {
};
Ok(())
}

pub fn actual_protocol_version(&self) -> Result<Version, Error> {
let version = unsafe {
s2n_connection_get_actual_protocol_version(self.connection.as_ptr()).into_result()?
};
version.try_into()
}

pub fn handshake_type(&self) -> Result<&str, Error> {
let handshake = unsafe {
s2n_connection_get_handshake_type_name(self.connection.as_ptr()).into_result()?
};
// The strings returned by s2n_connection_get_handshake_type_name
// are static and immutable after they are first calculated
static_const_str!(handshake)
}

pub fn cipher_suite(&self) -> Result<&str, Error> {
let cipher = unsafe { s2n_connection_get_cipher(self.connection.as_ptr()).into_result()? };
// The strings returned by s2n_connection_get_cipher
// are static and immutable since they are const fields on static const structs
static_const_str!(cipher)
}
}

#[derive(Default)]
Expand Down
Loading