Skip to content

Commit

Permalink
[bindings] Apply async blinding (#3356)
Browse files Browse the repository at this point in the history
  • Loading branch information
lrstewart authored Jun 13, 2022
1 parent 664fef5 commit 0459b41
Show file tree
Hide file tree
Showing 10 changed files with 479 additions and 68 deletions.
4 changes: 2 additions & 2 deletions bindings/rust/s2n-tls-tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ default = []
errno = { version = "0.2" }
libc = { version = "0.2" }
s2n-tls = { version = "=0.0.8", path = "../s2n-tls" }
tokio = { version = "1", features = ["net"] }
tokio = { version = "1", features = ["net", "time"] }

[dev-dependencies]
clap = { version = "3.1", features = ["derive"] }
rand = { version = "0.8" }
tokio = { version = "1", features = [ "io-std", "io-util", "macros", "net", "rt-multi-thread", "time"] }
tokio = { version = "1", features = [ "io-std", "io-util", "macros", "net", "rt-multi-thread", "test-util", "time"] }
150 changes: 107 additions & 43 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use errno::{set_errno, Errno};
use s2n_tls::raw::{
config::Config,
connection::{Builder, Connection},
enums::{CallbackResult, Mode},
enums::{Blinding, CallbackResult, Mode},
error::Error,
};
use std::{
Expand All @@ -14,9 +14,24 @@ use std::{
io,
os::raw::{c_int, c_void},
pin::Pin,
task::{Context, Poll},
task::{
Context, Poll,
Poll::{Pending, Ready},
},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
time::{sleep, Duration, Sleep},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

macro_rules! ready {
($x:expr) => {
match $x {
Ready(r) => r,
Pending => return Pending,
}
};
}

#[derive(Clone)]
pub struct TlsAcceptor<B: Builder = Config>
Expand Down Expand Up @@ -79,6 +94,7 @@ where
S: AsyncRead + AsyncWrite + Unpin,
{
tls: &'a mut TlsStream<S, C>,
error: Option<Error>,
}

impl<S, C> Future for TlsHandshake<'_, S, C>
Expand All @@ -89,10 +105,36 @@ where
type Output = Result<(), Error>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
self.tls.with_io(ctx, |context| {
let conn = context.get_mut().as_mut();
conn.negotiate().map(|r| r.map(|_| ()))
})
// Retrieve a result, either from the stored error
// or by polling Connection::negotiate().
// Connection::negotiate() only completes once,
// regardless of how often this method is polled.
let result = match self.error.take() {
Some(err) => Err(err),
None => {
ready!(self.tls.with_io(ctx, |context| {
let conn = context.get_mut().as_mut();
conn.negotiate().map(|r| r.map(|_| ()))
}))
}
};
// If the result isn't a fatal error, return it immediately.
// Otherwise, poll Connection::shutdown().
//
// Shutdown is only best-effort.
// When Connection::shutdown() completes, even with an error,
// we return the original Connection::negotiate() error.
match result {
Ok(r) => Ok(r).into(),
Err(e) if e.is_retryable() => Err(e).into(),
Err(e) => match Pin::new(&mut self.tls).poll_shutdown(ctx) {
Pending => {
self.error = Some(e);
Pending
}
Ready(_) => Err(e).into(),
},
}
}
}

Expand All @@ -103,16 +145,26 @@ where
{
conn: C,
stream: S,
blinding: Option<Pin<Box<Sleep>>>,
}

impl<S, C> TlsStream<S, C>
where
C: AsRef<Connection> + AsMut<Connection> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
async fn open(conn: C, stream: S) -> Result<Self, Error> {
let mut tls = TlsStream { conn, stream };
TlsHandshake { tls: &mut tls }.await?;
async fn open(mut conn: C, stream: S) -> Result<Self, Error> {
conn.as_mut().set_blinding(Blinding::SelfService)?;
let mut tls = TlsStream {
conn,
stream,
blinding: None,
};
TlsHandshake {
tls: &mut tls,
error: None,
}
.await?;
Ok(tls)
}

Expand Down Expand Up @@ -211,17 +263,17 @@ where
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut()
.with_io(ctx, |mut context| {
context
.conn
.as_mut()
.recv(buf.initialize_unfilled())
.map_ok(|size| {
buf.advance(size);
})
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
let tls = self.get_mut();
tls.with_io(ctx, |mut context| {
context
.conn
.as_mut()
.recv(buf.initialize_unfilled())
.map_ok(|size| {
buf.advance(size);
})
})
.map_err(io::Error::from)
}
}

Expand All @@ -235,37 +287,49 @@ where
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut()
.with_io(ctx, |mut context| context.conn.as_mut().send(buf))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
let tls = self.get_mut();
tls.with_io(ctx, |mut context| context.conn.as_mut().send(buf))
.map_err(io::Error::from)
}

fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
let tls = self.get_mut();
let tls_flush = tls
.with_io(ctx, |mut context| {
context.conn.as_mut().flush().map(|r| r.map(|_| ()))
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
if tls_flush.is_ready() {
Pin::new(&mut tls.stream).poll_flush(ctx)
} else {
tls_flush
}

ready!(tls.with_io(ctx, |mut context| {
context.conn.as_mut().flush().map(|r| r.map(|_| ()))
}))
.map_err(io::Error::from)?;

Pin::new(&mut tls.stream).poll_flush(ctx)
}

fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
let tls = self.get_mut();
let tls_shutdown = tls
.with_io(ctx, |mut context| {
context.conn.as_mut().shutdown().map(|r| r.map(|_| ()))
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
if tls_shutdown.is_ready() {
Pin::new(&mut tls.stream).poll_shutdown(ctx)
} else {
tls_shutdown

if tls.blinding.is_none() {
let delay = tls
.as_ref()
.remaining_blinding_delay()
.map_err(io::Error::from)?;
if !delay.is_zero() {
// Sleep operates at the milisecond resolution, so add an extra
// millisecond to account for any stray nanoseconds.
let safety = Duration::from_millis(1);
tls.blinding = Some(Box::pin(sleep(delay.saturating_add(safety))));
}
};

if let Some(timer) = tls.blinding.as_mut() {
ready!(timer.as_mut().poll(ctx));
tls.blinding = None;
}

ready!(tls.with_io(ctx, |mut context| {
context.conn.as_mut().shutdown().map(|r| r.map(|_| ()))
}))
.map_err(io::Error::from)?;

Pin::new(&mut tls.stream).poll_shutdown(ctx)
}
}

Expand Down
33 changes: 20 additions & 13 deletions bindings/rust/s2n-tls-tokio/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

use s2n_tls::raw::{config, connection::Builder, error::Error, security::DEFAULT_TLS13};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream};
use tokio::net::{TcpListener, TcpStream};
use std::time::Duration;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};

mod stream;
pub use stream::*;

/// NOTE: this certificate and key are used for testing purposes only!
pub static CERT_PEM: &[u8] = include_bytes!(concat!(
Expand All @@ -15,6 +22,9 @@ pub static KEY_PEM: &[u8] = include_bytes!(concat!(
"/examples/certs/key.pem"
));

pub const MIN_BLINDING_SECS: Duration = Duration::from_secs(10);
pub const MAX_BLINDING_SECS: Duration = Duration::from_secs(30);

pub async fn get_streams() -> Result<(TcpStream, TcpStream), tokio::io::Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await?;
Expand All @@ -38,24 +48,21 @@ pub fn server_config() -> Result<config::Builder, Error> {
Ok(builder)
}

pub async fn run_negotiate<A: Builder, B: Builder>(
pub async fn run_negotiate<A: Builder, B: Builder, C, D>(
client: &TlsConnector<A>,
client_stream: TcpStream,
client_stream: C,
server: &TlsAcceptor<B>,
server_stream: TcpStream,
) -> Result<
(
TlsStream<TcpStream, A::Output>,
TlsStream<TcpStream, B::Output>,
),
Error,
>
server_stream: D,
) -> Result<(TlsStream<C, A::Output>, TlsStream<D, B::Output>), Error>
where
<A as Builder>::Output: Unpin,
<B as Builder>::Output: Unpin,
C: AsyncRead + AsyncWrite + Unpin,
D: AsyncRead + AsyncWrite + Unpin,
{
tokio::try_join!(
let (client, server) = tokio::join!(
client.connect("localhost", client_stream),
server.accept(server_stream)
)
);
Ok((client?, server?))
}
103 changes: 103 additions & 0 deletions bindings/rust/s2n-tls-tokio/tests/common/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{
io,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};

type ReadFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &mut ReadBuf) -> Poll<io::Result<()>>>;
type WriteFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &[u8]) -> Poll<io::Result<usize>>>;

#[derive(Default)]
struct OverrideMethods {
next_read: Option<ReadFn>,
next_write: Option<WriteFn>,
}

#[derive(Default)]
pub struct Overrides(Mutex<OverrideMethods>);

impl Overrides {
pub fn next_read(&self, input: Option<ReadFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_read = input;
}
}

pub fn next_write(&self, input: Option<WriteFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_write = input;
}
}
}

pub struct TestStream {
stream: TcpStream,
overrides: Arc<Overrides>,
}

impl TestStream {
pub fn new(stream: TcpStream) -> Self {
let overrides = Arc::new(Overrides::default());
Self { stream, overrides }
}

pub fn overrides(&self) -> Arc<Overrides> {
self.overrides.clone()
}
}

impl AsyncRead for TestStream {
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_read.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx, buf)
} else {
stream.poll_read(ctx, buf)
}
}
}

impl AsyncWrite for TestStream {
fn poll_write(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_write.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx, buf)
} else {
stream.poll_write(ctx, buf)
}
}

fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(ctx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(ctx)
}
}
Loading

0 comments on commit 0459b41

Please sign in to comment.