Skip to content

Commit

Permalink
feat(rt): replace IO traits with hyper::rt ones (hyperium#3230)
Browse files Browse the repository at this point in the history
This replaces the usage of `tokio::io::{AsyncRead, AsyncWrite}` in hyper's public API with new traits in the `hyper::rt` module.

Closes hyperium#3110

BREAKING CHANGE: Any IO transport type provided must not implement `hyper::rt::{Read, Write}` instead of
  `tokio::io` traits. You can grab a helper type from `hyper-util` to wrap Tokio types, or implement the traits yourself,
  if it's a custom type.

Signed-off-by: Sven Pfennig <[email protected]>
  • Loading branch information
seanmonstar authored and 0xE282B0 committed Jan 16, 2024
1 parent 068116c commit b6f1200
Show file tree
Hide file tree
Showing 43 changed files with 1,015 additions and 292 deletions.
14 changes: 8 additions & 6 deletions benches/end_to_end.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
extern crate test;
mod support;

// TODO: Reimplement Opts::bench using hyper::server::conn and hyper::client::conn
// (instead of Server and HttpClient).
// TODO: Reimplement parallel for HTTP/1

use std::convert::Infallible;
use std::net::SocketAddr;
Expand Down Expand Up @@ -315,7 +314,8 @@ impl Opts {

let mut client = rt.block_on(async {
if self.http2 {
let io = tokio::net::TcpStream::connect(&addr).await.unwrap();
let tcp = tokio::net::TcpStream::connect(&addr).await.unwrap();
let io = support::TokioIo::new(tcp);
let (tx, conn) = hyper::client::conn::http2::Builder::new(support::TokioExecutor)
.initial_stream_window_size(self.http2_stream_window)
.initial_connection_window_size(self.http2_conn_window)
Expand All @@ -328,7 +328,8 @@ impl Opts {
} else if self.parallel_cnt > 1 {
todo!("http/1 parallel >1");
} else {
let io = tokio::net::TcpStream::connect(&addr).await.unwrap();
let tcp = tokio::net::TcpStream::connect(&addr).await.unwrap();
let io = support::TokioIo::new(tcp);
let (tx, conn) = hyper::client::conn::http1::Builder::new()
.handshake(io)
.await
Expand Down Expand Up @@ -414,14 +415,15 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr {
let opts = opts.clone();
rt.spawn(async move {
while let Ok((sock, _)) = listener.accept().await {
let io = support::TokioIo::new(sock);
if opts.http2 {
tokio::spawn(
hyper::server::conn::http2::Builder::new(support::TokioExecutor)
.initial_stream_window_size(opts.http2_stream_window)
.initial_connection_window_size(opts.http2_conn_window)
.adaptive_window(opts.http2_adaptive_window)
.serve_connection(
sock,
io,
service_fn(move |req: Request<hyper::body::Incoming>| async move {
let mut req_body = req.into_body();
while let Some(_chunk) = req_body.frame().await {}
Expand All @@ -433,7 +435,7 @@ fn spawn_server(rt: &tokio::runtime::Runtime, opts: &Opts) -> SocketAddr {
);
} else {
tokio::spawn(hyper::server::conn::http1::Builder::new().serve_connection(
sock,
io,
service_fn(move |req: Request<hyper::body::Incoming>| async move {
let mut req_body = req.into_body();
while let Some(_chunk) = req_body.frame().await {}
Expand Down
5 changes: 4 additions & 1 deletion benches/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

extern crate test;

mod support;

use std::convert::Infallible;
use std::io::{Read, Write};
use std::net::{SocketAddr, TcpStream};
Expand Down Expand Up @@ -40,11 +42,12 @@ fn hello_world_16(b: &mut test::Bencher) {
rt.spawn(async move {
loop {
let (stream, _addr) = listener.accept().await.expect("accept");
let io = support::TokioIo::new(stream);

http1::Builder::new()
.pipeline_flush(true)
.serve_connection(
stream,
io,
service_fn(|_| async {
Ok::<_, Infallible>(Response::new(Full::new(Bytes::from(
"Hello, World!",
Expand Down
5 changes: 4 additions & 1 deletion benches/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

extern crate test;

mod support;

use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::mpsc;
Expand Down Expand Up @@ -38,10 +40,11 @@ macro_rules! bench_server {
rt.spawn(async move {
loop {
let (stream, _) = listener.accept().await.expect("accept");
let io = support::TokioIo::new(stream);

http1::Builder::new()
.serve_connection(
stream,
io,
service_fn(|_| async {
Ok::<_, hyper::Error>(
Response::builder()
Expand Down
2 changes: 1 addition & 1 deletion benches/support/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
mod tokiort;
pub use tokiort::{TokioExecutor, TokioTimer};
pub use tokiort::{TokioExecutor, TokioIo, TokioTimer};
146 changes: 146 additions & 0 deletions benches/support/tokiort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,149 @@ impl TokioSleep {
self.project().inner.as_mut().reset(deadline.into());
}
}

pin_project! {
#[derive(Debug)]
pub struct TokioIo<T> {
#[pin]
inner: T,
}
}

impl<T> TokioIo<T> {
pub fn new(inner: T) -> Self {
Self { inner }
}

pub fn inner(self) -> T {
self.inner
}
}

impl<T> hyper::rt::Read for TokioIo<T>
where
T: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: hyper::rt::ReadBufCursor<'_>,
) -> Poll<Result<(), std::io::Error>> {
let n = unsafe {
let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
Poll::Ready(Ok(())) => tbuf.filled().len(),
other => return other,
}
};

unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}

impl<T> hyper::rt::Write for TokioIo<T>
where
T: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
tokio::io::AsyncWrite::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
}
}

impl<T> tokio::io::AsyncRead for TokioIo<T>
where
T: hyper::rt::Read,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
tbuf: &mut tokio::io::ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
//let init = tbuf.initialized().len();
let filled = tbuf.filled().len();
let sub_filled = unsafe {
let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());

match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
Poll::Ready(Ok(())) => buf.filled().len(),
other => return other,
}
};

let n_filled = filled + sub_filled;
// At least sub_filled bytes had to have been initialized.
let n_init = sub_filled;
unsafe {
tbuf.assume_init(n_init);
tbuf.set_filled(n_filled);
}

Poll::Ready(Ok(()))
}
}

impl<T> tokio::io::AsyncWrite for TokioIo<T>
where
T: hyper::rt::Write,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write(self.project().inner, cx, buf)
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_flush(self.project().inner, cx)
}

fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
hyper::rt::Write::poll_shutdown(self.project().inner, cx)
}

fn is_write_vectored(&self) -> bool {
hyper::rt::Write::is_write_vectored(&self.inner)
}

fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
}
}
7 changes: 6 additions & 1 deletion examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use http_body_util::{BodyExt, Empty};
use hyper::Request;
use tokio::net::TcpStream;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// A simple type alias so as to DRY.
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

Expand Down Expand Up @@ -39,8 +43,9 @@ async fn fetch_url(url: hyper::Uri) -> Result<()> {
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
let stream = TcpStream::connect(addr).await?;
let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?;
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand Down
7 changes: 6 additions & 1 deletion examples/client_json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ use hyper::{body::Buf, Request};
use serde::Deserialize;
use tokio::net::TcpStream;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// A simple type alias so as to DRY.
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;

Expand All @@ -29,8 +33,9 @@ async fn fetch_json(url: hyper::Uri) -> Result<Vec<User>> {
let addr = format!("{}:{}", host, port);

let stream = TcpStream::connect(addr).await?;
let io = TokioIo::new(stream);

let (mut sender, conn) = hyper::client::conn::http1::handshake(stream).await?;
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand Down
7 changes: 6 additions & 1 deletion examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use hyper::service::service_fn;
use hyper::{body::Body, Method, Request, Response, StatusCode};
use tokio::net::TcpListener;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

/// This is our service handler. It receives a Request, routes on its
/// path, and returns a Future of a Response.
async fn echo(
Expand Down Expand Up @@ -92,10 +96,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
println!("Listening on http://{}", addr);
loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(echo))
.serve_connection(io, service_fn(echo))
.await
{
println!("Error serving connection: {:?}", err);
Expand Down
14 changes: 8 additions & 6 deletions examples/gateway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use hyper::{server::conn::http1, service::service_fn};
use std::net::SocketAddr;
use tokio::net::{TcpListener, TcpStream};

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

#[tokio::main(flavor="current_thread")]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
pretty_env_logger::init();
Expand All @@ -20,6 +24,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

loop {
let (stream, _) = listener.accept().await?;
let io = TokioIo::new(stream);

// This is the `Service` that will handle the connection.
// `service_fn` is a helper to convert a function that
Expand All @@ -42,9 +47,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {

async move {
let client_stream = TcpStream::connect(addr).await.unwrap();
let io = TokioIo::new(client_stream);

let (mut sender, conn) =
hyper::client::conn::http1::handshake(client_stream).await?;
let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
println!("Connection failed: {:?}", err);
Expand All @@ -56,10 +61,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

tokio::task::spawn(async move {
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service)
.await
{
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
println!("Failed to serve the connection: {:?}", err);
}
});
Expand Down
11 changes: 9 additions & 2 deletions examples/hello.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ use hyper::service::service_fn;
use hyper::{Request, Response};
use tokio::net::TcpListener;

#[path = "../benches/support/mod.rs"]
mod support;
use support::TokioIo;

// An async function that consumes a request, does nothing with it and returns a
// response.
async fn hello(_: Request<hyper::body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Expand All @@ -35,7 +39,10 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// has work to do. In this case, a connection arrives on the port we are listening on and
// the task is woken up, at which point the task is then put back on a thread, and is
// driven forward by the runtime, eventually yielding a TCP stream.
let (stream, _) = listener.accept().await?;
let (tcp, _) = listener.accept().await?;
// Use an adapter to access something implementing `tokio::io` traits as if they implement
// `hyper::rt` IO traits.
let io = TokioIo::new(tcp);

// Spin up a new task in Tokio so we can continue to listen for new TCP connection on the
// current task without waiting for the processing of the HTTP1 connection we just received
Expand All @@ -44,7 +51,7 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
// Handle the connection from the client using HTTP1 and pass any
// HTTP requests received on that connection to the `hello` function
if let Err(err) = http1::Builder::new()
.serve_connection(stream, service_fn(hello))
.serve_connection(io, service_fn(hello))
.await
{
println!("Error serving connection: {:?}", err);
Expand Down
Loading

0 comments on commit b6f1200

Please sign in to comment.