Skip to content

Commit

Permalink
fix: remove gateways when empty and expired
Browse files Browse the repository at this point in the history
  • Loading branch information
ilbertt committed Dec 15, 2023
1 parent 1ad3916 commit 858ebbc
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 63 deletions.
83 changes: 54 additions & 29 deletions src/ic-websocket-cdk/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
cell::RefCell,
collections::{HashMap, HashSet},
rc::Rc,
time::Duration,
};

use candid::{encode_one, Principal};
Expand Down Expand Up @@ -32,6 +33,8 @@ thread_local! {
/* flexible */ pub(crate) static CERT_TREE: RefCell<RbTree<String, ICHash>> = RefCell::new(RbTree::new());
/// Keeps track of the principals of the WS Gateways that poll the canister
/* flexible */ pub(crate) static REGISTERED_GATEWAYS: RefCell<HashMap<GatewayPrincipal, RegisteredGateway>> = RefCell::new(HashMap::new());
/// Keeps track of the gateways that must be removed from the list of registered gateways in the next ack interval
/* flexible */ pub(crate) static GATEWAYS_TO_REMOVE: RefCell<HashMap<GatewayPrincipal, TimestampNs>> = RefCell::new(HashMap::new());
/// The parameters passed in the CDK initialization
/* flexible */ pub(crate) static PARAMS: RefCell<WsInitParams> = RefCell::new(WsInitParams::default());
}
Expand Down Expand Up @@ -72,6 +75,10 @@ pub(crate) fn reset_internal_state() {
/// Increments the clients connected count for the given gateway.
/// If the gateway is not registered, a new entry is created with a clients connected count of 1.
pub(crate) fn increment_gateway_clients_count(gateway_principal: GatewayPrincipal) {
GATEWAYS_TO_REMOVE.with(|state| {
state.borrow_mut().remove(&gateway_principal);
});

REGISTERED_GATEWAYS.with(|map| {
map.borrow_mut()
.entry(gateway_principal)
Expand All @@ -82,29 +89,54 @@ pub(crate) fn increment_gateway_clients_count(gateway_principal: GatewayPrincipa

/// Decrements the clients connected count for the given gateway, if it exists.
///
/// If `remove_if_empty` is true, the gateway is removed from the list of registered gateways
/// if it has no clients connected.
pub(crate) fn decrement_gateway_clients_count(
gateway_principal: &GatewayPrincipal,
remove_if_empty: bool,
) {
let messages_keys_to_delete = REGISTERED_GATEWAYS.with(|map| {
let mut map = map.borrow_mut();
if let Some(g) = map.get_mut(gateway_principal) {
let clients_count = g.decrement_clients_count();

if remove_if_empty && clients_count == 0 {
return map
.remove(gateway_principal)
.map(|g| g.messages_queue.iter().map(|m| m.key.clone()).collect());
}
}
/// If the gateway has no more clients connected, it is added to the [GATEWAYS_TO_REMOVE] map,
/// in order to remove it in the next keep alive check.
pub(crate) fn decrement_gateway_clients_count(gateway_principal: &GatewayPrincipal) {
let is_empty = REGISTERED_GATEWAYS.with(|map| {
map.borrow_mut()
.get_mut(gateway_principal)
.is_some_and(|g| {
let clients_count = g.decrement_clients_count();
clients_count == 0
})
});

if is_empty {
GATEWAYS_TO_REMOVE.with(|state| {
state
.borrow_mut()
.insert(gateway_principal.clone(), get_current_time());
});
}
}

/// Removes the gateways that were added to the [GATEWAYS_TO_REMOVE] map
/// more than the ack interval ms time ago from the list of registered gateways
pub(crate) fn remove_empty_expired_gateways() {
let ack_interval_ms = get_params().send_ack_interval_ms;
let time = get_current_time();

let mut gateway_principals_to_remove: Vec<GatewayPrincipal> = vec![];

None
GATEWAYS_TO_REMOVE.with(|state| {
state.borrow_mut().retain(|gp, added_at| {
if Duration::from_nanos(time - *added_at) > Duration::from_millis(ack_interval_ms) {
gateway_principals_to_remove.push(gp.clone());
false
} else {
true
}
})
});

if let Some(messages_keys_to_delete) = messages_keys_to_delete {
delete_keys_from_cert_tree(messages_keys_to_delete);
for gateway_principal in &gateway_principals_to_remove {
if let Some(messages_keys_to_delete) = REGISTERED_GATEWAYS.with(|map| {
map.borrow_mut()
.remove(gateway_principal)
.map(|g| g.messages_queue.iter().map(|m| m.key.clone()).collect())
}) {
delete_keys_from_cert_tree(messages_keys_to_delete);
}
}
}

Expand Down Expand Up @@ -278,15 +310,11 @@ pub(crate) fn add_client(client_key: ClientKey, new_client: RegisteredClient) {
increment_gateway_clients_count(new_client.gateway_principal);
}

/// Removes a client from the internal state
/// and call the on_close callback,
/// Removes a client from the internal state and call the on_close callback,
/// if the client was registered in the state.
///
/// If a `close_reason` is provided, it also sends a close message to the client,
/// so that the client can close the WS connection with the gateway.
///
/// If a `close_reason` is **not** provided, it also removes the gateway from the state
/// if it has no clients connected anymore.
pub(crate) fn remove_client(client_key: &ClientKey, close_reason: Option<CloseMessageReason>) {
if let Some(close_reason) = close_reason.clone() {
// ignore the error
Expand Down Expand Up @@ -314,10 +342,7 @@ pub(crate) fn remove_client(client_key: &ClientKey, close_reason: Option<CloseMe
if let Some(registered_client) =
REGISTERED_CLIENTS.with(|map| map.borrow_mut().remove(client_key))
{
decrement_gateway_clients_count(
&registered_client.gateway_principal,
close_reason.is_none(),
);
decrement_gateway_clients_count(&registered_client.gateway_principal);

let handlers = get_handlers_from_params();
handlers.call_on_close(OnCloseCallbackArgs {
Expand Down
128 changes: 125 additions & 3 deletions src/ic-websocket-cdk/src/tests/integration_tests/c_ws_get_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use proptest::prelude::*;
use std::ops::Deref;

use crate::{
tests::common, CanisterOutputCertifiedMessages, CanisterWsGetMessagesArguments,
MESSAGES_TO_DELETE_COUNT,
tests::common, CanisterOutputCertifiedMessages, CanisterWsCloseArguments,
CanisterWsGetMessagesArguments, CLIENT_KEEP_ALIVE_TIMEOUT_MS, MESSAGES_TO_DELETE_COUNT,
};

use super::utils::{
actor::{
send::call_send_with_panic, wipe::call_wipe,
send::call_send_with_panic, wipe::call_wipe, ws_close::call_ws_close_with_panic,
ws_get_messages::call_ws_get_messages_with_panic,
ws_open::call_ws_open_for_client_key_with_panic,
},
Expand Down Expand Up @@ -227,3 +227,125 @@ proptest! {
);
}
}

#[test]
fn test_6_empty_gateway_can_get_messages_until_next_keep_alive_check() {
let send_messages_count = 10;
// first, reset the canister
call_wipe(None);
// second, register client 1
let client_1_key = CLIENT_1_KEY.deref();
call_ws_open_for_client_key_with_panic(client_1_key);
// third, send a batch of messages to the client
let messages_to_send: Vec<AppMessage> = (1..=send_messages_count)
.map(|i| AppMessage {
text: format!("test{}", i),
})
.collect();
call_send_with_panic(&client_1_key.client_principal, messages_to_send.clone());

// check that gateway can receive the messages
helpers::assert_gateway_has_messages(send_messages_count);

// disconnect the client and check that gateway can still receive the messages
call_ws_close_with_panic(
&GATEWAY_1,
CanisterWsCloseArguments {
client_key: client_1_key.clone(),
},
);

// check that gateway can still receive the messages
helpers::assert_gateway_has_messages(send_messages_count);

// wait for the ack interval to fire
get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS);

// check that gateway can still receive the messages, even after the ack interval has fired
helpers::assert_gateway_has_messages(send_messages_count);

// wait for the keep alive timeout to expire
get_test_env().advance_canister_time_ms(CLIENT_KEEP_ALIVE_TIMEOUT_MS);

helpers::assert_gateway_has_no_messages();
}

#[test]
fn test_7_empty_gateway_can_get_messages_until_next_keep_alive_check_if_removed_before_ack_interval(
) {
let send_messages_count = 10;
// first, reset the canister
call_wipe(None);
// second, register client 1
let client_1_key = CLIENT_1_KEY.deref();
call_ws_open_for_client_key_with_panic(client_1_key);
// third, send a batch of messages to the client
let messages_to_send: Vec<AppMessage> = (1..=send_messages_count)
.map(|i| AppMessage {
text: format!("test{}", i),
})
.collect();
call_send_with_panic(&client_1_key.client_principal, messages_to_send.clone());

// check that gateway can receive the messages
helpers::assert_gateway_has_messages(send_messages_count);

// wait for the ack interval to fire
get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS);

// disconnect the client and check that gateway can still receive the messages
call_ws_close_with_panic(
&GATEWAY_1,
CanisterWsCloseArguments {
client_key: client_1_key.clone(),
},
);

let expected_messages_len = send_messages_count + 1; // +1 for the ack message

// check that gateway can still receive the messages, even after the ack interval has fired
helpers::assert_gateway_has_messages(expected_messages_len);

// wait for the keep alive timeout to expire
get_test_env().advance_canister_time_ms(CLIENT_KEEP_ALIVE_TIMEOUT_MS);

// the gateway can still receive the messages, because it was emptied
// less than an ack interval ago
helpers::assert_gateway_has_messages(expected_messages_len);

// wait for next ack interval to expire
get_test_env().advance_canister_time_ms(DEFAULT_TEST_SEND_ACK_INTERVAL_MS);

// the gateway can still receive the messages, because empty expired gateways
// are removed only in the keep alive timeout callback
helpers::assert_gateway_has_messages(expected_messages_len);

// wait for the keep alive timeout to expire
get_test_env().advance_canister_time_ms(CLIENT_KEEP_ALIVE_TIMEOUT_MS);

helpers::assert_gateway_has_no_messages();
}

mod helpers {
use super::*;

pub(super) fn assert_gateway_has_messages(send_messages_count: usize) {
let CanisterOutputCertifiedMessages { messages, .. } = call_ws_get_messages_with_panic(
&GATEWAY_1,
CanisterWsGetMessagesArguments { nonce: 0 },
);
assert_eq!(
messages.len(),
// + 1 for the open service message
send_messages_count + 1,
);
}

pub(super) fn assert_gateway_has_no_messages() {
let CanisterOutputCertifiedMessages { messages, .. } = call_ws_get_messages_with_panic(
&GATEWAY_1,
CanisterWsGetMessagesArguments { nonce: 0 },
);
assert_eq!(messages.len(), 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ fn test_4_should_close_the_websocket_for_a_registered_client() {
assert_eq!(
res,
CanisterWsCloseResult::Err(
WsError::GatewayNotRegistered {
gateway_principal: GATEWAY_1.deref()
WsError::ClientKeyNotConnected {
client_key: &CLIENT_1_KEY,
}
.to_string()
)
Expand Down
Loading

0 comments on commit 858ebbc

Please sign in to comment.