Skip to content

Commit

Permalink
balancer: reduce lock contention when processing inbound client messa…
Browse files Browse the repository at this point in the history
…ges (dyc3#1178)

* set up clients so that they can have a unicast channel and a broadcast channel

* set up clients to have all the right channels

* rename some vars

* fix incorrect serialization

* fix lints

* fix lints

* uncomment a line

* rename some stuff to be clearer
  • Loading branch information
dyc3 authored and cjrkoa committed Jan 26, 2024
1 parent 11c1cb0 commit 043a3d9
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 148 deletions.
2 changes: 1 addition & 1 deletion crates/harness/src/monolith.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ impl Monolith {
let to_send = {
let mut state = state.lock().unwrap();
let parsed = match &msg {
Message::Text(msg) => serde_json::from_str(msg).unwrap(),
Message::Text(msg) => serde_json::from_str(msg).expect("failed to parse B2M message"),
_ => panic!("unexpected message type: {:?}", msg),
};

Expand Down
216 changes: 95 additions & 121 deletions crates/ott-balancer-bin/src/balancer.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};

use ott_balancer_protocol::monolith::{
B2MClientMsg, B2MJoin, B2MLeave, B2MUnload, MsgM2B, RoomMetadata,
B2MClientMsg, B2MJoin, B2MLeave, B2MUnload, MsgB2M, MsgM2B, RoomMetadata,
};
use ott_balancer_protocol::*;
use rand::seq::IteratorRandom;
Expand All @@ -12,6 +12,7 @@ use tokio_tungstenite::tungstenite::protocol::CloseFrame;
use tokio_tungstenite::tungstenite::Message;
use tracing::{debug, error, info, instrument, trace, warn};

use crate::client::ClientLink;
use crate::config::BalancerConfig;
use crate::monolith::Room;
use crate::room::RoomLocator;
Expand All @@ -24,17 +25,9 @@ use crate::{
pub struct Balancer {
pub(crate) ctx: Arc<RwLock<BalancerContext>>,

new_client_rx: tokio::sync::mpsc::Receiver<(
NewClient,
tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
)>,
new_client_tx: tokio::sync::mpsc::Sender<(
NewClient,
tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
)>,

client_msg_rx: tokio::sync::mpsc::Receiver<Context<ClientId, SocketMessage>>,
client_msg_tx: tokio::sync::mpsc::Sender<Context<ClientId, SocketMessage>>,
new_client_rx:
tokio::sync::mpsc::Receiver<(NewClient, tokio::sync::oneshot::Sender<ClientLink>)>,
new_client_tx: tokio::sync::mpsc::Sender<(NewClient, tokio::sync::oneshot::Sender<ClientLink>)>,

new_monolith_rx: tokio::sync::mpsc::Receiver<(
NewMonolith,
Expand All @@ -52,7 +45,6 @@ pub struct Balancer {
impl Balancer {
pub fn new(ctx: Arc<RwLock<BalancerContext>>) -> Self {
let (new_client_tx, new_client_rx) = tokio::sync::mpsc::channel(20);
let (client_msg_tx, client_msg_rx) = tokio::sync::mpsc::channel(100);

let (new_monolith_tx, new_monolith_rx) = tokio::sync::mpsc::channel(20);
let (monolith_msg_tx, monolith_msg_rx) = tokio::sync::mpsc::channel(100);
Expand All @@ -63,9 +55,6 @@ impl Balancer {
new_client_rx,
new_client_tx,

client_msg_rx,
client_msg_tx,

new_monolith_rx,
new_monolith_tx,

Expand All @@ -77,7 +66,6 @@ impl Balancer {
pub fn new_link(&self) -> BalancerLink {
BalancerLink {
new_client_tx: self.new_client_tx.clone(),
client_msg_tx: self.client_msg_tx.clone(),

new_monolith_tx: self.new_monolith_tx.clone(),
monolith_msg_tx: self.monolith_msg_tx.clone(),
Expand All @@ -88,10 +76,10 @@ impl Balancer {
loop {
tokio::select! {
new_client = self.new_client_rx.recv() => {
if let Some((new_client, receiver_tx)) = new_client {
if let Some((new_client, client_link_tx)) = new_client {
let ctx = self.ctx.clone();
let _ = tokio::task::Builder::new().name("join client").spawn(async move {
match join_client(ctx, new_client, receiver_tx).await {
match join_client(ctx, new_client, client_link_tx).await {
Ok(_) => {},
Err(err) => error!("failed to join client: {:?}", err)
};
Expand All @@ -100,19 +88,6 @@ impl Balancer {
warn!("new client channel closed")
}
}
msg = self.client_msg_rx.recv() => {
if let Some(msg) = msg {
let ctx = self.ctx.clone();
let _ = tokio::task::Builder::new().name("dispatch client message").spawn(async move {
match dispatch_client_message(ctx, msg).await {
Ok(_) => {},
Err(err) => error!("failed to dispatch client message: {:?}", err)
}
});
} else {
warn!("client message channel closed")
}
}
new_monolith = self.new_monolith_rx.recv() => {
if let Some((new_monolith, receiver_tx)) = new_monolith {
let ctx = self.ctx.clone();
Expand Down Expand Up @@ -155,11 +130,7 @@ pub fn start_dispatcher(mut balancer: Balancer) -> anyhow::Result<tokio::task::J

#[derive(Clone)]
pub struct BalancerLink {
new_client_tx: tokio::sync::mpsc::Sender<(
NewClient,
tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
)>,
client_msg_tx: tokio::sync::mpsc::Sender<Context<ClientId, SocketMessage>>,
new_client_tx: tokio::sync::mpsc::Sender<(NewClient, tokio::sync::oneshot::Sender<ClientLink>)>,

new_monolith_tx: tokio::sync::mpsc::Sender<(
NewMonolith,
Expand All @@ -169,26 +140,12 @@ pub struct BalancerLink {
}

impl BalancerLink {
pub async fn send_client(
&self,
client: NewClient,
) -> anyhow::Result<tokio::sync::mpsc::Receiver<SocketMessage>> {
pub async fn send_client(&self, client: NewClient) -> anyhow::Result<ClientLink> {
let (receiver_tx, receiver_rx) = tokio::sync::oneshot::channel();
self.new_client_tx.send((client, receiver_tx)).await?;
let receiver = receiver_rx.await?;
let client_link = receiver_rx.await?;

Ok(receiver)
}

pub async fn send_client_message(
&self,
client_id: ClientId,
message: SocketMessage,
) -> anyhow::Result<()> {
self.client_msg_tx
.send(Context::new(client_id, message))
.await?;
Ok(())
Ok(client_link)
}

pub async fn send_monolith(
Expand Down Expand Up @@ -304,16 +261,16 @@ impl BalancerContext {
Ok(())
}

#[instrument(skip(self, room), err, fields(room = %room.name(), load_epoch = %locator.load_epoch()))]
pub fn add_room(&mut self, room: Room, locator: RoomLocator) -> anyhow::Result<()> {
#[instrument(skip(self, room_name), fields(room = %room_name, load_epoch = %locator.load_epoch()))]
pub fn add_room(&mut self, room_name: RoomName, locator: RoomLocator) -> anyhow::Result<&Room> {
debug!("add_room");
let monolith = self
.monoliths
.get_mut(&locator.monolith_id())
.ok_or(anyhow::anyhow!("monolith not found"))?;
self.rooms_to_monoliths.insert(room.name().clone(), locator);
monolith.add_room(room);
Ok(())
self.rooms_to_monoliths.insert(room_name.clone(), locator);
let room = monolith.add_room(&room_name)?;
Ok(room)
}

pub async fn remove_room(
Expand Down Expand Up @@ -377,7 +334,7 @@ impl BalancerContext {
metadata.name.clone(),
RoomLocator::new(monolith_id, load_epoch),
);
monolith.add_or_sync_room(metadata);
monolith.add_or_sync_room(metadata)?;

Ok(())
}
Expand All @@ -394,15 +351,6 @@ impl BalancerContext {
Ok(locator.monolith_id())
}

pub fn find_monolith(&self, client: ClientId) -> anyhow::Result<&BalancerMonolith> {
let monolith_id = self.find_monolith_id(client)?;
let monolith = self
.monoliths
.get(&monolith_id)
.ok_or(anyhow::anyhow!("monolith not found"))?;
Ok(monolith)
}

pub fn find_monolith_mut(&mut self, client: ClientId) -> anyhow::Result<&mut BalancerMonolith> {
let monolith_id = self.find_monolith_id(client)?;
let monolith = self
Expand Down Expand Up @@ -471,22 +419,15 @@ impl BalancerContext {
pub async fn join_client(
ctx: Arc<RwLock<BalancerContext>>,
new_client: NewClient,
receiver_tx: tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
client_link_tx: tokio::sync::oneshot::Sender<ClientLink>,
) -> anyhow::Result<()> {
info!("new client");

// create the channel that the client socket will use to be notified of outbound messages to be sent to tbe client
// balancer -> client websocket
let (client_tx, client_rx) = tokio::sync::mpsc::channel(100);
let client = BalancerClient::new(new_client, client_tx);
receiver_tx
.send(client_rx)
.map_err(|_| anyhow::anyhow!("receiver closed"))?;

// since we're always going to be doing a write, we can just lock the context for the whole function so it doesn't change out from under us
let mut ctx_write = ctx.write().await;

let (monolith_id, should_create_room) = match ctx_write.rooms_to_monoliths.get(&client.room) {
let (monolith_id, should_create_room) = match ctx_write.rooms_to_monoliths.get(&new_client.room)
{
Some(locator) => {
debug!("room already loaded on {}", locator.monolith_id());
(locator.monolith_id(), false)
Expand All @@ -503,11 +444,35 @@ pub async fn join_client(
}
};

if should_create_room {
let room = Room::new(client.room.clone());
let room_broadcast_rx = if should_create_room {
// we assume the load epoch is u32::MAX since we're creating the room. this will be updated when the monolith sends us the loaded message, or when we receive the gossip message
ctx_write.add_room(room, RoomLocator::new(monolith_id, u32::MAX))?;
ctx_write.add_room(
new_client.room.clone(),
RoomLocator::new(monolith_id, u32::MAX),
)?
} else {
let monolith = ctx_write.monoliths.get(&monolith_id).unwrap();
let room = monolith.rooms().get(&new_client.room).unwrap();
room
}
.new_broadcast_rx();

let monolith = ctx_write.monoliths.get(&monolith_id).unwrap();
let client_inbound_tx = monolith.new_inbound_tx();

let (client_outbound_unicast_tx, client_outbound_unicast_rx) = tokio::sync::mpsc::channel(100);

let link = ClientLink::new(
new_client.id,
client_inbound_tx,
room_broadcast_rx,
client_outbound_unicast_rx,
);
let client = BalancerClient::new(new_client, client_outbound_unicast_tx);
client_link_tx
.send(link)
.map_err(|_| anyhow::anyhow!("receiver closed"))?;

ctx_write.add_client(client, monolith_id).await?;
Ok(())
}
Expand All @@ -520,27 +485,60 @@ pub async fn leave_client(ctx: Arc<RwLock<BalancerContext>>, id: ClientId) -> an
Ok(())
}

pub async fn dispatch_client_message(
#[instrument(skip_all, err, fields(monolith_id = %monolith.id))]
pub async fn join_monolith(
ctx: Arc<RwLock<BalancerContext>>,
msg: Context<ClientId, SocketMessage>,
monolith: NewMonolith,
receiver_tx: tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
) -> anyhow::Result<()> {
trace!("client message: {:?}", msg);
info!("new monolith");
let mut b = ctx.write().await;
let (client_inbound_tx, mut client_inbound_rx) = tokio::sync::mpsc::channel(100);
let (monolith_outbound_tx, monolith_outbound_rx) = tokio::sync::mpsc::channel(100);
let monolith_outbound_tx = Arc::new(monolith_outbound_tx);
let monolith = BalancerMonolith::new(monolith, monolith_outbound_tx.clone(), client_inbound_tx);
receiver_tx
.send(monolith_outbound_rx)
.map_err(|_| anyhow::anyhow!("receiver closed"))?;
let monolith_id = monolith.id();
b.add_monolith(monolith);
drop(b);

let ctx = ctx.clone();
let monolith_outbound_tx = monolith_outbound_tx.clone();
tokio::task::Builder::new()
.name(format!("monolith {}", monolith_id).as_ref())
.spawn(async move {
loop {
tokio::select! {
Some(msg) = client_inbound_rx.recv() => {
if let Err(e) = handle_client_inbound(ctx.clone(), msg, monolith_outbound_tx.clone()).await {
error!("failed to handle client inbound: {:?}", e);
}
}
}
}
})?;
Ok(())
}

async fn handle_client_inbound(
ctx: Arc<RwLock<BalancerContext>>,
msg: Context<ClientId, SocketMessage>,
monolith_outbound_tx: Arc<tokio::sync::mpsc::Sender<SocketMessage>>,
) -> anyhow::Result<()> {
match msg.message() {
SocketMessage::Message(Message::Text(_) | Message::Binary(_)) => {
let raw_value: Box<RawValue> = msg.message().deserialize()?;

let ctx_read = ctx.read().await;
let Ok(monolith) = ctx_read.find_monolith(*msg.id()) else {
anyhow::bail!("monolith not found");
};

monolith
.send(B2MClientMsg {
client_id: *msg.id(),
payload: raw_value,
})
.await?;
let built_msg: MsgB2M = B2MClientMsg {
client_id: *msg.id(),
payload: raw_value,
}
.into();
let text = serde_json::to_string(&built_msg).expect("failed to serialize message");
let socket_msg = Message::Text(text).into();
monolith_outbound_tx.send(socket_msg).await?;
}
#[allow(deprecated)]
SocketMessage::Message(Message::Close(_)) | SocketMessage::End => {
Expand All @@ -549,24 +547,6 @@ pub async fn dispatch_client_message(
SocketMessage::Message(Message::Frame(_)) => unreachable!(),
_ => {}
}

Ok(())
}

#[instrument(skip_all, err, fields(monolith_id = %monolith.id))]
pub async fn join_monolith(
ctx: Arc<RwLock<BalancerContext>>,
monolith: NewMonolith,
receiver_tx: tokio::sync::oneshot::Sender<tokio::sync::mpsc::Receiver<SocketMessage>>,
) -> anyhow::Result<()> {
info!("new monolith");
let mut b = ctx.write().await;
let (monolith_tx, monolith_rx) = tokio::sync::mpsc::channel(100);
let monolith = BalancerMonolith::new(monolith, monolith_tx);
receiver_tx
.send(monolith_rx)
.map_err(|_| anyhow::anyhow!("receiver closed"))?;
b.add_monolith(monolith);
Ok(())
}

Expand Down Expand Up @@ -692,13 +672,7 @@ pub async fn dispatch_monolith_message(
None => {
// broadcast to all clients
debug!("broadcasting to clients in room: {:?}", room.name());
// TODO: optimize this using a broadcast channel
for client in room.clients() {
let Some(client) = ctx_read.clients.get(client) else {
anyhow::bail!("client not found");
};
client.send(built_msg.clone()).await?;
}
room.broadcast(built_msg)?;
}
}
}
Expand Down
Loading

0 comments on commit 043a3d9

Please sign in to comment.