Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions src/proxy/h2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
// limitations under the License.

use crate::copy;
use bytes::{BufMut, Bytes};
use bytes::Bytes;
use futures_core::ready;
use h2::Reason;
use std::io::Error;
Expand Down Expand Up @@ -85,7 +85,10 @@ pub struct H2StreamWriteHalf {
_dropped: Option<DropCounter>,
}

pub struct TokioH2Stream(H2Stream);
pub struct TokioH2Stream {
stream: H2Stream,
buf: Bytes,
}

struct DropCounter {
// Whether the other end of this shared counter has already dropped.
Expand Down Expand Up @@ -144,7 +147,10 @@ impl Drop for DropCounter {
// then the specific implementation will conflict with the generic one.
impl TokioH2Stream {
pub fn new(stream: H2Stream) -> Self {
Self(stream)
Self {
stream,
buf: Bytes::new(),
}
}
}

Expand All @@ -154,21 +160,21 @@ impl tokio::io::AsyncRead for TokioH2Stream {
cx: &mut Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let pinned = std::pin::Pin::new(&mut self.0.read);
copy::ResizeBufRead::poll_bytes(pinned, cx).map(|r| match r {
Ok(bytes) => {
if buf.remaining() < bytes.len() {
Err(Error::other(format!(
"kould overflow buffer of with {} remaining",
buf.remaining()
)))
} else {
buf.put(bytes);
Ok(())
}
}
Err(e) => Err(e),
})
// Just return the bytes we have left over and don't poll the stream because
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified (but give equivilent) like so:

        let this = &mut self;
        if this.buf.is_empty() {
            let res = ready!(Pin::new(&mut this.stream.read).poll_bytes(cx))?;
            this.buf = res;
        }
        let cnt = std::cmp::min(this.buf.len(), read_buf.remaining());
        read_buf.put_slice(&this.buf[..cnt]);
        this.buf = this.buf.split_off(cnt);
        Poll::Ready(Ok(()))

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly agree. What's the point of let this = &mut self;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not needed and was just copy and pasted from others and not cleaned up. Its common to need to do things like that for projections

// its unclear what to do if there are bytes left over from the previous read, and when we
// poll, we get an error.
if self.buf.is_empty() {
// If we have no unread bytes, we can poll the stream
// and fill self.buf with the bytes we read.
let pinned = std::pin::Pin::new(&mut self.stream.read);
let res = ready!(copy::ResizeBufRead::poll_bytes(pinned, cx))?;
self.buf = res;
}
// Copy as many bytes as we can from self.buf.
let cnt = Ord::min(buf.remaining(), self.buf.len());
buf.put_slice(&self.buf[..cnt]);
self.buf = self.buf.split_off(cnt);
Poll::Ready(Ok(()))
}
}

Expand All @@ -178,7 +184,7 @@ impl tokio::io::AsyncWrite for TokioH2Stream {
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
let pinned = std::pin::Pin::new(&mut self.0.write);
let pinned = std::pin::Pin::new(&mut self.stream.write);
let buf = Bytes::copy_from_slice(buf);
copy::AsyncWriteBuf::poll_write_buf(pinned, cx, buf)
}
Expand All @@ -187,15 +193,15 @@ impl tokio::io::AsyncWrite for TokioH2Stream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let pinned = std::pin::Pin::new(&mut self.0.write);
let pinned = std::pin::Pin::new(&mut self.stream.write);
copy::AsyncWriteBuf::poll_flush(pinned, cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let pinned = std::pin::Pin::new(&mut self.0.write);
let pinned = std::pin::Pin::new(&mut self.stream.write);
copy::AsyncWriteBuf::poll_shutdown(pinned, cx)
}
}
Expand Down
34 changes: 25 additions & 9 deletions src/proxy/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,9 +594,10 @@ mod test {
}

/// This is really a test for TokioH2Stream, but its nicer here because we have access to
/// streams
/// streams.
/// Most important, we make sure there are no panics.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn small_reads() {
async fn read_buffering() {
let (mut pool, srv) = setup_test(3).await;

let key = key(&srv, 2);
Expand All @@ -612,13 +613,28 @@ mod test {
let c = pool.send_request_pooled(&key.clone(), req()).await.unwrap();
let mut c = TokioH2Stream::new(c);
c.write_all(b"abcde").await.unwrap();
let mut b = [0u8; 0];
// Crucially, this should error rather than panic.
if let Err(e) = c.read(&mut b).await {
assert_eq!(e.kind(), io::ErrorKind::Other);
} else {
panic!("Should have errored");
}
let mut b = [0u8; 100];
// Properly buffer reads and don't error
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WeirdIO may be useful here

assert_eq!(c.read(&mut b).await.unwrap(), 8);
assert_eq!(&b[..8], b"poolsrv\n"); // this is added by itself
assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1);
assert_eq!(&b[..1], b"a");
assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1);
assert_eq!(&b[..1], b"b");
assert_eq!(c.read(&mut b[..1]).await.unwrap(), 1);
assert_eq!(&b[..1], b"c");
assert_eq!(c.read(&mut b).await.unwrap(), 2); // there are only two bytes left
assert_eq!(&b[..2], b"de");

// Once we drop the pool, we should still retained the buffered data,
// but then we should error.
c.write_all(b"abcde").await.unwrap();
assert_eq!(c.read(&mut b[..3]).await.unwrap(), 3);
assert_eq!(&b[..3], b"abc");
drop(pool);
assert_eq!(c.read(&mut b[..2]).await.unwrap(), 2);
assert_eq!(&b[..2], b"de");
assert!(c.read(&mut b).await.is_err());
}

#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
Expand Down