diff --git a/Cargo.toml b/Cargo.toml index b5019f0..fee1061 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,14 +13,18 @@ edition = "2018" [dependencies] futures-io = "0.3" -rustls = { version = "0.22", default-features = false, features = ["tls12"] } +rustls = { version = "0.23", default-features = false, features = ["std"] } pki-types = { package = "rustls-pki-types", version = "1" } [features] -default = ["ring"] +default = ["aws-lc-rs", "tls12", "logging"] +aws-lc-rs = ["rustls/aws_lc_rs"] +aws_lc_rs = ["aws-lc-rs"] early-data = [] +fips = ["rustls/fips"] +logging = ["rustls/logging"] ring = ["rustls/ring"] -aws-lc-rs = ["rustls/aws_lc_rs"] +tls12 = ["rustls/tls12"] [dev-dependencies] smol = "1" diff --git a/src/client.rs b/src/client.rs index 8696d00..1d6717f 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,9 +1,5 @@ use super::*; use crate::common::IoSession; -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, RawSocket}; /// A wrapper around an underlying raw stream which implements the TLS or SSL /// protocol. @@ -34,6 +30,72 @@ impl TlsStream { } } +#[cfg(feature = "early-data")] +fn poll_handle_early_data( + state: &mut TlsState, + stream: &mut Stream, + early_waker: &mut Option, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], +) -> Poll> +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + if let TlsState::EarlyData(pos, data) = state { + use std::io::Write; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let mut written = 0; + + for buf in bufs { + if buf.is_empty() { + continue; + } + + let len = match early_data.write(buf) { + Ok(0) => break, + Ok(n) => n, + Err(err) => return Poll::Ready(Err(err)), + }; + + written += len; + data.extend_from_slice(&buf[..len]); + + if len < buf.len() { + break; + } + } + + if written != 0 { + return Poll::Ready(Ok(written)); + } + } + + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + *state = TlsState::Stream; + + if let Some(waker) = early_waker.take() { + waker.wake(); + } + } + + Poll::Ready(Ok(0)) +} + #[cfg(unix)] impl AsRawFd for TlsStream where @@ -145,48 +207,47 @@ where let mut stream = Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); - #[allow(clippy::match_single_binding)] - match this.state { - #[cfg(feature = "early-data")] - TlsState::EarlyData(ref mut pos, ref mut data) => { - use std::io::Write; - - // write early data - if let Some(mut early_data) = stream.session.early_data() { - let len = match early_data.write(buf) { - Ok(n) => n, - Err(err) => return Poll::Ready(Err(err)), - }; - if len != 0 { - data.extend_from_slice(&buf[..len]); - return Poll::Ready(Ok(len)); - } - } - - // complete handshake - while stream.session.is_handshaking() { - ready!(stream.handshake(cx))?; - } - - // write early data (fallback) - if !stream.session.is_early_data_accepted() { - while *pos < data.len() { - let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; - *pos += len; - } - } - - // end - this.state = TlsState::Stream; + #[cfg(feature = "early-data")] + { + let bufs = [io::IoSlice::new(buf)]; + let written = ready!(poll_handle_early_data( + &mut this.state, + &mut stream, + &mut this.early_waker, + cx, + &bufs + ))?; + if written != 0 { + return Poll::Ready(Ok(written)); + } + } + stream.as_mut_pin().poll_write(cx, buf) + } - if let Some(waker) = this.early_waker.take() { - waker.wake(); - } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); - stream.as_mut_pin().poll_write(cx, buf) + #[cfg(feature = "early-data")] + { + let written = ready!(poll_handle_early_data( + &mut this.state, + &mut stream, + &mut this.early_waker, + cx, + bufs + ))?; + if written != 0 { + return Poll::Ready(Ok(written)); } - _ => stream.as_mut_pin().poll_write(cx, buf), } + + stream.as_mut_pin().poll_write_vectored(cx, bufs) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { diff --git a/src/common/handshake.rs b/src/common/handshake.rs index 0f285c0..69b9695 100644 --- a/src/common/handshake.rs +++ b/src/common/handshake.rs @@ -1,11 +1,12 @@ -use crate::common::{Stream, TlsState}; +use crate::common::{Stream, SyncWriteAdapter, TlsState}; +use futures_io::{AsyncRead, AsyncWrite}; +use rustls::server::AcceptedAlert; use rustls::{ConnectionCommon, SideData}; use std::future::Future; use std::ops::{Deref, DerefMut}; use std::pin::Pin; use std::task::{Context, Poll}; use std::{io, mem}; -use futures_io::{AsyncRead, AsyncWrite}; pub(crate) trait IoSession { type Io; @@ -19,7 +20,15 @@ pub(crate) trait IoSession { pub(crate) enum MidHandshake { Handshaking(IS), End, - Error { io: IS::Io, error: io::Error }, + SendAlert { + io: IS::Io, + alert: AcceptedAlert, + error: io::Error, + }, + Error { + io: IS::Io, + error: io::Error, + }, } impl Future for MidHandshake @@ -36,6 +45,20 @@ where let mut stream = match mem::replace(this, MidHandshake::End) { MidHandshake::Handshaking(stream) => stream, + MidHandshake::SendAlert { + mut io, + mut alert, + error, + } => loop { + match alert.write(&mut SyncWriteAdapter { io: &mut io, cx }) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + *this = MidHandshake::SendAlert { io, error, alert }; + return Poll::Pending; + } + Err(_) | Ok(0) => return Poll::Ready(Err((error, io))), + Ok(_) => {} + }; + }, // Starting the handshake returned an error; fail the future immediately. MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), _ => panic!("unexpected polling after handshake"), diff --git a/src/common/mod.rs b/src/common/mod.rs index 2d6d0e1..5169f9d 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -89,23 +89,7 @@ where } pub fn read_io(&mut self, cx: &mut Context) -> Poll> { - struct Reader<'a, 'b, T> { - io: &'a mut T, - cx: &'a mut Context<'b>, - } - - impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> { - #[inline] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - match Pin::new(&mut self.io).poll_read(self.cx, buf) { - Poll::Ready(Ok(n)) => Ok(n), - Poll::Ready(Err(err)) => Err(err), - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), - } - } - } - - let mut reader = Reader { io: self.io, cx }; + let mut reader = SyncReadAdapter { io: self.io, cx }; let n = match self.session.read_tls(&mut reader) { Ok(n) => n, @@ -133,41 +117,7 @@ where } pub fn write_io(&mut self, cx: &mut Context) -> Poll> { - struct Writer<'a, 'b, T> { - io: &'a mut T, - cx: &'a mut Context<'b>, - } - - impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> { - #[inline] - fn poll_with( - &mut self, - f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, - ) -> io::Result { - match f(Pin::new(&mut self.io), self.cx) { - Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), - } - } - } - - impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - self.poll_with(|io, cx| io.poll_write(cx, buf)) - } - - #[inline] - fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) - } - - fn flush(&mut self) -> io::Result<()> { - self.poll_with(|io, cx| io.poll_flush(cx)) - } - } - - let mut writer = Writer { io: self.io, cx }; + let mut writer = SyncWriteAdapter { io: self.io, cx }; match self.session.write_tls(&mut writer) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, @@ -347,7 +297,45 @@ where while self.session.wants_write() { ready!(self.write_io(cx))?; } - Pin::new(&mut self.io).poll_close(cx) + + Poll::Ready(match ready!(Pin::new(&mut self.io).poll_close(cx)) { + Ok(()) => Ok(()), + // When trying to shutdown, not being connected seems fine + Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()), + Err(err) => Err(err), + }) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + if bufs.iter().all(|buf| buf.is_empty()) { + return Poll::Ready(Ok(0)); + } + + loop { + let mut would_block = false; + let written = self.session.writer().write_vectored(bufs)?; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + return match (written, would_block) { + (0, true) => Poll::Pending, + (0, false) => continue, + (n, _) => Poll::Ready(Ok(n)), + }; + } } } @@ -371,5 +359,39 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { } } +pub(crate) struct SyncWriteAdapter<'a, 'b, T> { + pub(crate) io: &'a mut T, + pub(crate) cx: &'a mut Context<'b>, +} + +impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> { + #[inline] + fn poll_with( + &mut self, + f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, + ) -> io::Result { + match f(Pin::new(&mut self.io), self.cx) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } +} + +impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.poll_with(|io, cx| io.poll_write(cx, buf)) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) + } + + fn flush(&mut self) -> io::Result<()> { + self.poll_with(|io, cx| io.poll_flush(cx)) + } +} + #[cfg(test)] mod test_stream; diff --git a/src/lib.rs b/src/lib.rs index ec4ac4f..eeef2a7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -//! Asynchronous TLS/SSL streams for futures using [Rustls](https://github.com/ctz/rustls). +//! Asynchronous TLS/SSL streams for futures using [Rustls](https://github.com/rustls/rustls). macro_rules! ready { ( $e:expr ) => { @@ -15,6 +15,7 @@ pub mod server; use common::{MidHandshake, Stream, TlsState}; use futures_io::{AsyncRead, AsyncWrite}; +use rustls::server::AcceptedAlert; use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; use std::future::Future; use std::io; @@ -26,8 +27,8 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -pub use rustls; pub use pki_types; +pub use rustls; /// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. #[derive(Clone)] @@ -78,7 +79,12 @@ impl TlsConnector { self.connect_with(domain, stream, |_| ()) } - pub fn connect_with(&self, domain: pki_types::ServerName<'static>, stream: IO, f: F) -> Connect + pub fn connect_with( + &self, + domain: pki_types::ServerName<'static>, + stream: IO, + f: F, + ) -> Connect where IO: AsyncRead + AsyncWrite + Unpin, F: FnOnce(&mut ClientConnection), @@ -155,6 +161,7 @@ impl TlsAcceptor { pub struct LazyConfigAcceptor { acceptor: rustls::server::Acceptor, io: Option, + alert: Option<(rustls::Error, AcceptedAlert)>, } impl LazyConfigAcceptor @@ -166,6 +173,7 @@ where Self { acceptor, io: Some(io), + alert: None, } } } @@ -189,6 +197,22 @@ where } }; + if let Some((err, mut alert)) = this.alert.take() { + match alert.write(&mut common::SyncWriteAdapter { io, cx }) { + Err(e) if e.kind() == io::ErrorKind::WouldBlock => { + this.alert = Some((err, alert)); + return Poll::Pending; + } + Ok(0) | Err(_) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))) + } + Ok(_) => { + this.alert = Some((err, alert)); + continue; + } + }; + } + let mut reader = common::SyncReadAdapter { io, cx }; match this.acceptor.read_tls(&mut reader) { Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(), @@ -202,9 +226,9 @@ where let io = this.io.take().unwrap(); return Poll::Ready(Ok(StartHandshake { accepted, io })); } - Ok(None) => continue, - Err(err) => { - return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))) + Ok(None) => {} + Err((err, alert)) => { + this.alert = Some((err, alert)); } } } @@ -234,12 +258,13 @@ where { let mut conn = match self.accepted.into_connection(config) { Ok(conn) => conn, - Err(error) => { - return Accept(MidHandshake::Error { + Err((error, alert)) => { + return Accept(MidHandshake::SendAlert { io: self.io, // TODO(eliza): should this really return an `io::Error`? // Probably not... error: io::Error::new(io::ErrorKind::Other, error), + alert, }); } }; @@ -333,11 +358,11 @@ impl TlsStream { match self { Client(io) => { let (io, session) = io.get_ref(); - (io, &*session) + (io, session) } Server(io) => { let (io, session) = io.get_ref(); - (io, &*session) + (io, session) } } } @@ -437,4 +462,16 @@ where TlsStream::Server(x) => Pin::new(x).poll_close(cx), } } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_write_vectored(cx, bufs), + TlsStream::Server(x) => Pin::new(x).poll_write_vectored(cx, bufs), + } + } } diff --git a/src/server.rs b/src/server.rs index 818236b..5f36cae 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,3 @@ -#[cfg(unix)] -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(windows)] -use std::os::windows::io::{AsRawSocket, RawSocket}; - use super::*; use crate::common::IoSession; @@ -124,6 +119,17 @@ where Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); stream.as_mut_pin().poll_close(cx) } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_write_vectored(cx, bufs) + } } #[cfg(unix)]