diff --git a/src/proxy/h2.rs b/src/proxy/h2.rs index 63a87a4695..6a82d437cb 100644 --- a/src/proxy/h2.rs +++ b/src/proxy/h2.rs @@ -13,7 +13,7 @@ // limitations under the License. use crate::copy; -use bytes::{BufMut, Bytes}; +use bytes::Bytes; use futures_core::ready; use h2::Reason; use std::io::Error; @@ -85,7 +85,10 @@ pub struct H2StreamWriteHalf { _dropped: Option, } -pub struct TokioH2Stream(H2Stream); +pub struct TokioH2Stream { + stream: H2Stream, + buf: Bytes, +} struct DropCounter { // Whether the other end of this shared counter has already dropped. @@ -144,7 +147,10 @@ impl Drop for DropCounter { // then the specific implementation will conflict with the generic one. impl TokioH2Stream { pub fn new(stream: H2Stream) -> Self { - Self(stream) + Self { + stream, + buf: Bytes::new(), + } } } @@ -154,21 +160,21 @@ impl tokio::io::AsyncRead for TokioH2Stream { cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll> { - let pinned = std::pin::Pin::new(&mut self.0.read); - copy::ResizeBufRead::poll_bytes(pinned, cx).map(|r| match r { - Ok(bytes) => { - if buf.remaining() < bytes.len() { - Err(Error::other(format!( - "kould overflow buffer of with {} remaining", - buf.remaining() - ))) - } else { - buf.put(bytes); - Ok(()) - } - } - Err(e) => Err(e), - }) + // Just return the bytes we have left over and don't poll the stream because + // its unclear what to do if there are bytes left over from the previous read, and when we + // poll, we get an error. + if self.buf.is_empty() { + // If we have no unread bytes, we can poll the stream + // and fill self.buf with the bytes we read. + let pinned = std::pin::Pin::new(&mut self.stream.read); + let res = ready!(copy::ResizeBufRead::poll_bytes(pinned, cx))?; + self.buf = res; + } + // Copy as many bytes as we can from self.buf. + let cnt = Ord::min(buf.remaining(), self.buf.len()); + buf.put_slice(&self.buf[..cnt]); + self.buf = self.buf.split_off(cnt); + Poll::Ready(Ok(())) } } @@ -178,7 +184,7 @@ impl tokio::io::AsyncWrite for TokioH2Stream { cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - let pinned = std::pin::Pin::new(&mut self.0.write); + let pinned = std::pin::Pin::new(&mut self.stream.write); let buf = Bytes::copy_from_slice(buf); copy::AsyncWriteBuf::poll_write_buf(pinned, cx, buf) } @@ -187,7 +193,7 @@ impl tokio::io::AsyncWrite for TokioH2Stream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let pinned = std::pin::Pin::new(&mut self.0.write); + let pinned = std::pin::Pin::new(&mut self.stream.write); copy::AsyncWriteBuf::poll_flush(pinned, cx) } @@ -195,7 +201,7 @@ impl tokio::io::AsyncWrite for TokioH2Stream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - let pinned = std::pin::Pin::new(&mut self.0.write); + let pinned = std::pin::Pin::new(&mut self.stream.write); copy::AsyncWriteBuf::poll_shutdown(pinned, cx) } } diff --git a/src/proxy/pool.rs b/src/proxy/pool.rs index 4b55b9dd6a..d19ef0b419 100644 --- a/src/proxy/pool.rs +++ b/src/proxy/pool.rs @@ -594,9 +594,10 @@ mod test { } /// This is really a test for TokioH2Stream, but its nicer here because we have access to - /// streams + /// streams. + /// Most important, we make sure there are no panics. #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn small_reads() { + async fn read_buffering() { let (mut pool, srv) = setup_test(3).await; let key = key(&srv, 2); @@ -612,13 +613,28 @@ mod test { let c = pool.send_request_pooled(&key.clone(), req()).await.unwrap(); let mut c = TokioH2Stream::new(c); c.write_all(b"abcde").await.unwrap(); - let mut b = [0u8; 0]; - // Crucially, this should error rather than panic. - if let Err(e) = c.read(&mut b).await { - assert_eq!(e.kind(), io::ErrorKind::Other); - } else { - panic!("Should have errored"); - } + let mut b = [0u8; 100]; + // Properly buffer reads and don't error + assert_eq!(c.read(&mut b).await.unwrap(), 8); + assert_eq!(&b[..8], b"poolsrv\n"); // this is added by itself + assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1); + assert_eq!(&b[..1], b"a"); + assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1); + assert_eq!(&b[..1], b"b"); + assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1); + assert_eq!(&b[..1], b"c"); + assert_eq!(c.read(&mut b).await.unwrap(), 2); // there are only two bytes left + assert_eq!(&b[..2], b"de"); + + // Once we drop the pool, we should still retained the buffered data, + // but then we should error. + c.write_all(b"abcde").await.unwrap(); + assert_eq!(c.read(&mut b[..3]).await.unwrap(), 3); + assert_eq!(&b[..3], b"abc"); + drop(pool); + assert_eq!(c.read(&mut b[..2]).await.unwrap(), 2); + assert_eq!(&b[..2], b"de"); + assert!(c.read(&mut b).await.is_err()); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)]