diff --git a/Cargo.toml b/Cargo.toml index a6da142500..35a7468d7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -193,12 +193,15 @@ name = "upgrades" path = "examples/upgrades.rs" required-features = ["full"] - [[example]] name = "web_api" path = "examples/web_api.rs" required-features = ["full"] +[[example]] +name = "routing" +path = "examples/routing.rs" +required-features = ["full"] [[bench]] name = "body" diff --git a/examples/routing.rs b/examples/routing.rs new file mode 100644 index 0000000000..78fc2ce015 --- /dev/null +++ b/examples/routing.rs @@ -0,0 +1,227 @@ +#![deny(warnings)] + +use std::collections::HashMap; +use std::future::Future; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::atomic::AtomicU64; +use std::sync::Arc; + +use bytes::{Buf, Bytes}; +use http_body_util::{BodyExt, Full}; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{body::Incoming as IncomingBody, header, Method, Request, Response, StatusCode}; +use tokio::net::{TcpListener, TcpStream}; + +#[path = "../benches/support/mod.rs"] +mod support; +use support::TokioIo; + +type GenericError = Box; +type Result = std::result::Result; +type BoxBody = http_body_util::combinators::BoxBody; + +static INDEX: &[u8] = b"test.html"; +static INTERNAL_SERVER_ERROR: &[u8] = b"Internal Server Error"; +static NOTFOUND: &[u8] = b"Not Found"; +static POST_DATA: &str = r#"{"original": "data"}"#; +static URL: &str = "http://127.0.0.1:1337/json_api"; + +async fn client_request_response( + _state: Arc, + _req: Request, +) -> Result> { + let req = Request::builder() + .method(Method::POST) + .uri(URL) + .header(header::CONTENT_TYPE, "application/json") + .body(Full::new(Bytes::from(POST_DATA))) + .unwrap(); + + let host = req.uri().host().expect("uri has no host"); + let port = req.uri().port_u16().expect("uri has no port"); + let stream = TcpStream::connect(format!("{}:{}", host, port)).await?; + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await?; + + tokio::task::spawn(async move { + if let Err(err) = conn.await { + println!("Connection error: {:?}", err); + } + }); + + let web_res = sender.send_request(req).await?; + + let res_body = web_res.into_body().boxed(); + + Ok(Response::new(res_body)) +} + +async fn api_post_response( + state: Arc, + req: Request, +) -> Result> { + // Aggregate the body... + let whole_body = req.collect().await?.aggregate(); + // Decode as JSON... + let mut data: serde_json::Value = serde_json::from_reader(whole_body.reader())?; + // Change the JSON... + data["test"] = serde_json::Value::from("test_value"); + // And respond with the new JSON. + let json = serde_json::to_string(&data)?; + let response = Response::builder() + .status(StatusCode::OK) + .header(header::CONTENT_TYPE, "application/json") + .body(full(json))?; + state.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + Ok(response) +} + +async fn api_get_response( + state: Arc, + _req: Request, +) -> Result> { + let responses = state.load(std::sync::atomic::Ordering::Relaxed).to_string(); + let data = vec!["foo", "bar", responses.as_str()]; + let res = match serde_json::to_string(&data) { + Ok(json) => Response::builder() + .header(header::CONTENT_TYPE, "application/json") + .body(full(json)) + .unwrap(), + Err(_) => Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(full(INTERNAL_SERVER_ERROR)) + .unwrap(), + }; + Ok(res) +} + +fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +type Routes = HashMap<(Method, String), Box>>; + +#[derive(Clone)] +struct Router { + routes: Arc>, + fallback: Pin>>, + state: S, +} + +impl Router { + async fn handle_request(self, req: Request) -> Result> { + // https://stackoverflow.com/questions/45786717 for a saner impl that is outside of the scope of this example + if let Some(handler) = self + .routes + .get(&(req.method().clone(), req.uri().path().to_owned())) + { + handler.handle(self.state, req).await + } else { + self.fallback.handle(self.state, req).await + } + } +} + +struct RouterBuilder { + routes: Routes, + fallback: Pin>>, + state: S, +} + +trait Handler: Send + Sync { + fn handle( + &self, + state: S, + req: Request, + ) -> Pin>> + Send>>; +} + +impl Handler for F +where + F: Fn(S, Request) -> Fut + Send + Sync, + Fut: Future>> + Send + 'static, +{ + fn handle( + &self, + state: S, + req: Request, + ) -> Pin>> + Send>> { + Box::pin(self(state, req)) + } +} + +impl RouterBuilder { + fn new(state: S) -> Self { + Self { + routes: HashMap::new(), + fallback: Arc::pin(|_, _| { + Box::pin(async { + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .body(full(NOTFOUND)) + .unwrap()) + }) + }), + state, + } + } + + fn get(self, path: impl Into, handler: impl Handler + 'static) -> Self { + self.route(Method::GET, path.into(), handler) + } + + fn post(self, path: impl Into, handler: impl Handler + 'static) -> Self { + self.route(Method::POST, path.into(), handler) + } + + fn route(mut self, method: Method, path: String, handler: impl Handler + 'static) -> Self { + self.routes.insert((method, path), Box::new(handler)); + self + } + + fn build(self) -> Router { + Router { + routes: Arc::new(self.routes), + fallback: self.fallback, + state: self.state, + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + pretty_env_logger::init(); + + let router = RouterBuilder::new(Arc::new(AtomicU64::new(0))) + .get("/", |_, _| async { Ok(Response::new(full(INDEX))) }) + .get("/index.html", |_, _| async { + Ok(Response::new(full(INDEX))) + }) + .get("/test.html", client_request_response) + .post("/json_api", api_post_response) + .get("/json_api", api_get_response) + .build(); + + let addr = SocketAddr::from(([127, 0, 0, 1], 1337)); + + let listener = TcpListener::bind(&addr).await?; + println!("Listening on http://{}", addr); + loop { + let (stream, _) = listener.accept().await?; + let io = TokioIo::new(stream); + let tmp_router = router.clone(); + + tokio::task::spawn(async move { + let service = service_fn(|req| tmp_router.clone().handle_request(req)); + + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + println!("Failed to serve connection: {:?}", err); + } + }); + } +}