-
Notifications
You must be signed in to change notification settings - Fork 242
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add an example of using tungstenite with a custom accept.
1 parent
61f5926
commit 4e1559a
Showing
2 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
//! A chat server that broadcasts a message to all connections. | ||
//! | ||
//! This is a simple line-based server which accepts WebSocket connections, | ||
//! reads lines from those connections, and broadcasts the lines to all other | ||
//! connected clients. | ||
//! | ||
//! You can test this out by running: | ||
//! | ||
//! cargo run --example server 127.0.0.1:12345 | ||
//! | ||
//! And then in another window run: | ||
//! | ||
//! cargo run --example client ws://127.0.0.1:12345/socket | ||
//! | ||
//! You can run the second command in multiple windows and then chat between the | ||
//! two, seeing the messages from the other client as they're received. For all | ||
//! connected clients they'll all join the same room and see everyone else's | ||
//! messages. | ||
use std::{ | ||
collections::HashMap, | ||
convert::Infallible, | ||
env, | ||
net::SocketAddr, | ||
sync::{Arc, Mutex}, | ||
}; | ||
|
||
use hyper::{ | ||
header::{ | ||
HeaderValue, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_VERSION, | ||
UPGRADE, | ||
}, | ||
server::conn::AddrStream, | ||
service::{make_service_fn, service_fn}, | ||
upgrade::Upgraded, | ||
Body, Method, Request, Response, Server, StatusCode, Version, | ||
}; | ||
|
||
use futures_channel::mpsc::{unbounded, UnboundedSender}; | ||
use futures_util::{future, pin_mut, stream::TryStreamExt, StreamExt}; | ||
|
||
use tokio_tungstenite::WebSocketStream; | ||
use tungstenite::{ | ||
handshake::derive_accept_key, | ||
protocol::{Message, Role}, | ||
}; | ||
|
||
type Tx = UnboundedSender<Message>; | ||
type PeerMap = Arc<Mutex<HashMap<SocketAddr, Tx>>>; | ||
|
||
async fn handle_connection( | ||
peer_map: PeerMap, | ||
ws_stream: WebSocketStream<Upgraded>, | ||
addr: SocketAddr, | ||
) { | ||
println!("WebSocket connection established: {}", addr); | ||
|
||
// Insert the write part of this peer to the peer map. | ||
let (tx, rx) = unbounded(); | ||
peer_map.lock().unwrap().insert(addr, tx); | ||
|
||
let (outgoing, incoming) = ws_stream.split(); | ||
|
||
let broadcast_incoming = incoming.try_for_each(|msg| { | ||
println!("Received a message from {}: {}", addr, msg.to_text().unwrap()); | ||
let peers = peer_map.lock().unwrap(); | ||
|
||
// We want to broadcast the message to everyone except ourselves. | ||
let broadcast_recipients = | ||
peers.iter().filter(|(peer_addr, _)| peer_addr != &&addr).map(|(_, ws_sink)| ws_sink); | ||
|
||
for recp in broadcast_recipients { | ||
recp.unbounded_send(msg.clone()).unwrap(); | ||
} | ||
|
||
future::ok(()) | ||
}); | ||
|
||
let receive_from_others = rx.map(Ok).forward(outgoing); | ||
|
||
pin_mut!(broadcast_incoming, receive_from_others); | ||
future::select(broadcast_incoming, receive_from_others).await; | ||
|
||
println!("{} disconnected", &addr); | ||
peer_map.lock().unwrap().remove(&addr); | ||
} | ||
|
||
async fn handle_request( | ||
peer_map: PeerMap, | ||
mut req: Request<Body>, | ||
addr: SocketAddr, | ||
) -> Result<Response<Body>, Infallible> { | ||
println!("Received a new, potentially ws handshake"); | ||
println!("The request's path is: {}", req.uri().path()); | ||
println!("The request's headers are:"); | ||
for (ref header, _value) in req.headers() { | ||
println!("* {}", header); | ||
} | ||
let upgrade = HeaderValue::from_static("Upgrade"); | ||
let websocket = HeaderValue::from_static("websocket"); | ||
let headers = req.headers(); | ||
let key = headers.get(SEC_WEBSOCKET_KEY); | ||
let derived = key.map(|k| derive_accept_key(k.as_bytes())); | ||
if req.method() != Method::GET | ||
|| req.version() < Version::HTTP_11 | ||
|| !headers | ||
.get(CONNECTION) | ||
.and_then(|h| h.to_str().ok()) | ||
.map(|h| { | ||
h.split(|c| c == ' ' || c == ',') | ||
.any(|p| p.eq_ignore_ascii_case(upgrade.to_str().unwrap())) | ||
}) | ||
.unwrap_or(false) | ||
|| !headers | ||
.get(UPGRADE) | ||
.and_then(|h| h.to_str().ok()) | ||
.map(|h| h.eq_ignore_ascii_case("websocket")) | ||
.unwrap_or(false) | ||
|| !headers.get(SEC_WEBSOCKET_VERSION).map(|h| h == "13").unwrap_or(false) | ||
|| key.is_none() | ||
|| req.uri() != "/socket" | ||
{ | ||
return Ok(Response::new(Body::from("Hello World!"))); | ||
} | ||
let ver = req.version(); | ||
tokio::task::spawn(async move { | ||
match hyper::upgrade::on(&mut req).await { | ||
Ok(upgraded) => { | ||
handle_connection( | ||
peer_map, | ||
WebSocketStream::from_raw_socket(upgraded, Role::Server, None).await, | ||
addr, | ||
) | ||
.await; | ||
} | ||
Err(e) => println!("upgrade error: {}", e), | ||
} | ||
}); | ||
let mut res = Response::new(Body::empty()); | ||
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; | ||
*res.version_mut() = ver; | ||
res.headers_mut().append(CONNECTION, upgrade); | ||
res.headers_mut().append(UPGRADE, websocket); | ||
res.headers_mut().append(SEC_WEBSOCKET_ACCEPT, derived.unwrap().parse().unwrap()); | ||
// Let's add an additional header to our response to the client. | ||
res.headers_mut().append("MyCustomHeader", ":)".parse().unwrap()); | ||
res.headers_mut().append("SOME_TUNGSTENITE_HEADER", "header_value".parse().unwrap()); | ||
Ok(res) | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), hyper::Error> { | ||
let state = PeerMap::new(Mutex::new(HashMap::new())); | ||
|
||
let addr = env::args().nth(1).unwrap_or_else(|| "127.0.0.1:8080".to_string()).parse().unwrap(); | ||
|
||
let make_svc = make_service_fn(move |conn: &AddrStream| { | ||
let remote_addr = conn.remote_addr(); | ||
let state = state.clone(); | ||
let service = service_fn(move |req| handle_request(state.clone(), req, remote_addr)); | ||
async { Ok::<_, Infallible>(service) } | ||
}); | ||
|
||
let server = Server::bind(&addr).serve(make_svc); | ||
|
||
server.await?; | ||
|
||
Ok::<_, hyper::Error>(()) | ||
} |