Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bindings] Apply async blinding #3356

Merged
merged 3 commits into from
Jun 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions bindings/rust/s2n-tls-tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ default = []
errno = { version = "0.2" }
libc = { version = "0.2" }
s2n-tls = { version = "=0.0.8", path = "../s2n-tls" }
tokio = { version = "1", features = ["net"] }
tokio = { version = "1", features = ["net", "time"] }

[dev-dependencies]
clap = { version = "3.1", features = ["derive"] }
rand = { version = "0.8" }
tokio = { version = "1", features = [ "io-std", "io-util", "macros", "net", "rt-multi-thread", "time"] }
tokio = { version = "1", features = [ "io-std", "io-util", "macros", "net", "rt-multi-thread", "test-util", "time"] }
150 changes: 107 additions & 43 deletions bindings/rust/s2n-tls-tokio/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use errno::{set_errno, Errno};
use s2n_tls::raw::{
config::Config,
connection::{Builder, Connection},
enums::{CallbackResult, Mode},
enums::{Blinding, CallbackResult, Mode},
error::Error,
};
use std::{
Expand All @@ -14,9 +14,24 @@ use std::{
io,
os::raw::{c_int, c_void},
pin::Pin,
task::{Context, Poll},
task::{
Context, Poll,
Poll::{Pending, Ready},
},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
time::{sleep, Duration, Sleep},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

macro_rules! ready {
($x:expr) => {
match $x {
Ready(r) => r,
Pending => return Pending,
}
};
}

#[derive(Clone)]
pub struct TlsAcceptor<B: Builder = Config>
Expand Down Expand Up @@ -79,6 +94,7 @@ where
S: AsyncRead + AsyncWrite + Unpin,
{
tls: &'a mut TlsStream<S, C>,
error: Option<Error>,
}

impl<S, C> Future for TlsHandshake<'_, S, C>
Expand All @@ -89,10 +105,36 @@ where
type Output = Result<(), Error>;

fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
self.tls.with_io(ctx, |context| {
let conn = context.get_mut().as_mut();
conn.negotiate().map(|r| r.map(|_| ()))
})
// Retrieve a result, either from the stored error
// or by polling Connection::negotiate().
// Connection::negotiate() only completes once,
// regardless of how often this method is polled.
let result = match self.error.take() {
Some(err) => Err(err),
None => {
ready!(self.tls.with_io(ctx, |context| {
let conn = context.get_mut().as_mut();
conn.negotiate().map(|r| r.map(|_| ()))
}))
}
};
// If the result isn't a fatal error, return it immediately.
// Otherwise, poll Connection::shutdown().
//
// Shutdown is only best-effort.
// When Connection::shutdown() completes, even with an error,
// we return the original Connection::negotiate() error.
match result {
Ok(r) => Ok(r).into(),
Err(e) if e.is_retryable() => Err(e).into(),
Err(e) => match Pin::new(&mut self.tls).poll_shutdown(ctx) {
Pending => {
self.error = Some(e);
Pending
}
Ready(_) => Err(e).into(),
},
}
}
}

Expand All @@ -103,16 +145,26 @@ where
{
conn: C,
stream: S,
blinding: Option<Pin<Box<Sleep>>>,
}

impl<S, C> TlsStream<S, C>
where
C: AsRef<Connection> + AsMut<Connection> + Unpin,
S: AsyncRead + AsyncWrite + Unpin,
{
async fn open(conn: C, stream: S) -> Result<Self, Error> {
let mut tls = TlsStream { conn, stream };
TlsHandshake { tls: &mut tls }.await?;
async fn open(mut conn: C, stream: S) -> Result<Self, Error> {
conn.as_mut().set_blinding(Blinding::SelfService)?;
let mut tls = TlsStream {
conn,
stream,
blinding: None,
};
TlsHandshake {
tls: &mut tls,
error: None,
}
.await?;
Ok(tls)
}

Expand Down Expand Up @@ -211,17 +263,17 @@ where
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.get_mut()
.with_io(ctx, |mut context| {
context
.conn
.as_mut()
.recv(buf.initialize_unfilled())
.map_ok(|size| {
buf.advance(size);
})
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
let tls = self.get_mut();
tls.with_io(ctx, |mut context| {
context
.conn
.as_mut()
.recv(buf.initialize_unfilled())
.map_ok(|size| {
buf.advance(size);
})
})
.map_err(io::Error::from)
}
}

Expand All @@ -235,37 +287,49 @@ where
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut()
.with_io(ctx, |mut context| context.conn.as_mut().send(buf))
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
let tls = self.get_mut();
tls.with_io(ctx, |mut context| context.conn.as_mut().send(buf))
.map_err(io::Error::from)
}

fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
let tls = self.get_mut();
let tls_flush = tls
.with_io(ctx, |mut context| {
context.conn.as_mut().flush().map(|r| r.map(|_| ()))
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
if tls_flush.is_ready() {
Pin::new(&mut tls.stream).poll_flush(ctx)
} else {
tls_flush
}

ready!(tls.with_io(ctx, |mut context| {
context.conn.as_mut().flush().map(|r| r.map(|_| ()))
}))
.map_err(io::Error::from)?;

Pin::new(&mut tls.stream).poll_flush(ctx)
}

fn poll_shutdown(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
let tls = self.get_mut();
let tls_shutdown = tls
.with_io(ctx, |mut context| {
context.conn.as_mut().shutdown().map(|r| r.map(|_| ()))
})
.map_err(|e| io::Error::new(io::ErrorKind::Other, e));
if tls_shutdown.is_ready() {
Pin::new(&mut tls.stream).poll_shutdown(ctx)
} else {
tls_shutdown

if tls.blinding.is_none() {
let delay = tls
.as_ref()
.remaining_blinding_delay()
.map_err(io::Error::from)?;
if !delay.is_zero() {
// Sleep operates at the milisecond resolution, so add an extra
// millisecond to account for any stray nanoseconds.
let safety = Duration::from_millis(1);
tls.blinding = Some(Box::pin(sleep(delay.saturating_add(safety))));
}
};

if let Some(timer) = tls.blinding.as_mut() {
ready!(timer.as_mut().poll(ctx));
tls.blinding = None;
}

ready!(tls.with_io(ctx, |mut context| {
context.conn.as_mut().shutdown().map(|r| r.map(|_| ()))
}))
.map_err(io::Error::from)?;

Pin::new(&mut tls.stream).poll_shutdown(ctx)
}
}

Expand Down
33 changes: 20 additions & 13 deletions bindings/rust/s2n-tls-tokio/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

use s2n_tls::raw::{config, connection::Builder, error::Error, security::DEFAULT_TLS13};
use s2n_tls_tokio::{TlsAcceptor, TlsConnector, TlsStream};
use tokio::net::{TcpListener, TcpStream};
use std::time::Duration;
use tokio::{
io::{AsyncRead, AsyncWrite},
net::{TcpListener, TcpStream},
};

mod stream;
pub use stream::*;

/// NOTE: this certificate and key are used for testing purposes only!
pub static CERT_PEM: &[u8] = include_bytes!(concat!(
Expand All @@ -15,6 +22,9 @@ pub static KEY_PEM: &[u8] = include_bytes!(concat!(
"/examples/certs/key.pem"
));

pub const MIN_BLINDING_SECS: Duration = Duration::from_secs(10);
pub const MAX_BLINDING_SECS: Duration = Duration::from_secs(30);

pub async fn get_streams() -> Result<(TcpStream, TcpStream), tokio::io::Error> {
let localhost = "127.0.0.1".to_owned();
let listener = TcpListener::bind(format!("{}:0", localhost)).await?;
Expand All @@ -38,24 +48,21 @@ pub fn server_config() -> Result<config::Builder, Error> {
Ok(builder)
}

pub async fn run_negotiate<A: Builder, B: Builder>(
pub async fn run_negotiate<A: Builder, B: Builder, C, D>(
client: &TlsConnector<A>,
client_stream: TcpStream,
client_stream: C,
server: &TlsAcceptor<B>,
server_stream: TcpStream,
) -> Result<
(
TlsStream<TcpStream, A::Output>,
TlsStream<TcpStream, B::Output>,
),
Error,
>
server_stream: D,
) -> Result<(TlsStream<C, A::Output>, TlsStream<D, B::Output>), Error>
where
<A as Builder>::Output: Unpin,
<B as Builder>::Output: Unpin,
C: AsyncRead + AsyncWrite + Unpin,
D: AsyncRead + AsyncWrite + Unpin,
{
tokio::try_join!(
let (client, server) = tokio::join!(
client.connect("localhost", client_stream),
server.accept(server_stream)
)
);
Ok((client?, server?))
}
103 changes: 103 additions & 0 deletions bindings/rust/s2n-tls-tokio/tests/common/stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use std::{
io,
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll},
};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};

type ReadFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &mut ReadBuf) -> Poll<io::Result<()>>>;
type WriteFn = Box<dyn Fn(Pin<&mut TcpStream>, &mut Context, &[u8]) -> Poll<io::Result<usize>>>;

#[derive(Default)]
struct OverrideMethods {
next_read: Option<ReadFn>,
next_write: Option<WriteFn>,
}

#[derive(Default)]
pub struct Overrides(Mutex<OverrideMethods>);

impl Overrides {
pub fn next_read(&self, input: Option<ReadFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_read = input;
}
}

pub fn next_write(&self, input: Option<WriteFn>) {
if let Ok(mut overrides) = self.0.lock() {
overrides.next_write = input;
}
}
}

pub struct TestStream {
stream: TcpStream,
overrides: Arc<Overrides>,
}

impl TestStream {
pub fn new(stream: TcpStream) -> Self {
let overrides = Arc::new(Overrides::default());
Self { stream, overrides }
}

pub fn overrides(&self) -> Arc<Overrides> {
self.overrides.clone()
}
}

impl AsyncRead for TestStream {
fn poll_read(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_read.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx, buf)
} else {
stream.poll_read(ctx, buf)
}
}
}

impl AsyncWrite for TestStream {
fn poll_write(
self: Pin<&mut Self>,
ctx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let s = self.get_mut();
let stream = Pin::new(&mut s.stream);
let action = match s.overrides.0.lock() {
Ok(mut overrides) => overrides.next_write.take(),
_ => None,
};
if let Some(f) = action {
(f)(stream, ctx, buf)
} else {
stream.poll_write(ctx, buf)
}
}

fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_flush(ctx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.stream).poll_shutdown(ctx)
}
}
Loading