From fad42acc79b54ce38adf99c58c894f29fa2665ad Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Wed, 23 Dec 2020 10:36:12 -0800 Subject: [PATCH] feat(lib): Upgrade to Tokio 1.0 (#2369) Closes #2370 --- Cargo.toml | 14 +++--- src/body/to_bytes.rs | 2 +- src/client/connect/http.rs | 5 +- src/client/dispatch.rs | 12 ++--- src/client/pool.rs | 3 +- src/common/buf.rs | 8 ++-- src/common/io/mod.rs | 1 - src/proto/h1/encode.rs | 22 ++++----- src/proto/h1/io.rs | 25 +++++----- src/proto/h2/mod.rs | 8 ++-- src/proto/h2/ping.rs | 10 ++-- src/server/tcp.rs | 8 ++-- src/upgrade.rs | 94 +++----------------------------------- tests/client.rs | 5 ++ 14 files changed, 71 insertions(+), 146 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c31c689422..9263cf31d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,20 +22,20 @@ include = [ ] [dependencies] -bytes = "0.6" +bytes = "1" futures-core = { version = "0.3", default-features = false } futures-channel = "0.3" futures-util = { version = "0.3", default-features = false } http = "0.2" -http-body = { git = "https://github.com/hyperium/http-body" } +http-body = "0.4" httpdate = "0.3" httparse = "1.0" -h2 = { git = "https://github.com/hyperium/h2", optional = true } +h2 = { version = "0.3", optional = true } itoa = "0.4.1" tracing = { version = "0.1", default-features = false, features = ["std"] } pin-project = "1.0" tower-service = "0.3" -tokio = { version = "0.3.4", features = ["sync", "stream"] } +tokio = { version = "1", features = ["sync"] } want = "0.3" # Optional @@ -51,7 +51,7 @@ spmc = "0.3" serde = "1.0" serde_derive = "1.0" serde_json = "1.0" -tokio = { version = "0.3", features = [ +tokio = { version = "1", features = [ "fs", "macros", "io-std", @@ -62,8 +62,8 @@ tokio = { version = "0.3", features = [ "time", "test-util", ] } -tokio-test = "0.3" -tokio-util = { version = "0.5", features = ["codec"] } +tokio-test = "0.4" +tokio-util = { version = "0.6", features = ["codec"] } tower-util = "0.3" url = "1.0" diff --git a/src/body/to_bytes.rs b/src/body/to_bytes.rs index 8dfbe01cc3..7c0765f486 100644 --- a/src/body/to_bytes.rs +++ b/src/body/to_bytes.rs @@ -23,7 +23,7 @@ where let second = if let Some(buf) = body.data().await { buf? } else { - return Ok(first.copy_to_bytes(first.bytes().len())); + return Ok(first.copy_to_bytes(first.remaining())); }; // With more than 1 buf, we gotta flatten into a Vec first. diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index 5f05716814..3a9c7f708d 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -667,8 +667,11 @@ impl ConnectingTcp<'_> { let fallback_fut = fallback.remote.connect(self.config); futures_util::pin_mut!(fallback_fut); + let fallback_delay = fallback.delay; + futures_util::pin_mut!(fallback_delay); + let (result, future) = - match futures_util::future::select(preferred_fut, fallback.delay).await { + match futures_util::future::select(preferred_fut, fallback_delay).await { Either::Left((result, _fallback_delay)) => { (result, Either::Right(fallback_fut)) } diff --git a/src/client/dispatch.rs b/src/client/dispatch.rs index 08c1e01f42..a7e6311bad 100644 --- a/src/client/dispatch.rs +++ b/src/client/dispatch.rs @@ -1,7 +1,7 @@ #[cfg(feature = "http2")] use std::future::Future; -use tokio::stream::Stream; +use futures_util::FutureExt; use tokio::sync::{mpsc, oneshot}; use crate::common::{task, Pin, Poll}; @@ -150,8 +150,8 @@ impl Receiver { self: Pin<&mut Self>, cx: &mut task::Context<'_>, ) -> Poll)>> { - let this = self.project(); - match this.inner.poll_next(cx) { + let mut this = self.project(); + match this.inner.poll_recv(cx) { Poll::Ready(item) => { Poll::Ready(item.map(|mut env| env.0.take().expect("envelope not dropped"))) } @@ -170,9 +170,9 @@ impl Receiver { #[cfg(feature = "http1")] pub(crate) fn try_recv(&mut self) -> Option<(T, Callback)> { - match self.inner.try_recv() { - Ok(mut env) => env.0.take(), - Err(_) => None, + match self.inner.recv().now_or_never() { + Some(Some(mut env)) => env.0.take(), + _ => None, } } } diff --git a/src/client/pool.rs b/src/client/pool.rs index 777b4789b4..26be53544a 100644 --- a/src/client/pool.rs +++ b/src/client/pool.rs @@ -731,7 +731,6 @@ impl Future for IdleTask { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { - use tokio::stream::Stream; let mut this = self.project(); loop { match this.pool_drop_notifier.as_mut().poll(cx) { @@ -743,7 +742,7 @@ impl Future for IdleTask { } } - ready!(this.interval.as_mut().poll_next(cx)); + ready!(this.interval.as_mut().poll_tick(cx)); if let Some(inner) = this.pool.upgrade() { if let Ok(mut inner) = inner.lock() { diff --git a/src/common/buf.rs b/src/common/buf.rs index a882cc0f42..9c8feae617 100644 --- a/src/common/buf.rs +++ b/src/common/buf.rs @@ -34,8 +34,8 @@ impl Buf for BufList { } #[inline] - fn bytes(&self) -> &[u8] { - self.bufs.front().map(Buf::bytes).unwrap_or_default() + fn chunk(&self) -> &[u8] { + self.bufs.front().map(Buf::chunk).unwrap_or_default() } #[inline] @@ -57,13 +57,13 @@ impl Buf for BufList { } #[inline] - fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { if dst.is_empty() { return 0; } let mut vecs = 0; for buf in &self.bufs { - vecs += buf.bytes_vectored(&mut dst[vecs..]); + vecs += buf.chunks_vectored(&mut dst[vecs..]); if vecs == dst.len() { break; } diff --git a/src/common/io/mod.rs b/src/common/io/mod.rs index 61dd038cc2..2e6d506153 100644 --- a/src/common/io/mod.rs +++ b/src/common/io/mod.rs @@ -1,4 +1,3 @@ mod rewind; pub(crate) use self::rewind::Rewind; -pub(crate) const MAX_WRITEV_BUFS: usize = 64; diff --git a/src/proto/h1/encode.rs b/src/proto/h1/encode.rs index 0f0ccca73e..c8ed99bbbd 100644 --- a/src/proto/h1/encode.rs +++ b/src/proto/h1/encode.rs @@ -229,12 +229,12 @@ where } #[inline] - fn bytes(&self) -> &[u8] { + fn chunk(&self) -> &[u8] { match self.kind { - BufKind::Exact(ref b) => b.bytes(), - BufKind::Limited(ref b) => b.bytes(), - BufKind::Chunked(ref b) => b.bytes(), - BufKind::ChunkedEnd(ref b) => b.bytes(), + BufKind::Exact(ref b) => b.chunk(), + BufKind::Limited(ref b) => b.chunk(), + BufKind::Chunked(ref b) => b.chunk(), + BufKind::ChunkedEnd(ref b) => b.chunk(), } } @@ -249,12 +249,12 @@ where } #[inline] - fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { match self.kind { - BufKind::Exact(ref b) => b.bytes_vectored(dst), - BufKind::Limited(ref b) => b.bytes_vectored(dst), - BufKind::Chunked(ref b) => b.bytes_vectored(dst), - BufKind::ChunkedEnd(ref b) => b.bytes_vectored(dst), + BufKind::Exact(ref b) => b.chunks_vectored(dst), + BufKind::Limited(ref b) => b.chunks_vectored(dst), + BufKind::Chunked(ref b) => b.chunks_vectored(dst), + BufKind::ChunkedEnd(ref b) => b.chunks_vectored(dst), } } } @@ -295,7 +295,7 @@ impl Buf for ChunkSize { } #[inline] - fn bytes(&self) -> &[u8] { + fn chunk(&self) -> &[u8] { &self.bytes[self.pos.into()..self.len.into()] } diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index ed10374e18..b42fc81e3c 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -186,7 +186,7 @@ where self.read_buf.reserve(next); } - let dst = self.read_buf.bytes_mut(); + let dst = self.read_buf.chunk_mut(); let dst = unsafe { &mut *(dst as *mut _ as *mut [MaybeUninit]) }; let mut buf = ReadBuf::uninit(dst); match Pin::new(&mut self.io).poll_read(cx, &mut buf) { @@ -231,10 +231,11 @@ where return self.poll_flush_flattened(cx); } + const MAX_WRITEV_BUFS: usize = 64; loop { let n = { - let mut iovs = [IoSlice::new(&[]); crate::common::io::MAX_WRITEV_BUFS]; - let len = self.write_buf.bytes_vectored(&mut iovs); + let mut iovs = [IoSlice::new(&[]); MAX_WRITEV_BUFS]; + let len = self.write_buf.chunks_vectored(&mut iovs); ready!(Pin::new(&mut self.io).poll_write_vectored(cx, &iovs[..len]))? }; // TODO(eliza): we have to do this manually because @@ -262,7 +263,7 @@ where /// that skips some bookkeeping around using multiple buffers. fn poll_flush_flattened(&mut self, cx: &mut task::Context<'_>) -> Poll> { loop { - let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.bytes()))?; + let n = ready!(Pin::new(&mut self.io).poll_write(cx, self.write_buf.headers.chunk()))?; debug!("flushed {} bytes", n); self.write_buf.headers.advance(n); if self.write_buf.headers.remaining() == 0 { @@ -433,7 +434,7 @@ impl> Buf for Cursor { } #[inline] - fn bytes(&self) -> &[u8] { + fn chunk(&self) -> &[u8] { &self.bytes.as_ref()[self.pos..] } @@ -487,7 +488,7 @@ where //but accomplishes the same result. loop { let adv = { - let slice = buf.bytes(); + let slice = buf.chunk(); if slice.is_empty() { return; } @@ -534,12 +535,12 @@ impl Buf for WriteBuf { } #[inline] - fn bytes(&self) -> &[u8] { - let headers = self.headers.bytes(); + fn chunk(&self) -> &[u8] { + let headers = self.headers.chunk(); if !headers.is_empty() { headers } else { - self.queue.bytes() + self.queue.chunk() } } @@ -559,9 +560,9 @@ impl Buf for WriteBuf { } #[inline] - fn bytes_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { - let n = self.headers.bytes_vectored(dst); - self.queue.bytes_vectored(&mut dst[n..]) + n + fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize { + let n = self.headers.chunks_vectored(dst); + self.queue.chunks_vectored(&mut dst[n..]) + n } } diff --git a/src/proto/h2/mod.rs b/src/proto/h2/mod.rs index 38b1ab350d..cf06592903 100644 --- a/src/proto/h2/mod.rs +++ b/src/proto/h2/mod.rs @@ -257,8 +257,8 @@ impl Buf for SendBuf { } #[inline] - fn bytes(&self) -> &[u8] { - self.0.as_ref().map(|b| b.bytes()).unwrap_or(&[]) + fn chunk(&self) -> &[u8] { + self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[]) } #[inline] @@ -268,7 +268,7 @@ impl Buf for SendBuf { } } - fn bytes_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { - self.0.as_ref().map(|b| b.bytes_vectored(dst)).unwrap_or(0) + fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize { + self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0) } } diff --git a/src/proto/h2/ping.rs b/src/proto/h2/ping.rs index 7cbd23ed21..105fc69a39 100644 --- a/src/proto/h2/ping.rs +++ b/src/proto/h2/ping.rs @@ -60,7 +60,7 @@ pub(super) fn channel(ping_pong: PingPong, config: Config) -> (Recorder, Ponger) interval, timeout: config.keep_alive_timeout, while_idle: config.keep_alive_while_idle, - timer: tokio::time::sleep(interval), + timer: Box::pin(tokio::time::sleep(interval)), state: KeepAliveState::Init, }); @@ -156,7 +156,7 @@ struct KeepAlive { while_idle: bool, state: KeepAliveState, - timer: Sleep, + timer: Pin>, } #[cfg(feature = "runtime")] @@ -441,7 +441,7 @@ impl KeepAlive { self.state = KeepAliveState::Scheduled; let interval = shared.last_read_at() + self.interval; - self.timer.reset(interval); + self.timer.as_mut().reset(interval); } KeepAliveState::PingSent => { if shared.is_ping_sent() { @@ -450,7 +450,7 @@ impl KeepAlive { self.state = KeepAliveState::Scheduled; let interval = shared.last_read_at() + self.interval; - self.timer.reset(interval); + self.timer.as_mut().reset(interval); } KeepAliveState::Scheduled => (), } @@ -472,7 +472,7 @@ impl KeepAlive { shared.send_ping(); self.state = KeepAliveState::PingSent; let timeout = Instant::now() + self.timeout; - self.timer.reset(timeout); + self.timer.as_mut().reset(timeout); } KeepAliveState::Init | KeepAliveState::PingSent => (), } diff --git a/src/server/tcp.rs b/src/server/tcp.rs index 4111573671..52d68e62b4 100644 --- a/src/server/tcp.rs +++ b/src/server/tcp.rs @@ -19,7 +19,7 @@ pub struct AddrIncoming { sleep_on_errors: bool, tcp_keepalive_timeout: Option, tcp_nodelay: bool, - timeout: Option, + timeout: Option>>, } impl AddrIncoming { @@ -160,9 +160,9 @@ impl AddrIncoming { error!("accept error: {}", e); // Sleep 1s. - let mut timeout = tokio::time::sleep(Duration::from_secs(1)); + let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1))); - match Pin::new(&mut timeout).poll(cx) { + match timeout.as_mut().poll(cx) { Poll::Ready(()) => { // Wow, it's been a second already? Ok then... continue; @@ -263,7 +263,7 @@ mod addr_stream { pub fn poll_peek( &mut self, cx: &mut task::Context<'_>, - buf: &mut [u8], + buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { self.inner.poll_peek(cx, buf) } diff --git a/src/upgrade.rs b/src/upgrade.rs index 46ce37fcf8..a981b912ee 100644 --- a/src/upgrade.rs +++ b/src/upgrade.rs @@ -11,7 +11,7 @@ use std::fmt; use std::io; use std::marker::Unpin; -use bytes::{Buf, Bytes}; +use bytes::Bytes; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::sync::oneshot; @@ -82,7 +82,7 @@ impl Upgraded { T: AsyncRead + AsyncWrite + Unpin + Send + 'static, { Upgraded { - io: Rewind::new_buffered(Box::new(ForwardsWriteBuf(io)), read_buf), + io: Rewind::new_buffered(Box::new(io), read_buf), } } @@ -92,9 +92,9 @@ impl Upgraded { /// `Upgraded` back. pub fn downcast(self) -> Result, Self> { let (io, buf) = self.io.into_inner(); - match io.__hyper_downcast::>() { + match io.__hyper_downcast() { Ok(t) => Ok(Parts { - io: t.0, + io: *t, read_buf: buf, _inner: (), }), @@ -221,20 +221,14 @@ impl StdError for UpgradeExpected {} // ===== impl Io ===== -struct ForwardsWriteBuf(T); - pub(crate) trait Io: AsyncRead + AsyncWrite + Unpin + 'static { - fn poll_write_dyn_buf( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut dyn Buf, - ) -> Poll>; - fn __hyper_type_id(&self) -> TypeId { TypeId::of::() } } +impl Io for T {} + impl dyn Io + Send { fn __hyper_is(&self) -> bool { let t = TypeId::of::(); @@ -254,61 +248,6 @@ impl dyn Io + Send { } } -impl AsyncRead for ForwardsWriteBuf { - fn poll_read( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - Pin::new(&mut self.0).poll_read(cx, buf) - } -} - -impl AsyncWrite for ForwardsWriteBuf { - fn poll_write( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - buf: &[u8], - ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) - } - - fn poll_write_vectored( - mut self: Pin<&mut Self>, - cx: &mut task::Context<'_>, - bufs: &[io::IoSlice<'_>], - ) -> Poll> { - Pin::new(&mut self.0).poll_write_vectored(cx, bufs) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) - } - - fn is_write_vectored(&self) -> bool { - self.0.is_write_vectored() - } -} - -impl Io for ForwardsWriteBuf { - fn poll_write_dyn_buf( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut dyn Buf, - ) -> Poll> { - if self.0.is_write_vectored() { - let mut bufs = [io::IoSlice::new(&[]); crate::common::io::MAX_WRITEV_BUFS]; - let cnt = buf.bytes_vectored(&mut bufs); - return Pin::new(&mut self.0).poll_write_vectored(cx, &bufs[..cnt]); - } - Pin::new(&mut self.0).poll_write(cx, buf.bytes()) - } -} - mod sealed { use super::OnUpgrade; @@ -352,7 +291,6 @@ mod sealed { #[cfg(test)] mod tests { use super::*; - use tokio::io::AsyncWriteExt; #[test] fn upgraded_downcast() { @@ -363,15 +301,6 @@ mod tests { upgraded.downcast::().unwrap(); } - #[tokio::test] - async fn upgraded_forwards_write_buf() { - // sanity check that the underlying IO implements write_buf - Mock.write_buf(&mut "hello".as_bytes()).await.unwrap(); - - let mut upgraded = Upgraded::new(Mock, Bytes::new()); - upgraded.write_buf(&mut "hello".as_bytes()).await.unwrap(); - } - // TODO: replace with tokio_test::io when it can test write_buf struct Mock; @@ -395,17 +324,6 @@ mod tests { Poll::Ready(Ok(buf.len())) } - // TODO(eliza): :( - // fn poll_write_buf( - // self: Pin<&mut Self>, - // _cx: &mut task::Context<'_>, - // buf: &mut B, - // ) -> Poll> { - // let n = buf.remaining(); - // buf.advance(n); - // Poll::Ready(Ok(n)) - // } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut task::Context<'_>) -> Poll> { unreachable!("Mock::poll_flush") } diff --git a/tests/client.rs b/tests/client.rs index d5092e3584..409a1622c0 100644 --- a/tests/client.rs +++ b/tests/client.rs @@ -1209,6 +1209,7 @@ mod dispatch_impl { // and wait a few ticks for the connections to close let t = tokio::time::sleep(Duration::from_millis(100)).map(|_| panic!("time out")); + futures_util::pin_mut!(t); let close = closes.into_future().map(|(opt, _)| opt.expect("closes")); future::select(t, close).await; } @@ -1257,6 +1258,7 @@ mod dispatch_impl { // res now dropped let t = tokio::time::sleep(Duration::from_millis(100)).map(|_| panic!("time out")); + futures_util::pin_mut!(t); let close = closes.into_future().map(|(opt, _)| opt.expect("closes")); future::select(t, close).await; } @@ -1312,6 +1314,7 @@ mod dispatch_impl { // and wait a few ticks to see the connection drop let t = tokio::time::sleep(Duration::from_millis(100)).map(|_| panic!("time out")); + futures_util::pin_mut!(t); let close = closes.into_future().map(|(opt, _)| opt.expect("closes")); future::select(t, close).await; } @@ -1362,6 +1365,7 @@ mod dispatch_impl { res.unwrap(); let t = tokio::time::sleep(Duration::from_millis(100)).map(|_| panic!("time out")); + futures_util::pin_mut!(t); let close = closes.into_future().map(|(opt, _)| opt.expect("closes")); future::select(t, close).await; } @@ -1408,6 +1412,7 @@ mod dispatch_impl { res.unwrap(); let t = tokio::time::sleep(Duration::from_millis(100)).map(|_| panic!("time out")); + futures_util::pin_mut!(t); let close = closes.into_future().map(|(opt, _)| opt.expect("closes")); future::select(t, close).await; }