From 486700d683ce5b0c921cf3c47100f0593b753c11 Mon Sep 17 00:00:00 2001 From: Carson McManus Date: Thu, 21 Mar 2024 11:19:37 -0400 Subject: [PATCH] balancer: fix memory leak caused by clients connecting (#1543) * avoid spawning unnecessary tasks * use jemallocator instead of system allocator * remove unused code * add new load test for just websocket connections * refactor http connection handling to properly clean up tasks * refactor dispatch loop to also properly clean up tasks * remove a log * ensure that tasks spawned when serving http requests get properly cleaned up * make console subscriber opt in with --console flag * fix lints --- Cargo.lock | 21 +++++++++ Cargo.toml | 1 + crates/ott-balancer/Cargo.toml | 1 + crates/ott-balancer/src/balancer.rs | 52 ++++++++++++++--------- crates/ott-balancer/src/config.rs | 6 ++- crates/ott-balancer/src/lib.rs | 66 +++++++++++++++++++++-------- crates/ott-balancer/src/service.rs | 19 +++++++-- tests/load/extreme-connects.js | 45 ++++++++++++++++++++ 8 files changed, 168 insertions(+), 43 deletions(-) create mode 100644 tests/load/extreme-connects.js diff --git a/Cargo.lock b/Cargo.lock index a7013ee16..1dd013cfc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1541,6 +1541,26 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" +[[package]] +name = "jemalloc-sys" +version = "0.5.4+5.3.0-patched" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac6c1946e1cea1788cbfde01c993b52a10e2da07f4bac608228d1bed20bfebf2" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "jemallocator" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0de374a9f8e63150e6f5e8a60cc14c668226d7a347d8aee1a45766e3c4dd3bc" +dependencies = [ + "jemalloc-sys", + "libc", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -1949,6 +1969,7 @@ dependencies = [ "http-body-util", "hyper 1.2.0", "hyper-util", + "jemallocator", "once_cell", "ott-balancer-protocol", "ott-common", diff --git a/Cargo.toml b/Cargo.toml index daa94673f..c4c13631a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ humantime-serde = "1.1" hyper = { version = "1.2.0", features = ["full"] } hyper-util = { version = "0.1.3", features = ["full"] } http-body-util = "0.1.0" +jemallocator = { version = "0.5.4" } once_cell = "1.17.1" ott-common = { path = "crates/ott-common" } ott-balancer = { path = "crates/ott-balancer" } diff --git a/crates/ott-balancer/Cargo.toml b/crates/ott-balancer/Cargo.toml index f2af887fd..927ef2f54 100644 --- a/crates/ott-balancer/Cargo.toml +++ b/crates/ott-balancer/Cargo.toml @@ -14,6 +14,7 @@ futures-util.workspace = true hyper.workspace = true hyper-util.workspace = true http-body-util.workspace = true +jemallocator.workspace = true rand.workspace = true reqwest.workspace = true serde.workspace = true diff --git a/crates/ott-balancer/src/balancer.rs b/crates/ott-balancer/src/balancer.rs index 43706c453..6bfeb4b4f 100644 --- a/crates/ott-balancer/src/balancer.rs +++ b/crates/ott-balancer/src/balancer.rs @@ -1,5 +1,7 @@ use std::{collections::HashMap, sync::Arc}; +use futures_util::stream::FuturesUnordered; +use futures_util::StreamExt; use ott_balancer_protocol::collector::{BalancerState, MonolithState, RoomState}; use ott_balancer_protocol::monolith::{ B2MClientMsg, B2MJoin, B2MLeave, B2MUnload, MsgB2M, MsgM2B, RoomMetadata, @@ -7,6 +9,7 @@ use ott_balancer_protocol::monolith::{ use ott_balancer_protocol::*; use serde_json::value::RawValue; use tokio::sync::RwLock; +use tokio::task::JoinHandle; use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode; use tokio_tungstenite::tungstenite::protocol::CloseFrame; use tokio_tungstenite::tungstenite::Message; @@ -75,30 +78,25 @@ impl Balancer { } pub async fn dispatch_loop(&mut self) { + let mut tasks = FuturesUnordered::new(); loop { tokio::select! { new_client = self.new_client_rx.recv() => { if let Some((new_client, client_link_tx)) = new_client { - let ctx = self.ctx.clone(); - let _ = tokio::task::Builder::new().name("join client").spawn(async move { - match join_client(ctx, new_client, client_link_tx).await { - Ok(_) => {}, - Err(err) => error!("failed to join client: {:?}", err) - }; - }); + match join_client(&self.ctx, new_client, client_link_tx).await { + Ok(_) => {}, + Err(err) => error!("failed to join client: {:?}", err) + }; } else { warn!("new client channel closed") } } new_monolith = self.new_monolith_rx.recv() => { if let Some((new_monolith, receiver_tx)) = new_monolith { - let ctx = self.ctx.clone(); - let _ = tokio::task::Builder::new().name("join monolith").spawn(async move { - match join_monolith(ctx, new_monolith, receiver_tx).await { - Ok(_) => {}, - Err(err) => error!("failed to join monolith: {:?}", err) - } - }); + match join_monolith(&self.ctx, new_monolith, receiver_tx).await { + Ok(handle) => tasks.push(handle), + Err(err) => error!("failed to join monolith: {:?}", err) + } } else { warn!("new monolith channel closed") } @@ -106,17 +104,31 @@ impl Balancer { msg = self.monolith_msg_rx.recv() => { if let Some(msg) = msg { let ctx = self.ctx.clone(); - let _ = tokio::task::Builder::new().name("dispatch monolith message").spawn(async move { + let handle = tokio::task::Builder::new().name("dispatch monolith message").spawn(async move { let id = *msg.id(); match dispatch_monolith_message(ctx, msg).await { Ok(_) => {}, Err(err) => error!("failed to dispatch monolith message {}: {:?}", id, err) } }); + match handle { + Ok(handle) => { + tasks.push(handle); + } + Err(err) => { + error!("failed to spawn dispatch monolith message task: {:?}", err); + } + } } else { warn!("monolith message channel closed") } } + // process completed tasks + task_result = tasks.next() => { + if let Some(Err(err)) = task_result { + error!("error in task: {:?}", err); + } + } } } } @@ -432,7 +444,7 @@ impl BalancerContext { #[instrument(skip_all, err, fields(client_id = %new_client.id, room = %new_client.room))] pub async fn join_client( - ctx: Arc>, + ctx: &Arc>, new_client: NewClient, client_link_tx: tokio::sync::oneshot::Sender, ) -> anyhow::Result<()> { @@ -502,10 +514,10 @@ pub async fn leave_client(ctx: Arc>, id: ClientId) -> an #[instrument(skip_all, err, fields(monolith_id = %monolith.id))] pub async fn join_monolith( - ctx: Arc>, + ctx: &Arc>, monolith: NewMonolith, receiver_tx: tokio::sync::oneshot::Sender>, -) -> anyhow::Result<()> { +) -> anyhow::Result> { info!("new monolith"); let mut b = ctx.write().await; let (client_inbound_tx, mut client_inbound_rx) = tokio::sync::mpsc::channel(100); @@ -521,7 +533,7 @@ pub async fn join_monolith( let ctx = ctx.clone(); let monolith_outbound_tx = monolith_outbound_tx.clone(); - tokio::task::Builder::new() + let handle = tokio::task::Builder::new() .name(format!("monolith {}", monolith_id).as_ref()) .spawn(async move { while let Some(msg) = client_inbound_rx.recv().await { @@ -537,7 +549,7 @@ pub async fn join_monolith( } } })?; - Ok(()) + Ok(handle) } async fn handle_client_inbound( diff --git a/crates/ott-balancer/src/config.rs b/crates/ott-balancer/src/config.rs index b804580ae..5ee1a225e 100644 --- a/crates/ott-balancer/src/config.rs +++ b/crates/ott-balancer/src/config.rs @@ -79,10 +79,14 @@ pub struct Cli { #[clap(short, long, default_value_t = LogLevel::Info, value_enum)] pub log_level: LogLevel, + /// Enable the console-subscriber for debugging via tokio-console. + #[clap(long)] + pub console: bool, + /// Allow remote connections via tokio-console for debugging. By default, only local connections are allowed. /// /// The default port for tokio-console is 6669. - #[clap(long)] + #[clap(long, requires("console"))] pub remote_console: bool, /// Validate the configuration file. diff --git a/crates/ott-balancer/src/lib.rs b/crates/ott-balancer/src/lib.rs index 55c45a3a3..86b4c62b3 100644 --- a/crates/ott-balancer/src/lib.rs +++ b/crates/ott-balancer/src/lib.rs @@ -4,6 +4,8 @@ use std::{net::SocketAddr, sync::Arc}; use anyhow::Context; use balancer::{start_dispatcher, Balancer, BalancerContext}; use clap::Parser; +use futures_util::stream::FuturesUnordered; +use futures_util::{FutureExt, StreamExt}; use hyper::server::conn::http1; use tokio::net::TcpListener; use tokio::sync::RwLock; @@ -30,6 +32,9 @@ pub mod selection; pub mod service; pub mod state_stream; +#[global_allocator] +static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; + pub async fn run() -> anyhow::Result<()> { let args = config::Cli::parse(); @@ -51,17 +56,22 @@ pub async fn run() -> anyhow::Result<()> { let config = BalancerConfig::get(); - let console_layer = if args.remote_console { - console_subscriber::ConsoleLayer::builder() - .server_addr(( - Ipv6Addr::UNSPECIFIED, - console_subscriber::Server::DEFAULT_PORT, - )) - .spawn() + let console_layer = if args.console { + let console_layer = if args.remote_console { + console_subscriber::ConsoleLayer::builder() + .server_addr(( + Ipv6Addr::UNSPECIFIED, + console_subscriber::Server::DEFAULT_PORT, + )) + .spawn() + } else { + console_subscriber::ConsoleLayer::builder().spawn() + } + .with_filter(EnvFilter::try_new("tokio=trace,runtime=trace")?); + Some(console_layer) } else { - console_subscriber::ConsoleLayer::builder().spawn() + None }; - let console_layer = console_layer.with_filter(EnvFilter::try_new("tokio=trace,runtime=trace")?); let filter = args.build_tracing_filter(); let filter_layer = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new(filter))?; let fmt_layer = tracing_subscriber::fmt::layer().with_filter(filter_layer); @@ -128,10 +138,11 @@ pub async fn run() -> anyhow::Result<()> { let bind_addr6: SocketAddr = SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), config.port); + let (task_handle_tx, mut task_handle_rx) = tokio::sync::mpsc::channel(10); let service = BalancerService { ctx, link: service_link, - addr: bind_addr6, + task_handle_tx, }; // on linux, binding ipv6 will also bind ipv4 @@ -140,21 +151,38 @@ pub async fn run() -> anyhow::Result<()> { .context("binding primary inbound socket")?; info!("Serving on {}", bind_addr6); + let mut tasks = FuturesUnordered::new(); loop { - let (stream, addr) = tokio::select! { - stream = listener6.accept() => { + let accept_fut = Box::pin(listener6.accept()); + + let (stream, _addr) = tokio::select! { + stream = accept_fut.fuse() => { let (stream, addr) = stream?; (stream, addr) } - }; - let mut service = service.clone(); - service.addr = addr; + // process completed tasks + result = tasks.next() => { + if let Some(Err(err)) = result { + error!("Error in http serving task: {:?}", err); + } + continue; + } + task_handle_rx = task_handle_rx.recv() => { + if let Some(task_handle) = task_handle_rx { + info!("Received task handle"); + tasks.push(task_handle); + } + continue; + } + }; + + let service = service.clone(); let io = hyper_util::rt::TokioIo::new(stream); // Spawn a tokio task to serve multiple connections concurrently - let result = tokio::task::Builder::new() + let task = tokio::task::Builder::new() .name("serve http") .spawn(async move { let conn = http1::Builder::new() @@ -164,8 +192,10 @@ pub async fn run() -> anyhow::Result<()> { error!("Error serving connection: {:?}", err); } }); - if let Err(err) = result { - error!("Error spawning task to serve http: {:?}", err); + + match task { + Ok(task) => tasks.push(task), + Err(err) => error!("Error spawning task to serve http: {:?}", err), } } } diff --git a/crates/ott-balancer/src/service.rs b/crates/ott-balancer/src/service.rs index a0ff6564c..6d8d41f8b 100644 --- a/crates/ott-balancer/src/service.rs +++ b/crates/ott-balancer/src/service.rs @@ -15,6 +15,7 @@ use prometheus::{register_int_gauge, Encoder, IntGauge, TextEncoder}; use reqwest::Url; use route_recognizer::Router; use tokio::sync::RwLock; +use tokio::task::JoinHandle; use tracing::{debug, error, field, info, span, warn, Level}; use crate::balancer::{BalancerContext, BalancerLink}; @@ -46,7 +47,7 @@ static ROUTER: Lazy> = Lazy::new(|| { pub struct BalancerService { pub(crate) ctx: Arc>, pub(crate) link: BalancerLink, - pub(crate) addr: std::net::SocketAddr, + pub(crate) task_handle_tx: tokio::sync::mpsc::Sender>, } impl Service> for BalancerService { @@ -73,7 +74,7 @@ impl Service> for BalancerService { let ctx: Arc> = self.ctx.clone(); let link = self.link.clone(); - let _addr = self.addr; + let task_handle_tx = self.task_handle_tx.clone(); let Ok(route) = ROUTER.recognize(req.uri().path()) else { warn!("no route found for {}", req.uri().path()); @@ -141,13 +142,14 @@ impl Service> for BalancerService { } }; - tokio::spawn(async move { + let handle = tokio::spawn(async move { if let Err(err) = crate::state_stream::handle_stream_websocket(websocket).await { error!("error handling event stream websocket: {}", err); } }); + let _ = task_handle_tx.send(handle).await; Ok(response) } else { @@ -208,7 +210,7 @@ impl Service> for BalancerService { }; // Spawn a task to handle the websocket connection. - let _ = tokio::task::Builder::new().name("client connection").spawn( + let handle = tokio::task::Builder::new().name("client connection").spawn( async move { GAUGE_CLIENTS.inc(); if let Err(e) = client_entry(room_name, websocket, link).await { @@ -217,6 +219,15 @@ impl Service> for BalancerService { GAUGE_CLIENTS.dec(); }, ); + match handle { + Ok(handle) => { + let _ = task_handle_tx.send(handle).await; + } + Err(err) => { + error!("Error spawning task to handle websocket: {}", err); + return Ok(interval_server_error()); + } + } // Return the response so the spawned future can continue. Ok(response) diff --git a/tests/load/extreme-connects.js b/tests/load/extreme-connects.js new file mode 100644 index 000000000..2530dc990 --- /dev/null +++ b/tests/load/extreme-connects.js @@ -0,0 +1,45 @@ +import ws from "k6/ws"; +import { sleep, check } from "k6"; +import { randomItem } from "https://jslib.k6.io/k6-utils/1.4.0/index.js"; +import { getAuthToken, createRoom, HOSTNAME } from "./utils.js"; + +export const options = { + rate: 50, + duration: "1h", +}; + +export function setup() { + const rooms = []; + for (let i = 0; i < 50; i++) { + rooms.push(`load-test-${i}`); + } + + const token = getAuthToken(); + for (let room of rooms) { + createRoom(room, token, { visibility: "public", isTemporary: true }); + } + sleep(1); + + const tokens = [token]; + for (let i = 0; i < 100; i++) { + tokens.push(getAuthToken()); + } + + return { rooms, tokens }; +} + +export default function (data) { + const { rooms, tokens } = data; + const room = randomItem(rooms); + const token = randomItem(tokens); + console.log(`User is joining room ${room}`); + const url = `ws://${HOSTNAME}/api/room/${room}`; + + const res = ws.connect(url, null, socket => { + socket.on("open", () => { + socket.send(JSON.stringify({ action: "auth", token: token })); + socket.close(1000); + }); + }); + check(res, { "ws status is 101": r => r && r.status === 101 }); +}