Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@ edition = "2018"

[dependencies]
futures-io = "0.3"
rustls = { version = "0.22", default-features = false, features = ["tls12"] }
rustls = { version = "0.23", default-features = false, features = ["std"] }
pki-types = { package = "rustls-pki-types", version = "1" }

[features]
default = ["ring"]
default = ["aws-lc-rs", "tls12", "logging"]
aws-lc-rs = ["rustls/aws_lc_rs"]
aws_lc_rs = ["aws-lc-rs"]
early-data = []
fips = ["rustls/fips"]
logging = ["rustls/logging"]
ring = ["rustls/ring"]
aws-lc-rs = ["rustls/aws_lc_rs"]
tls12 = ["rustls/tls12"]

[dev-dependencies]
smol = "1"
Expand Down
145 changes: 103 additions & 42 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
use super::*;
use crate::common::IoSession;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};

/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
Expand Down Expand Up @@ -34,6 +30,72 @@ impl<IO> TlsStream<IO> {
}
}

#[cfg(feature = "early-data")]
fn poll_handle_early_data<IO>(
state: &mut TlsState,
stream: &mut Stream<IO, ClientConnection>,
early_waker: &mut Option<std::task::Waker>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
if let TlsState::EarlyData(pos, data) = state {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let mut written = 0;

for buf in bufs {
if buf.is_empty() {
continue;
}

let len = match early_data.write(buf) {
Ok(0) => break,
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};

written += len;
data.extend_from_slice(&buf[..len]);

if len < buf.len() {
break;
}
}

if written != 0 {
return Poll::Ready(Ok(written));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
*state = TlsState::Stream;

if let Some(waker) = early_waker.take() {
waker.wake();
}
}

Poll::Ready(Ok(0))
}

#[cfg(unix)]
impl<S> AsRawFd for TlsStream<S>
where
Expand Down Expand Up @@ -145,48 +207,47 @@ where
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[allow(clippy::match_single_binding)]
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => {
use std::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(err) => return Poll::Ready(Err(err)),
};
if len != 0 {
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
this.state = TlsState::Stream;
#[cfg(feature = "early-data")]
{
let bufs = [io::IoSlice::new(buf)];
let written = ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
&bufs
))?;
if written != 0 {
return Poll::Ready(Ok(written));
}
}
stream.as_mut_pin().poll_write(cx, buf)
}

if let Some(waker) = this.early_waker.take() {
waker.wake();
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

stream.as_mut_pin().poll_write(cx, buf)
#[cfg(feature = "early-data")]
{
let written = ready!(poll_handle_early_data(
&mut this.state,
&mut stream,
&mut this.early_waker,
cx,
bufs
))?;
if written != 0 {
return Poll::Ready(Ok(written));
}
_ => stream.as_mut_pin().poll_write(cx, buf),
}

stream.as_mut_pin().poll_write_vectored(cx, bufs)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Expand Down
29 changes: 26 additions & 3 deletions src/common/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use crate::common::{Stream, TlsState};
use crate::common::{Stream, SyncWriteAdapter, TlsState};
use futures_io::{AsyncRead, AsyncWrite};
use rustls::server::AcceptedAlert;
use rustls::{ConnectionCommon, SideData};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{io, mem};
use futures_io::{AsyncRead, AsyncWrite};

pub(crate) trait IoSession {
type Io;
Expand All @@ -19,7 +20,15 @@ pub(crate) trait IoSession {
pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
End,
Error { io: IS::Io, error: io::Error },
SendAlert {
io: IS::Io,
alert: AcceptedAlert,
error: io::Error,
},
Error {
io: IS::Io,
error: io::Error,
},
}

impl<IS, SD> Future for MidHandshake<IS>
Expand All @@ -36,6 +45,20 @@ where

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
MidHandshake::SendAlert {
mut io,
mut alert,
error,
} => loop {
match alert.write(&mut SyncWriteAdapter { io: &mut io, cx }) {
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
*this = MidHandshake::SendAlert { io, error, alert };
return Poll::Pending;
}
Err(_) | Ok(0) => return Poll::Ready(Err((error, io))),
Ok(_) => {}
};
},
// Starting the handshake returned an error; fail the future immediately.
MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))),
_ => panic!("unexpected polling after handshake"),
Expand Down
128 changes: 75 additions & 53 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,7 @@ where
}

pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Reader<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(Ok(n)) => Ok(n),
Poll::Ready(Err(err)) => Err(err),
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

let mut reader = Reader { io: self.io, cx };
let mut reader = SyncReadAdapter { io: self.io, cx };

let n = match self.session.read_tls(&mut reader) {
Ok(n) => n,
Expand Down Expand Up @@ -133,41 +117,7 @@ where
}

pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(&mut self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}

#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}

let mut writer = Writer { io: self.io, cx };
let mut writer = SyncWriteAdapter { io: self.io, cx };

match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Expand Down Expand Up @@ -347,7 +297,45 @@ where
while self.session.wants_write() {
ready!(self.write_io(cx))?;
}
Pin::new(&mut self.io).poll_close(cx)

Poll::Ready(match ready!(Pin::new(&mut self.io).poll_close(cx)) {
Ok(()) => Ok(()),
// When trying to shutdown, not being connected seems fine
Err(err) if err.kind() == io::ErrorKind::NotConnected => Ok(()),
Err(err) => Err(err),
})
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if bufs.iter().all(|buf| buf.is_empty()) {
return Poll::Ready(Ok(0));
}

loop {
let mut would_block = false;
let written = self.session.writer().write_vectored(bufs)?;

while self.session.wants_write() {
match self.write_io(cx) {
Poll::Ready(Ok(0)) | Poll::Pending => {
would_block = true;
break;
}
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}

return match (written, would_block) {
(0, true) => Poll::Pending,
(0, false) => continue,
(n, _) => Poll::Ready(Ok(n)),
};
}
}
}

Expand All @@ -371,5 +359,39 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
}
}

pub(crate) struct SyncWriteAdapter<'a, 'b, T> {
pub(crate) io: &'a mut T,
pub(crate) cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(&mut self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}

#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}

#[cfg(test)]
mod test_stream;
Loading