Skip to content

Commit

Permalink
balancer: fix memory leak caused by clients connecting (#1543)
Browse files Browse the repository at this point in the history
* avoid spawning unnecessary tasks

* use jemallocator instead of system allocator

* remove unused code

* add new load test for just websocket connections

* refactor http connection handling to properly clean up tasks

* refactor dispatch loop to also properly clean up tasks

* remove a log

* ensure that tasks spawned when serving http requests get properly cleaned up

* make console subscriber opt in with --console flag

* fix lints
  • Loading branch information
dyc3 authored Mar 21, 2024
1 parent 8d70847 commit 486700d
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 43 deletions.
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ humantime-serde = "1.1"
hyper = { version = "1.2.0", features = ["full"] }
hyper-util = { version = "0.1.3", features = ["full"] }
http-body-util = "0.1.0"
jemallocator = { version = "0.5.4" }
once_cell = "1.17.1"
ott-common = { path = "crates/ott-common" }
ott-balancer = { path = "crates/ott-balancer" }
Expand Down
1 change: 1 addition & 0 deletions crates/ott-balancer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ futures-util.workspace = true
hyper.workspace = true
hyper-util.workspace = true
http-body-util.workspace = true
jemallocator.workspace = true
rand.workspace = true
reqwest.workspace = true
serde.workspace = true
Expand Down
52 changes: 32 additions & 20 deletions crates/ott-balancer/src/balancer.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use std::{collections::HashMap, sync::Arc};

use futures_util::stream::FuturesUnordered;
use futures_util::StreamExt;
use ott_balancer_protocol::collector::{BalancerState, MonolithState, RoomState};
use ott_balancer_protocol::monolith::{
B2MClientMsg, B2MJoin, B2MLeave, B2MUnload, MsgB2M, MsgM2B, RoomMetadata,
};
use ott_balancer_protocol::*;
use serde_json::value::RawValue;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::tungstenite::Message;
Expand Down Expand Up @@ -75,48 +78,57 @@ impl Balancer {
}

pub async fn dispatch_loop(&mut self) {
let mut tasks = FuturesUnordered::new();
loop {
tokio::select! {
new_client = self.new_client_rx.recv() => {
if let Some((new_client, client_link_tx)) = new_client {
let ctx = self.ctx.clone();
let _ = tokio::task::Builder::new().name("join client").spawn(async move {
match join_client(ctx, new_client, client_link_tx).await {
Ok(_) => {},
Err(err) => error!("failed to join client: {:?}", err)
};
});
match join_client(&self.ctx, new_client, client_link_tx).await {
Ok(_) => {},
Err(err) => error!("failed to join client: {:?}", err)
};
} else {
warn!("new client channel closed")
}
}
new_monolith = self.new_monolith_rx.recv() => {
if let Some((new_monolith, receiver_tx)) = new_monolith {
let ctx = self.ctx.clone();
let _ = tokio::task::Builder::new().name("join monolith").spawn(async move {
match join_monolith(ctx, new_monolith, receiver_tx).await {
Ok(_) => {},
Err(err) => error!("failed to join monolith: {:?}", err)
}
});
match join_monolith(&self.ctx, new_monolith, receiver_tx).await {
Ok(handle) => tasks.push(handle),
Err(err) => error!("failed to join monolith: {:?}", err)
}
} else {
warn!("new monolith channel closed")
}
}
msg = self.monolith_msg_rx.recv() => {
if let Some(msg) = msg {
let ctx = self.ctx.clone();
let _ = tokio::task::Builder::new().name("dispatch monolith message").spawn(async move {
let handle = tokio::task::Builder::new().name("dispatch monolith message").spawn(async move {
let id = *msg.id();
match dispatch_monolith_message(ctx, msg).await {
Ok(_) => {},
Err(err) => error!("failed to dispatch monolith message {}: {:?}", id, err)
}
});
match handle {
Ok(handle) => {
tasks.push(handle);
}
Err(err) => {
error!("failed to spawn dispatch monolith message task: {:?}", err);
}
}
} else {
warn!("monolith message channel closed")
}
}
// process completed tasks
task_result = tasks.next() => {
if let Some(Err(err)) = task_result {
error!("error in task: {:?}", err);
}
}
}
}
}
Expand Down Expand Up @@ -432,7 +444,7 @@ impl BalancerContext {

#[instrument(skip_all, err, fields(client_id = %new_client.id, room = %new_client.room))]
pub async fn join_client(
ctx: Arc<RwLock<BalancerContext>>,
ctx: &Arc<RwLock<BalancerContext>>,
new_client: NewClient,
client_link_tx: tokio::sync::oneshot::Sender<ClientLink>,
) -> anyhow::Result<()> {
Expand Down Expand Up @@ -502,10 +514,10 @@ pub async fn leave_client(ctx: Arc<RwLock<BalancerContext>>, id: ClientId) -> an

#[instrument(skip_all, err, fields(monolith_id = %monolith.id))]
pub async fn join_monolith(
ctx: Arc<RwLock<BalancerContext>>,
ctx: &Arc<RwLock<BalancerContext>>,
monolith: NewMonolith,
receiver_tx: tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
) -> anyhow::Result<()> {
) -> anyhow::Result<JoinHandle<()>> {
info!("new monolith");
let mut b = ctx.write().await;
let (client_inbound_tx, mut client_inbound_rx) = tokio::sync::mpsc::channel(100);
Expand All @@ -521,7 +533,7 @@ pub async fn join_monolith(

let ctx = ctx.clone();
let monolith_outbound_tx = monolith_outbound_tx.clone();
tokio::task::Builder::new()
let handle = tokio::task::Builder::new()
.name(format!("monolith {}", monolith_id).as_ref())
.spawn(async move {
while let Some(msg) = client_inbound_rx.recv().await {
Expand All @@ -537,7 +549,7 @@ pub async fn join_monolith(
}
}
})?;
Ok(())
Ok(handle)
}

async fn handle_client_inbound(
Expand Down
6 changes: 5 additions & 1 deletion crates/ott-balancer/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,14 @@ pub struct Cli {
#[clap(short, long, default_value_t = LogLevel::Info, value_enum)]
pub log_level: LogLevel,

/// Enable the console-subscriber for debugging via tokio-console.
#[clap(long)]
pub console: bool,

/// Allow remote connections via tokio-console for debugging. By default, only local connections are allowed.
///
/// The default port for tokio-console is 6669.
#[clap(long)]
#[clap(long, requires("console"))]
pub remote_console: bool,

/// Validate the configuration file.
Expand Down
66 changes: 48 additions & 18 deletions crates/ott-balancer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::{net::SocketAddr, sync::Arc};
use anyhow::Context;
use balancer::{start_dispatcher, Balancer, BalancerContext};
use clap::Parser;
use futures_util::stream::FuturesUnordered;
use futures_util::{FutureExt, StreamExt};
use hyper::server::conn::http1;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
Expand All @@ -30,6 +32,9 @@ pub mod selection;
pub mod service;
pub mod state_stream;

#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;

pub async fn run() -> anyhow::Result<()> {
let args = config::Cli::parse();

Expand All @@ -51,17 +56,22 @@ pub async fn run() -> anyhow::Result<()> {

let config = BalancerConfig::get();

let console_layer = if args.remote_console {
console_subscriber::ConsoleLayer::builder()
.server_addr((
Ipv6Addr::UNSPECIFIED,
console_subscriber::Server::DEFAULT_PORT,
))
.spawn()
let console_layer = if args.console {
let console_layer = if args.remote_console {
console_subscriber::ConsoleLayer::builder()
.server_addr((
Ipv6Addr::UNSPECIFIED,
console_subscriber::Server::DEFAULT_PORT,
))
.spawn()
} else {
console_subscriber::ConsoleLayer::builder().spawn()
}
.with_filter(EnvFilter::try_new("tokio=trace,runtime=trace")?);
Some(console_layer)
} else {
console_subscriber::ConsoleLayer::builder().spawn()
None
};
let console_layer = console_layer.with_filter(EnvFilter::try_new("tokio=trace,runtime=trace")?);
let filter = args.build_tracing_filter();
let filter_layer = EnvFilter::try_from_default_env().or_else(|_| EnvFilter::try_new(filter))?;
let fmt_layer = tracing_subscriber::fmt::layer().with_filter(filter_layer);
Expand Down Expand Up @@ -128,10 +138,11 @@ pub async fn run() -> anyhow::Result<()> {
let bind_addr6: SocketAddr =
SocketAddr::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0).into(), config.port);

let (task_handle_tx, mut task_handle_rx) = tokio::sync::mpsc::channel(10);
let service = BalancerService {
ctx,
link: service_link,
addr: bind_addr6,
task_handle_tx,
};

// on linux, binding ipv6 will also bind ipv4
Expand All @@ -140,21 +151,38 @@ pub async fn run() -> anyhow::Result<()> {
.context("binding primary inbound socket")?;

info!("Serving on {}", bind_addr6);
let mut tasks = FuturesUnordered::new();
loop {
let (stream, addr) = tokio::select! {
stream = listener6.accept() => {
let accept_fut = Box::pin(listener6.accept());

let (stream, _addr) = tokio::select! {
stream = accept_fut.fuse() => {
let (stream, addr) = stream?;
(stream, addr)
}
};

let mut service = service.clone();
service.addr = addr;
// process completed tasks
result = tasks.next() => {
if let Some(Err(err)) = result {
error!("Error in http serving task: {:?}", err);
}
continue;
}

task_handle_rx = task_handle_rx.recv() => {
if let Some(task_handle) = task_handle_rx {
info!("Received task handle");
tasks.push(task_handle);
}
continue;
}
};

let service = service.clone();
let io = hyper_util::rt::TokioIo::new(stream);

// Spawn a tokio task to serve multiple connections concurrently
let result = tokio::task::Builder::new()
let task = tokio::task::Builder::new()
.name("serve http")
.spawn(async move {
let conn = http1::Builder::new()
Expand All @@ -164,8 +192,10 @@ pub async fn run() -> anyhow::Result<()> {
error!("Error serving connection: {:?}", err);
}
});
if let Err(err) = result {
error!("Error spawning task to serve http: {:?}", err);

match task {
Ok(task) => tasks.push(task),
Err(err) => error!("Error spawning task to serve http: {:?}", err),
}
}
}
19 changes: 15 additions & 4 deletions crates/ott-balancer/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use prometheus::{register_int_gauge, Encoder, IntGauge, TextEncoder};
use reqwest::Url;
use route_recognizer::Router;
use tokio::sync::RwLock;
use tokio::task::JoinHandle;
use tracing::{debug, error, field, info, span, warn, Level};

use crate::balancer::{BalancerContext, BalancerLink};
Expand Down Expand Up @@ -46,7 +47,7 @@ static ROUTER: Lazy<Router<&'static str>> = Lazy::new(|| {
pub struct BalancerService {
pub(crate) ctx: Arc<RwLock<BalancerContext>>,
pub(crate) link: BalancerLink,
pub(crate) addr: std::net::SocketAddr,
pub(crate) task_handle_tx: tokio::sync::mpsc::Sender<JoinHandle<()>>,
}

impl Service<Request<IncomingBody>> for BalancerService {
Expand All @@ -73,7 +74,7 @@ impl Service<Request<IncomingBody>> for BalancerService {

let ctx: Arc<RwLock<BalancerContext>> = self.ctx.clone();
let link = self.link.clone();
let _addr = self.addr;
let task_handle_tx = self.task_handle_tx.clone();

let Ok(route) = ROUTER.recognize(req.uri().path()) else {
warn!("no route found for {}", req.uri().path());
Expand Down Expand Up @@ -141,13 +142,14 @@ impl Service<Request<IncomingBody>> for BalancerService {
}
};

tokio::spawn(async move {
let handle = tokio::spawn(async move {
if let Err(err) =
crate::state_stream::handle_stream_websocket(websocket).await
{
error!("error handling event stream websocket: {}", err);
}
});
let _ = task_handle_tx.send(handle).await;

Ok(response)
} else {
Expand Down Expand Up @@ -208,7 +210,7 @@ impl Service<Request<IncomingBody>> for BalancerService {
};

// Spawn a task to handle the websocket connection.
let _ = tokio::task::Builder::new().name("client connection").spawn(
let handle = tokio::task::Builder::new().name("client connection").spawn(
async move {
GAUGE_CLIENTS.inc();
if let Err(e) = client_entry(room_name, websocket, link).await {
Expand All @@ -217,6 +219,15 @@ impl Service<Request<IncomingBody>> for BalancerService {
GAUGE_CLIENTS.dec();
},
);
match handle {
Ok(handle) => {
let _ = task_handle_tx.send(handle).await;
}
Err(err) => {
error!("Error spawning task to handle websocket: {}", err);
return Ok(interval_server_error());
}
}

// Return the response so the spawned future can continue.
Ok(response)
Expand Down
Loading

0 comments on commit 486700d

Please sign in to comment.