Skip to content

Commit c6a68b4

Browse files
authored
Merge pull request #37 from tmccombs/timeout-errors
feat!: Add a new error type for handshake timeouts
2 parents 8eed9e9 + 23ca7ff commit c6a68b4

13 files changed

+226
-60
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ hyper-h2 = ["hyper", "hyper/http2"]
2121
[dependencies]
2222
futures-util = "0.3.8"
2323
hyper = { version = "0.14.1", features = ["server", "tcp"], optional = true }
24-
pin-project-lite = "0.2.8"
24+
pin-project-lite = "0.2.13"
2525
thiserror = "1.0.30"
2626
tokio = { version = "1.0", features = ["time"] }
2727
tokio-native-tls = { version = "0.3.0", optional = true }

examples/echo-threads.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mod tls_config;
1313
use tls_config::tls_acceptor;
1414

1515
#[inline]
16-
async fn handle_stream(stream: TlsStream<TcpStream>) {
16+
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
1717
let (mut reader, mut writer) = split(stream);
1818
match copy(&mut reader, &mut writer).await {
1919
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
@@ -32,8 +32,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
3232
TlsListener::new(SpawningHandshakes(tls_acceptor()), listener)
3333
.for_each_concurrent(None, |s| async {
3434
match s {
35-
Ok(stream) => {
36-
handle_stream(stream).await;
35+
Ok((stream, remote_addr)) => {
36+
handle_stream(stream, remote_addr).await;
3737
}
3838
Err(e) => {
3939
eprintln!("Error: {:?}", e);

examples/echo.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ mod tls_config;
2222
use tls_config::tls_acceptor;
2323

2424
#[inline]
25-
async fn handle_stream(stream: TlsStream<TcpStream>) {
25+
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
2626
let (mut reader, mut writer) = split(stream);
2727
match copy(&mut reader, &mut writer).await {
2828
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
@@ -41,10 +41,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
4141
TlsListener::new(tls_acceptor(), listener)
4242
.for_each_concurrent(None, |s| async {
4343
match s {
44-
Ok(stream) => {
45-
handle_stream(stream).await;
44+
Ok((stream, remote_addr)) => {
45+
handle_stream(stream, remote_addr).await;
4646
}
4747
Err(e) => {
48+
if let Some(remote_addr) = e.peer_addr() {
49+
eprint!("[client {remote_addr}] ");
50+
}
51+
4852
eprintln!("Error accepting connection: {:?}", e);
4953
}
5054
}

examples/http-change-certificate.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,22 @@ async fn main() {
3030
tokio::select! {
3131
conn = listener.accept() => {
3232
match conn.expect("Tls listener stream should be infinite") {
33-
Ok(conn) => {
33+
Ok((conn, remote_addr)) => {
3434
let http = http.clone();
3535
let tx = tx.clone();
3636
let counter = counter.clone();
3737
tokio::spawn(async move {
3838
let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request));
3939
if let Err(err) = http.serve_connection(conn, svc).await {
40-
eprintln!("Application error: {}", err);
40+
eprintln!("Application error (client address: {remote_addr}): {err}");
4141
}
4242
});
4343
},
4444
Err(e) => {
45+
if let Some(remote_addr) = e.peer_addr() {
46+
eprint!("[client {remote_addr}] ");
47+
}
48+
4549
eprintln!("Bad connection: {}", e);
4650
}
4751
}

examples/http-low-level.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,19 @@ async fn main() {
2727
listener
2828
.for_each(|r| async {
2929
match r {
30-
Ok(conn) => {
30+
Ok((conn, remote_addr)) => {
3131
let http = http.clone();
3232
tokio::spawn(async move {
3333
if let Err(err) = http.serve_connection(conn, svc).await {
34-
eprintln!("Application error: {}", err);
34+
eprintln!("[client {remote_addr}] Application error: {}", err);
3535
}
3636
});
3737
}
3838
Err(err) => {
39+
if let Some(remote_addr) = err.peer_addr() {
40+
eprint!("[client {remote_addr}] ");
41+
}
42+
3943
eprintln!("Error accepting connection: {}", err);
4044
}
4145
}

examples/http-stream.rs

+10-8
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
2222
});
2323

2424
// This uses a filter to handle errors with connecting
25-
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| {
26-
if let Err(err) = conn {
27-
eprintln!("Error: {:?}", err);
28-
ready(false)
29-
} else {
30-
ready(true)
31-
}
32-
});
25+
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?)
26+
.connections()
27+
.filter(|conn| {
28+
if let Err(err) = conn {
29+
eprintln!("Error: {:?}", err);
30+
ready(false)
31+
} else {
32+
ready(true)
33+
}
34+
});
3335

3436
let server = Server::builder(accept::from_stream(incoming)).serve(new_svc);
3537
server.await?;

examples/tls_config/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ mod config {
55

66
const CERT: &[u8] = include_bytes!("local.cert");
77
const PKEY: &[u8] = include_bytes!("local.key");
8+
#[allow(dead_code)]
89
const CERT2: &[u8] = include_bytes!("local2.cert");
10+
#[allow(dead_code)]
911
const PKEY2: &[u8] = include_bytes!("local2.key");
1012

1113
pub type Acceptor = tokio_rustls::TlsAcceptor;
@@ -27,6 +29,7 @@ mod config {
2729
tls_acceptor_impl(PKEY, CERT)
2830
}
2931

32+
#[allow(dead_code)]
3033
pub fn tls_acceptor2() -> Acceptor {
3134
tls_acceptor_impl(PKEY2, CERT2)
3235
}

src/hyper.rs

+21-6
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@ use std::ops::{Deref, DerefMut};
77
impl AsyncAccept for AddrIncoming {
88
type Connection = AddrStream;
99
type Error = std::io::Error;
10+
type Address = std::net::SocketAddr;
1011

1112
fn poll_accept(
1213
self: Pin<&mut Self>,
1314
cx: &mut Context<'_>,
14-
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
15-
<AddrIncoming as HyperAccept>::poll_accept(self, cx)
15+
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
16+
<AddrIncoming as HyperAccept>::poll_accept(self, cx).map_ok(|conn| {
17+
let peer_addr = conn.remote_addr();
18+
(conn, peer_addr)
19+
})
1620
}
1721
}
1822

@@ -22,6 +26,11 @@ pin_project! {
2226
/// Unfortunately, it isn't possible to use a blanket impl, due to coherence rules.
2327
/// At least until [RFC 1210](https://rust-lang.github.io/rfcs/1210-impl-specialization.html)
2428
/// (specialization) is stabilized.
29+
///
30+
/// Note that, because `hyper::server::accept::Accept` does not expose the
31+
/// remote address, the implementation of `AsyncAccept` for `WrappedAccept`
32+
/// doesn't expose it either. That is, [`AsyncAccept::Address`] is `()` in
33+
/// this case.
2534
//#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
2635
pub struct WrappedAccept<A> {
2736
// sadly, pin-project-lite doesn't suport tuple structs :(
@@ -43,15 +52,20 @@ pub fn wrap<A: HyperAccept>(acceptor: A) -> WrappedAccept<A> {
4352
impl<A: HyperAccept> AsyncAccept for WrappedAccept<A>
4453
where
4554
A::Conn: AsyncRead + AsyncWrite,
55+
A::Error: std::error::Error,
4656
{
4757
type Connection = A::Conn;
4858
type Error = A::Error;
59+
type Address = ();
4960

5061
fn poll_accept(
5162
self: Pin<&mut Self>,
5263
cx: &mut Context<'_>,
53-
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
54-
self.project().inner.poll_accept(cx)
64+
) -> Poll<Option<Result<(Self::Connection, ()), Self::Error>>> {
65+
self.project()
66+
.inner
67+
.poll_accept(cx)
68+
.map_ok(|conn| (conn, ()))
5569
}
5670
}
5771

@@ -78,6 +92,7 @@ impl<A: HyperAccept> WrappedAccept<A> {
7892
impl<A: HyperAccept, T> TlsListener<WrappedAccept<A>, T>
7993
where
8094
A::Conn: AsyncWrite + AsyncRead,
95+
A::Error: std::error::Error,
8196
T: AsyncTls<A::Conn>,
8297
{
8398
/// Create a `TlsListener` from a hyper [`Accept`](::hyper::server::accept::Accept) and tls
@@ -95,12 +110,12 @@ where
95110
T: AsyncTls<A::Connection>,
96111
{
97112
type Conn = T::Stream;
98-
type Error = Error<A::Error, T::Error>;
113+
type Error = Error<A::Error, T::Error, A::Address>;
99114

100115
fn poll_accept(
101116
self: Pin<&mut Self>,
102117
cx: &mut Context<'_>,
103118
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
104-
self.poll_next(cx)
119+
self.poll_next(cx).map_ok(|(conn, _)| conn)
105120
}
106121
}

0 commit comments

Comments
 (0)