diff --git a/crates/harness-tests/src/state.rs b/crates/harness-tests/src/state.rs index 7577865e9..c9d537882 100644 --- a/crates/harness-tests/src/state.rs +++ b/crates/harness-tests/src/state.rs @@ -2,7 +2,7 @@ use std::time::Duration; -use harness::{Client, Monolith, TestRunner}; +use harness::{BehaviorLoadRooms, Client, Monolith, MonolithBuilder, TestRunner}; use ott_balancer_protocol::monolith::{B2MUnload, MsgB2M}; use test_context::test_context; @@ -37,14 +37,98 @@ async fn should_unload_duplicate_rooms_and_route_correctly(ctx: &TestRunner) { m1.load_room("foo").await; m2.load_room("foo").await; - tokio::time::timeout(Duration::from_millis(100), m2.wait_recv()) + tokio::time::timeout(Duration::from_millis(200), m2.wait_recv()) .await .expect("timed out waiting for unload"); let mut c = Client::new(ctx).unwrap(); c.join("foo").await; - tokio::time::timeout(Duration::from_millis(100), m1.wait_recv()) + tokio::time::timeout(Duration::from_millis(200), m1.wait_recv()) .await .expect("timed out waiting for client join"); } + +#[test_context(TestRunner)] +#[tokio::test] +async fn should_not_unload_rooms_when_balancer_restart(ctx: &mut TestRunner) { + let mut m = MonolithBuilder::new() + .behavior(BehaviorLoadRooms) + .build(ctx) + .await; + let mut c1 = Client::new(ctx).unwrap(); + + // increase the load epoch past the initial value + for _ in 0..10 { + m.load_room("foo").await; + m.unload_room("foo").await; + } + + m.show().await; + c1.join("foo").await; + m.wait_recv().await; + m.clear_recv(); + + ctx.restart_balancer().await; + + m.wait_for_balancer_connect().await; + c1.disconnect().await; + c1.join("foo").await; + m.wait_recv().await; + m.gossip().await; + let _ = tokio::time::timeout(Duration::from_millis(200), m.wait_recv()).await; + + let recv = m.collect_recv(); + for msg in recv { + if matches!(msg, MsgB2M::Unload(_)) { + panic!("expected no unload message from balancer, got {:?}", msg); + } + } +} + +#[test_context(TestRunner)] +#[tokio::test] +async fn should_update_load_epoch_when_balancer_restart_2_monoliths(ctx: &mut TestRunner) { + let mut m = MonolithBuilder::new() + .behavior(BehaviorLoadRooms) + .build(ctx) + .await; + let mut c1 = Client::new(ctx).unwrap(); + + // increase the load epoch past the initial value + for _ in 0..10 { + m.load_room("foo").await; + m.unload_room("foo").await; + } + + m.show().await; + c1.join("foo").await; + m.wait_recv().await; + m.clear_recv(); + + ctx.restart_balancer().await; + + m.wait_for_balancer_connect().await; + c1.disconnect().await; + c1.join("foo").await; + m.wait_recv().await; + m.gossip().await; + let _ = tokio::time::timeout(Duration::from_millis(200), m.wait_recv()).await; + + let mut m2 = MonolithBuilder::new() + .behavior(BehaviorLoadRooms) + .build(ctx) + .await; + m2.show().await; + m2.load_room("foo").await; + m2.wait_recv().await; + + let recv = m2.collect_recv(); + for msg in recv { + if matches!(msg, MsgB2M::Unload(_)) { + // This means that the load epoch was corrected when the gossip was received + return; + } + } + panic!("expected unload message from balancer"); +} diff --git a/crates/harness/src/monolith.rs b/crates/harness/src/monolith.rs index adba9a09d..f34b8f927 100644 --- a/crates/harness/src/monolith.rs +++ b/crates/harness/src/monolith.rs @@ -55,7 +55,7 @@ pub struct MonolithState { received_http: Vec, /// A mapping from request path to response body for mocking HTTP responses. response_mocks: HashMap, - rooms: HashMap, + rooms: HashMap, room_load_epoch: Arc, clients: HashSet, } @@ -279,8 +279,15 @@ impl Monolith { let meta = RoomMetadata::default_with_name(room.clone()); let load_epoch = { let mut state = self.state.lock().unwrap(); - state.rooms.insert(room, meta.clone()); - state.room_load_epoch.fetch_add(1, Ordering::Relaxed) + let load_epoch = state.room_load_epoch.fetch_add(1, Ordering::Relaxed); + state.rooms.insert( + room, + GossipRoom { + room: meta.clone(), + load_epoch, + }, + ); + load_epoch }; if connected { self.send(M2BLoaded { @@ -302,6 +309,12 @@ impl Monolith { self.send(M2BUnloaded { name: room }).await; } } + + /// Send a gossip message to the balancer. + pub async fn gossip(&mut self) { + let rooms = self.state.lock().unwrap().rooms.values().cloned().collect(); + self.send(M2BGossip { rooms }).await; + } } impl Drop for Monolith { diff --git a/crates/harness/src/monolith/behavior.rs b/crates/harness/src/monolith/behavior.rs index fdcfe3f3f..e61cc78e1 100644 --- a/crates/harness/src/monolith/behavior.rs +++ b/crates/harness/src/monolith/behavior.rs @@ -68,11 +68,15 @@ impl Behavior for BehaviorLoadRooms { match msg { MsgB2M::Load(msg) => { let room = RoomMetadata::default_with_name(msg.room.clone()); - state.rooms.insert(room.name.clone(), room.clone()); - let loaded = M2BLoaded { - room, - load_epoch: state.room_load_epoch.fetch_add(1, Ordering::Relaxed), - }; + let load_epoch = state.room_load_epoch.fetch_add(1, Ordering::Relaxed); + state.rooms.insert( + room.name.clone(), + GossipRoom { + room: room.clone(), + load_epoch, + }, + ); + let loaded = M2BLoaded { room, load_epoch }; return vec![loaded.into()]; } MsgB2M::Unload(msg) => { @@ -85,11 +89,15 @@ impl Behavior for BehaviorLoadRooms { MsgB2M::Join(msg) => { if !state.rooms.contains_key(&msg.room) { let room = RoomMetadata::default_with_name(msg.room.clone()); - state.rooms.insert(room.name.clone(), room.clone()); - let loaded = M2BLoaded { - room, - load_epoch: state.room_load_epoch.fetch_add(1, Ordering::Relaxed), - }; + let load_epoch = state.room_load_epoch.fetch_add(1, Ordering::Relaxed); + state.rooms.insert( + room.name.clone(), + GossipRoom { + room: room.clone(), + load_epoch, + }, + ); + let loaded = M2BLoaded { room, load_epoch }; return vec![loaded.into()]; } } @@ -122,9 +130,13 @@ mod tests { fn behavior_should_unload_rooms() { let b = BehaviorLoadRooms; let mut state = MonolithState::default(); - state - .rooms - .insert("foo".into(), RoomMetadata::default_with_name("foo")); + state.rooms.insert( + "foo".into(), + GossipRoom { + room: RoomMetadata::default_with_name("foo"), + load_epoch: 0, + }, + ); let msg = MsgB2M::Unload(B2MUnload { room: "foo".into() }); let msgs = b.on_msg(&msg, &mut state); diff --git a/crates/harness/src/provider.rs b/crates/harness/src/provider.rs index c786ebaa8..d9c36fbce 100644 --- a/crates/harness/src/provider.rs +++ b/crates/harness/src/provider.rs @@ -78,3 +78,23 @@ impl DiscoveryProvider { } } } + +#[cfg(test)] +mod test { + use std::time::Duration; + + use test_context::test_context; + + use crate::{MonolithBuilder, TestRunner}; + + #[test_context(TestRunner)] + #[tokio::test] + async fn should_reconnect_when_balancer_restarts(ctx: &mut TestRunner) { + let mut m = MonolithBuilder::new().build(ctx).await; + m.show().await; + ctx.restart_balancer().await; + tokio::time::timeout(Duration::from_secs(2), m.wait_for_balancer_connect()) + .await + .expect("timed out waiting for balancer to reconnect"); + } +} diff --git a/crates/ott-balancer-bin/src/balancer.rs b/crates/ott-balancer-bin/src/balancer.rs index f80b2620c..c60a90b60 100644 --- a/crates/ott-balancer-bin/src/balancer.rs +++ b/crates/ott-balancer-bin/src/balancer.rs @@ -340,20 +340,23 @@ impl BalancerContext { ) -> anyhow::Result<()> { debug!(func = "add_or_sync_room"); if let Some(locator) = self.rooms_to_monoliths.get(&metadata.name) { - match locator.load_epoch().cmp(&load_epoch) { - std::cmp::Ordering::Less => { - // we already have an older version of this room - self.unload_room(monolith_id, metadata.name.clone()).await?; - return Err(anyhow::anyhow!("room already loaded")); - } - std::cmp::Ordering::Greater => { - // we have an newer version of this room, remove it - self.unload_room(locator.monolith_id(), metadata.name.clone()) - .await?; - // self.remove_room(&metadata.name, locator.monolith_id()) - // .await?; + if locator.monolith_id() != monolith_id { + // this room is loaded on a different monolith than we were expecting + match locator.load_epoch().cmp(&load_epoch) { + std::cmp::Ordering::Less => { + // we already have an older version of this room + self.unload_room(monolith_id, metadata.name.clone()).await?; + return Err(anyhow::anyhow!("room already loaded")); + } + std::cmp::Ordering::Greater => { + // we have an newer version of this room, remove it + self.unload_room(locator.monolith_id(), metadata.name.clone()) + .await?; + // self.remove_room(&metadata.name, locator.monolith_id()) + // .await?; + } + _ => {} } - _ => {} } } let monolith = self.monoliths.get_mut(&monolith_id).unwrap(); @@ -461,8 +464,8 @@ pub async fn join_client( if should_create_room { let room = Room::new(client.room.clone()); - // we assume the load epoch is 0 since we're creating the room. this will be updated when the monolith sends us the loaded message - ctx_write.add_room(room, RoomLocator::new(monolith_id, 0))?; + // we assume the load epoch is u32::MAX since we're creating the room. this will be updated when the monolith sends us the loaded message, or when we receive the gossip message + ctx_write.add_room(room, RoomLocator::new(monolith_id, u32::MAX))?; } ctx_write.add_client(client, monolith_id).await?; Ok(())