diff --git a/Cargo.lock b/Cargo.lock index 1a49b5a64584f1..e607294b666618 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2254,6 +2254,7 @@ dependencies = [ "serde", "serde_json", "tinytemplate", + "tokio", "walkdir", ] @@ -10979,6 +10980,7 @@ dependencies = [ "bytes", "chrono", "clap 4.5.31", + "criterion", "crossbeam-channel", "dashmap", "futures 0.3.31", @@ -11001,6 +11003,7 @@ dependencies = [ "solana-keypair", "solana-measure", "solana-metrics", + "solana-native-token", "solana-net-utils", "solana-packet 4.0.0", "solana-perf", @@ -11012,6 +11015,7 @@ dependencies = [ "solana-tls-utils", "solana-transaction-error", "solana-transaction-metrics-tracker", + "static_assertions", "thiserror 2.0.18", "tokio", "tokio-util 0.7.18", diff --git a/dev-bins/Cargo.lock b/dev-bins/Cargo.lock index 3eb163282c7f3d..e176b8f12c1121 100644 --- a/dev-bins/Cargo.lock +++ b/dev-bins/Cargo.lock @@ -9167,6 +9167,7 @@ dependencies = [ "solana-keypair", "solana-measure", "solana-metrics", + "solana-native-token", "solana-net-utils", "solana-packet 4.0.0", "solana-perf", @@ -9177,6 +9178,7 @@ dependencies = [ "solana-tls-utils", "solana-transaction-error", "solana-transaction-metrics-tracker", + "static_assertions", "thiserror 2.0.18", "tokio", "tokio-util 0.7.18", diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index f291d261704247..4050e165d41488 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -9665,6 +9665,7 @@ dependencies = [ "solana-keypair", "solana-measure", "solana-metrics", + "solana-native-token", "solana-net-utils", "solana-packet 4.0.0", "solana-perf", @@ -9675,6 +9676,7 @@ dependencies = [ "solana-tls-utils", "solana-transaction-error", "solana-transaction-metrics-tracker", + "static_assertions", "thiserror 2.0.18", "tokio", "tokio-util 0.7.18", diff --git a/streamer/Cargo.toml b/streamer/Cargo.toml index 53788eba6512f0..ef923c6ccd684b 100644 --- a/streamer/Cargo.toml +++ b/streamer/Cargo.toml @@ -45,7 +45,8 @@ socket2 = { workspace = true } solana-keypair = { workspace = true } solana-measure = { workspace = true } solana-metrics = { workspace = true } -solana-net-utils = { workspace = true } +solana-native-token = { workspace = true } +solana-net-utils = { workspace = true, features = ["agave-unstable-api"] } solana-packet = { workspace = true } solana-perf = { workspace = true } solana-pubkey = { workspace = true } @@ -55,6 +56,7 @@ solana-time-utils = { workspace = true } solana-tls-utils = { workspace = true } solana-transaction-error = { workspace = true } solana-transaction-metrics-tracker = { workspace = true } +static_assertions = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-util = { workspace = true, features = ["rt"] } @@ -66,5 +68,10 @@ anyhow = { workspace = true } assert_matches = { workspace = true } chrono = { workspace = true, features = ["now"] } clap = { version = "4.5.31", features = ["cargo", "derive", "error-context"] } +criterion = { workspace = true, features = ["async", "async_tokio"] } solana-net-utils = { workspace = true, features = ["dev-context-only-utils"] } solana-streamer = { path = ".", features = ["agave-unstable-api", "dev-context-only-utils"] } + +[[bench]] +name = "bench_refiller" +harness = false diff --git a/streamer/benches/bench_refiller.rs b/streamer/benches/bench_refiller.rs new file mode 100644 index 00000000000000..7d12c91cdfeff4 --- /dev/null +++ b/streamer/benches/bench_refiller.rs @@ -0,0 +1,51 @@ +use { + criterion::{criterion_group, criterion_main, Criterion}, + solana_streamer::{ + nonblocking::{ + stream_throttle::{Refiller, StreamRateLimiter}, + testing_utilities::fill_connection_table, + }, + quic::StreamerStats, + }, + std::{ + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + }, + tokio::sync::Mutex, +}; + +const NUM_CLIENTS: usize = 10000; +fn bench_refiller(c: &mut Criterion) { + let stats = Arc::new(StreamerStats::default()); + let sockets: Vec<_> = (0..NUM_CLIENTS as u32) + .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(i)), 0)) + .collect(); + + let rate_limiters: Vec<_> = sockets + .iter() + .map(|_| Arc::new(StreamRateLimiter::new_unstaked())) + .collect(); + let connection_table1 = fill_connection_table(&sockets, &rate_limiters, stats.clone()); + let connection_table2 = fill_connection_table(&sockets, &rate_limiters, stats); + let rt = tokio::runtime::Builder::new_current_thread() + .build() + .unwrap(); + + // Using mutexes here is obviously not ideal, + // but we are looking to prove that refilling is not expensive, + // rather than to get super accurate data + let refiller = + Arc::new(Mutex::new(rt.block_on(async { + Refiller::new(connection_table1, connection_table2).await + }))); + + c.bench_function(&format!("do_refill_{NUM_CLIENTS}"), |b| { + b.to_async(&rt).iter(|| async { + refiller.lock().await.do_refill(100000).await; + }); + }); +} + +criterion_group!(benches, bench_refiller); + +criterion_main!(benches); diff --git a/streamer/examples/swqos.rs b/streamer/examples/swqos.rs index e617a471630c41..f0f922c4a2b59a 100644 --- a/streamer/examples/swqos.rs +++ b/streamer/examples/swqos.rs @@ -84,6 +84,9 @@ struct Cli { #[arg(short, long)] stake_amounts: String, + + #[arg(short, long, default_value_t = 100_000)] + max_tps: u64, } // number of threads as in fn default_num_tpu_transaction_forward_receive_threads @@ -108,7 +111,6 @@ async fn main() -> anyhow::Result<()> { ); Arc::new(RwLock::new(nodes)) }; - let cancel = CancellationToken::new(); let SpawnNonBlockingServerResult { endpoints, @@ -127,6 +129,7 @@ async fn main() -> anyhow::Result<()> { SwQosConfig { max_connections_per_staked_peer: cli.max_connections_per_staked_peer, max_connections_per_unstaked_peer: cli.max_connections_per_unstaked_peer, + max_streams_per_ms: cli.max_tps / 1000, ..Default::default() }, cancel.clone(), diff --git a/streamer/src/nonblocking/mod.rs b/streamer/src/nonblocking/mod.rs index 28390acf2cd172..fdf7451d5c4bbd 100644 --- a/streamer/src/nonblocking/mod.rs +++ b/streamer/src/nonblocking/mod.rs @@ -5,7 +5,7 @@ pub mod quic; pub mod recvmmsg; pub mod sendmmsg; pub mod simple_qos; -mod stream_throttle; +pub mod stream_throttle; pub mod swqos; #[cfg(feature = "dev-context-only-utils")] pub mod testing_utilities; diff --git a/streamer/src/nonblocking/qos.rs b/streamer/src/nonblocking/qos.rs index cda14f1f2066ff..ed8b0ea4b8d88f 100644 --- a/streamer/src/nonblocking/qos.rs +++ b/streamer/src/nonblocking/qos.rs @@ -16,7 +16,12 @@ pub(crate) trait ConnectionContext: Clone + Send + Sync { /// A trait to manage QoS for connections. This includes /// 1) deriving the ConnectionContext for a connection /// 2) managing connection caching and connection limits, stream limits -pub(crate) trait QosController { +pub(crate) trait QosController { + /// Initialize the controller's async logic (if any) + fn async_init(&mut self) -> impl std::future::Future + std::marker::Send { + async {} + } + /// Build the ConnectionContext for a connection fn build_connection_context(&self, connection: &Connection) -> C; @@ -57,7 +62,7 @@ pub(crate) trait QosController { } /// Marker trait to indicate what is the shared state for connections -pub(crate) trait OpaqueStreamerCounter: Send + Sync + 'static {} +pub trait OpaqueStreamerCounter: Send + Sync + 'static {} #[cfg(test)] pub(crate) struct NullStreamerCounter; diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index 7562927b61c9b3..4d5230e96b9c87 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -141,7 +141,7 @@ pub(crate) fn spawn_server( keypair: &Keypair, packet_sender: Sender, quic_server_params: QuicStreamerConfig, - qos: Arc, + qos: Q, cancel: CancellationToken, ) -> Result where @@ -223,7 +223,10 @@ impl Drop for ClientConnectionTracker { impl ClientConnectionTracker { /// Check the max_concurrent_connections limit and if it is within the limit /// create ClientConnectionTracker and increment open connection count. Otherwise returns Err - fn new(stats: Arc, max_concurrent_connections: usize) -> Result { + pub(crate) fn new( + stats: Arc, + max_concurrent_connections: usize, + ) -> Result { let open_connections = stats.open_connections.fetch_add(1, Ordering::Relaxed); if open_connections >= max_concurrent_connections { stats.open_connections.fetch_sub(1, Ordering::Relaxed); @@ -246,12 +249,14 @@ async fn run_server( stats: Arc, quic_server_params: QuicStreamerConfig, cancel: CancellationToken, - qos: Arc, + mut qos: Q, ) -> TaskTracker where Q: QosController + Send + Sync + 'static, C: ConnectionContext + Send + Sync + 'static, { + qos.async_init().await; + let qos = Arc::new(qos); let quic_server_params = Arc::new(quic_server_params); let rate_limiter = Arc::new(ConnectionRateLimiter::new( quic_server_params.max_connections_per_ipaddr_per_min, @@ -863,15 +868,15 @@ fn handle_chunks( Ok(StreamState::Finished) } -struct ConnectionEntry { +pub(crate) struct ConnectionEntry { cancel: CancellationToken, - peer_type: ConnectionPeerType, + pub(crate) peer_type: ConnectionPeerType, last_update: Arc, port: u16, // We do not explicitly use it, but its drop is triggered when ConnectionEntry is dropped. _client_connection_tracker: ClientConnectionTracker, connection: Option, - stream_counter: Arc, + pub(crate) stream_counter: Arc, } impl ConnectionEntry { @@ -939,9 +944,10 @@ pub(crate) enum ConnectionTableType { } // Map of IP to list of connection entries -pub(crate) struct ConnectionTable { +pub struct ConnectionTable { table: IndexMap>>, - pub(crate) total_size: usize, + total_size: usize, + connected_stake: u64, table_type: ConnectionTableType, cancel: CancellationToken, } @@ -954,35 +960,51 @@ impl ConnectionTable { Self { table: IndexMap::default(), total_size: 0, + connected_stake: 0, table_type, cancel, } } - fn table_size(&self) -> usize { + pub(crate) fn table_size(&self) -> usize { self.total_size } + pub(crate) fn connected_stake(&self) -> u64 { + self.connected_stake + } + + fn register_connection(&mut self, stake: u64) { + self.connected_stake = self.connected_stake.saturating_add(stake); + self.total_size = self.total_size.saturating_add(1); + } + + fn deregister_connection(&mut self, connection: ConnectionEntry) { + self.connected_stake = self.connected_stake.saturating_sub(connection.stake()); + self.total_size = self.total_size.saturating_sub(1); + } + fn is_staked(&self) -> bool { matches!(self.table_type, ConnectionTableType::Staked) } pub(crate) fn prune_oldest(&mut self, max_size: usize) -> usize { - let mut num_pruned = 0; let key = |(_, connections): &(_, &Vec<_>)| { connections.iter().map(ConnectionEntry::last_update).min() }; - while self.total_size.saturating_sub(num_pruned) > max_size { + let old_size = self.table_size(); + while self.table_size() > max_size { match self.table.values().enumerate().min_by_key(key) { None => break, - Some((index, connections)) => { - num_pruned += connections.len(); - self.table.swap_remove_index(index); + Some((index, _connections)) => { + let (_, mut connections) = self.table.swap_remove_index(index).unwrap(); + connections + .drain(..) + .for_each(|c| self.deregister_connection(c)); } } } - self.total_size = self.total_size.saturating_sub(num_pruned); - num_pruned + old_size.saturating_sub(self.table_size()) } // Randomly selects sample_size many connections, evicts the one with the @@ -990,8 +1012,10 @@ impl ConnectionTable { // If the stakes of all the sampled connections are higher than the // threshold_stake, rejects the pruning attempt, and returns 0. pub(crate) fn prune_random(&mut self, sample_size: usize, threshold_stake: u64) -> usize { + if self.table.is_empty() { + return 0; + } let num_pruned = std::iter::once(self.table.len()) - .filter(|&size| size > 0) .flat_map(|size| { let mut rng = rng(); repeat_with(move || rng.random_range(0..size)) @@ -1005,9 +1029,13 @@ impl ConnectionTable { .min_by_key(|&(_, stake)| stake) .filter(|&(_, stake)| stake < Some(threshold_stake)) .and_then(|(index, _)| self.table.swap_remove_index(index)) - .map(|(_, connections)| connections.len()) + .map(|(_, mut connections)| { + connections + .drain(..) + .map(|c| self.deregister_connection(c)) + .count() + }) .unwrap_or_default(); - self.total_size = self.total_size.saturating_sub(num_pruned); num_pruned } @@ -1034,7 +1062,7 @@ impl ConnectionTable { .first() .map(|entry| entry.stream_counter.clone()) .unwrap_or_else(stream_counter_factory); - connection_entry.push(ConnectionEntry::new( + let entry = ConnectionEntry::new( cancel.clone(), peer_type, last_update.clone(), @@ -1042,8 +1070,10 @@ impl ConnectionTable { client_connection_tracker, connection, stream_counter.clone(), - )); - self.total_size += 1; + ); + let entry_stake = entry.stake(); + connection_entry.push(entry); + self.register_connection(entry_stake); Some((last_update, cancel, stream_counter)) } else { if let Some(connection) = connection { @@ -1056,40 +1086,52 @@ impl ConnectionTable { } } - // Returns number of connections that were removed + /// Return an iterator over connection table entries. + pub(crate) fn iter( + &self, + ) -> indexmap::map::Iter<'_, ConnectionTableKey, Vec>> { + self.table.iter() + } + + /// Returns number of connections that were removed pub(crate) fn remove_connection( &mut self, key: ConnectionTableKey, port: u16, stable_id: usize, ) -> usize { + let mut removed = vec![]; if let Entry::Occupied(mut e) = self.table.entry(key) { let e_ref = e.get_mut(); - let old_size = e_ref.len(); - - e_ref.retain(|connection_entry| { + let mut index = 0; + while index < e_ref.len() { // Retain the connection entry if the port is different, or if the connection's // stable_id doesn't match the provided stable_id. // (Some unit tests do not fill in a valid connection in the table. To support that, // if the connection is none, the stable_id check is ignored. i.e. if the port matches, // the connection gets removed) - connection_entry.port != port + let connection_entry = &e_ref[index]; + if connection_entry.port != port || connection_entry .connection .as_ref() .and_then(|connection| (connection.stable_id() != stable_id).then_some(0)) .is_some() - }); - let new_size = e_ref.len(); + { + index += 1; + } else { + removed.push(e_ref.swap_remove(index)); + } + } + if e_ref.is_empty() { e.swap_remove_entry(); } - let connections_removed = old_size.saturating_sub(new_size); - self.total_size = self.total_size.saturating_sub(connections_removed); - connections_removed - } else { - 0 } + removed + .drain(..) + .map(|entry| self.deregister_connection(entry)) + .count() } } @@ -1732,8 +1774,7 @@ pub mod test { } let new_size = 3; - let pruned = table.prune_oldest(new_size); - assert_eq!(pruned, num_entries as usize - new_size); + table.prune_oldest(new_size); assert_eq!(table.table.len(), new_size); assert_eq!(table.total_size, new_size); for pubkey in pubkeys.iter().take(num_entries as usize).skip(new_size - 1) { @@ -1802,8 +1843,7 @@ pub mod test { assert_eq!(table.total_size, num_entries); let new_max_size = 3; - let pruned = table.prune_oldest(new_max_size); - assert!(pruned >= num_entries - new_max_size); + table.prune_oldest(new_max_size); assert!(table.table.len() <= new_max_size); assert!(table.total_size <= new_max_size); @@ -1928,7 +1968,7 @@ pub mod test { } #[tokio::test(flavor = "multi_thread")] - async fn test_throttling_check_no_packet_drop() { + async fn test_no_throttling_check() { agave_logger::setup_with_default_filter(); let SpawnTestServerResult { @@ -1945,7 +1985,7 @@ pub mod test { let client_connection = make_client_endpoint(&server_address, None).await; - // unstaked connection can handle up to 100tps, so we should send in ~1s. + // unstaked connection should not get throttled since there is no other load let expected_num_txs = 100; let start_time = tokio::time::Instant::now(); for i in 0..expected_num_txs { @@ -1976,7 +2016,7 @@ pub mod test { stats.total_new_streams.load(Ordering::Relaxed), expected_num_txs ); - assert!(stats.throttled_unstaked_streams.load(Ordering::Relaxed) > 0); + assert!(stats.throttled_ms_unstaked.load(Ordering::Relaxed) == 0); } #[test] diff --git a/streamer/src/nonblocking/simple_qos.rs b/streamer/src/nonblocking/simple_qos.rs index 352f4d93ade434..153421a66cc689 100644 --- a/streamer/src/nonblocking/simple_qos.rs +++ b/streamer/src/nonblocking/simple_qos.rs @@ -9,8 +9,7 @@ use { }, }, quic::{ - StreamerStats, DEFAULT_MAX_QUIC_CONNECTIONS_PER_STAKED_PEER, - DEFAULT_MAX_STAKED_CONNECTIONS, DEFAULT_MAX_STREAMS_PER_MS, + StreamerStats, DEFAULT_MAX_QUIC_CONNECTIONS_PER_STAKED_PEER, DEFAULT_MAX_STREAMS_PER_MS, }, streamer::StakedNodes, }, @@ -43,11 +42,13 @@ pub struct SimpleQosConfig { pub max_connections_per_peer: usize, } +const DEFAULT_MAX_VOTING_CONNECTIONS: usize = 8000; + impl Default for SimpleQosConfig { fn default() -> Self { SimpleQosConfig { max_streams_per_second: DEFAULT_MAX_STREAMS_PER_MS * 1000, - max_staked_connections: DEFAULT_MAX_STAKED_CONNECTIONS, + max_staked_connections: DEFAULT_MAX_VOTING_CONNECTIONS, max_connections_per_peer: DEFAULT_MAX_QUIC_CONNECTIONS_PER_STAKED_PEER, } } @@ -186,7 +187,7 @@ impl QosController for SimpleQos { ConnectionPeerType::Staked(stake) => { let mut connection_table_l = self.staked_connection_table.lock().await; - if connection_table_l.total_size >= self.config.max_staked_connections { + if connection_table_l.table_size() >= self.config.max_staked_connections { let num_pruned = connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake); @@ -202,7 +203,7 @@ impl QosController for SimpleQos { update_open_connections_stat(&self.stats, &connection_table_l); } - if connection_table_l.total_size < self.config.max_staked_connections { + if connection_table_l.table_size() < self.config.max_staked_connections { if let Ok((last_update, cancel_connection, stream_counter)) = self .cache_new_connection( client_connection_tracker, @@ -274,23 +275,26 @@ impl QosController for SimpleQos { while stream_counter.consume_tokens(1).is_err() { debug!("Throttling stream from {remote_addr:?}"); - self.stats.throttled_streams.fetch_add(1, Ordering::Relaxed); + let min_sleep = stream_counter.us_to_have_tokens(1).expect( + "Valid QoS configurations guarantee enough token bucket fits at least one \ + token", + ); + self.stats + .throttled_time_ms + .fetch_add(min_sleep / 1000, Ordering::Relaxed); match peer_type { ConnectionPeerType::Unstaked => { self.stats - .throttled_unstaked_streams - .fetch_add(1, Ordering::Relaxed); + .throttled_ms_unstaked + .fetch_add(min_sleep / 1000, Ordering::Relaxed); } ConnectionPeerType::Staked(_) => { self.stats - .throttled_staked_streams - .fetch_add(1, Ordering::Relaxed); + .throttled_ms_staked + .fetch_add(min_sleep / 1000, Ordering::Relaxed); } } - let min_sleep = stream_counter.us_to_have_tokens(1).expect( - "Valid QoS configurations guarantee enough token bucket fits at least one \ - token", - ); + sleep(Duration::from_micros(min_sleep)).await; } } diff --git a/streamer/src/nonblocking/stream_throttle.rs b/streamer/src/nonblocking/stream_throttle.rs index 3fd4917d7c4dc8..351a0ca4a990ed 100644 --- a/streamer/src/nonblocking/stream_throttle.rs +++ b/streamer/src/nonblocking/stream_throttle.rs @@ -1,418 +1,693 @@ +//! This module implements the stake-weighted read throttling logic: +//! * Each connected client is assigned a fraction of total TPS budget R +//! in proportion to its stake. +//! * All unused capacity gets added to R for the next round of allocation +//! * Unstaked nodes get granted fake stake as if they were staked +//! +//! To better utilize the bandwidth, the effective usage rate is +//! estimated based on how much they have consumed last round, and refill +//! proportion is based on the total usage rate, not true stake. +//! This allows to efficiently reassign underutilized bandwidth to other +//! users, as their effective stake in the overall allocation grows. + +#![deny(clippy::arithmetic_side_effects)] use { crate::{ - nonblocking::{qos::OpaqueStreamerCounter, quic::ConnectionPeerType}, + nonblocking::{qos::OpaqueStreamerCounter, quic::ConnectionTable}, quic::StreamerStats, }, - percentage::Percentage, + solana_native_token::LAMPORTS_PER_SOL, + solana_pubkey::Pubkey, + static_assertions::const_assert, std::{ + cmp, sync::{ - atomic::{AtomicBool, AtomicU64, Ordering}, - Arc, RwLock, + atomic::{AtomicU64, Ordering}, + Arc, }, time::{Duration, Instant}, }, - tokio::time::sleep, + tokio::sync::{Mutex, Notify}, + tokio_util::sync::CancellationToken, }; -/// Max TPS allowed for unstaked connection -const MAX_UNSTAKED_TPS: u64 = 200; -/// Expected % of max TPS to be consumed by unstaked connections -const EXPECTED_UNSTAKED_STREAMS_PERCENT: u64 = 20; - -pub const STREAM_THROTTLING_INTERVAL_MS: u64 = 100; -pub const STREAM_THROTTLING_INTERVAL: Duration = - Duration::from_millis(STREAM_THROTTLING_INTERVAL_MS); -const STREAM_LOAD_EMA_INTERVAL_MS: u64 = 5; -// EMA smoothing window to reduce sensitivity to short-lived load spikes at the start -// of a leader slot. Throttling is only triggered when saturation is sustained. -// The value 40 was chosen based on simulations: at a max target TPS of ~400K, -// it allows the system to absorb a burst of ~50K transactions over ~40 ms -// before throttling activates. -const STREAM_LOAD_EMA_INTERVAL_COUNT: u64 = 40; - -const STAKED_THROTTLING_ON_LOAD_THRESHOLD_PERCENT: u64 = 95; - -pub(crate) struct StakedStreamLoadEMA { - current_load_ema: AtomicU64, - load_in_recent_interval: AtomicU64, - last_update: RwLock, - stats: Arc, - max_staked_load_in_throttling_window: u64, - max_unstaked_load_in_throttling_window: u64, - max_streams_per_ms: u64, - staked_throttling_on_load_threshold: u64, // in streams/STREAM_LOAD_EMA_INTERVAL_MS - staked_throttling_enabled: AtomicBool, +/// This will be added to the true stake amount in the +/// calculations to ensure that unstaked nodes have non-zero throughput +pub const BASE_STAKE_SOL: u64 = 1000; + +/// This defines the minimum amount effective stake a connection +/// can hold when it is completely inactive. +const FULLY_DECAYED_STAKE_SOL: u64 = 1; + +// fully decayed stake should be less than base stake +// to make sure unstaked can get properly drained of their +// bandwidth when they are not active. +const_assert!(FULLY_DECAYED_STAKE_SOL < BASE_STAKE_SOL); + +/// Interval of refills for the QoS token buckets +pub const REFILL_INTERVAL: Duration = Duration::from_millis(20); + +/// How many [`REFILL_INTERVAL`] worth of token refill rate do we accumulate +/// for idle connections (to handle bursts of arrivals). +/// +/// For example, given `REFILL_INTERVAL = 100ms` and `MAX_BURST = 3` we can +/// * sustain 4x rate for 100ms or +/// * sustain 2x rate for 200ms +/// +/// `MAX_BURST = 1` disables the feature +const MAX_BURST: u64 = 2; + +/// Amount of tokens to gift to new connections +/// This allows fresh connections to immediately start sending +const INITIAL_TOKENS: u16 = 512; + +/// Minimal size of the token bucket +const MIN_BUCKET_SIZE: u64 = 2; +const MAX_BUCKET_SIZE: u64 = 2000; +const_assert!(MAX_BUCKET_SIZE < u16::MAX as u64); + +/// An abstraction to pack all of the token bucket state into one +/// AtomicU64 variable. This makes the updates easier to reason about. +#[derive(Debug, Default)] +struct BucketState { + tokens: u16, + consumed: u16, + last_refill: u16, } -impl StakedStreamLoadEMA { - pub(crate) fn new( - stats: Arc, - max_unstaked_connections: usize, - max_streams_per_ms: u64, - ) -> Self { - let allow_unstaked_streams = max_unstaked_connections > 0; - let max_staked_load_in_ms = if allow_unstaked_streams { - max_streams_per_ms - - Percentage::from(EXPECTED_UNSTAKED_STREAMS_PERCENT).apply_to(max_streams_per_ms) - } else { - max_streams_per_ms - }; - - let max_staked_load_in_ema_interval = max_staked_load_in_ms * STREAM_LOAD_EMA_INTERVAL_MS; - let max_staked_load_in_throttling_window = - max_staked_load_in_ms * STREAM_THROTTLING_INTERVAL_MS; - - let max_unstaked_load_in_throttling_window = if allow_unstaked_streams { - MAX_UNSTAKED_TPS * STREAM_THROTTLING_INTERVAL_MS / 1000 - } else { - 0 - }; - - let staked_throttling_on_load_threshold = - Percentage::from(STAKED_THROTTLING_ON_LOAD_THRESHOLD_PERCENT) - .apply_to(max_staked_load_in_ema_interval); +#[derive(Debug)] +pub struct StreamRateLimiter { + pub true_stake_sol: u64, + pub effective_stake_sol: AtomicU64, + pub address: Pubkey, + pub bucket_state: AtomicU64, // actually stores BucketState + wake_notify: Notify, +} +impl StreamRateLimiter { + pub fn new(address: Pubkey, stake_lamports: u64) -> Self { + let stake_sol = (stake_lamports / LAMPORTS_PER_SOL).max(BASE_STAKE_SOL); Self { - current_load_ema: AtomicU64::default(), - load_in_recent_interval: AtomicU64::default(), - last_update: RwLock::new(Instant::now()), - stats, - max_staked_load_in_throttling_window, - max_unstaked_load_in_throttling_window, - max_streams_per_ms, - staked_throttling_on_load_threshold, - staked_throttling_enabled: AtomicBool::new(false), + true_stake_sol: stake_sol, + effective_stake_sol: AtomicU64::new(stake_sol), + bucket_state: AtomicU64::new( + BucketState { + tokens: INITIAL_TOKENS, + consumed: 0, + last_refill: INITIAL_TOKENS, + } + .into(), + ), + address, + wake_notify: Notify::new(), } } - fn ema_function(current_ema: u128, recent_load: u128) -> u128 { - // Using the EMA multiplier helps in avoiding the floating point math during EMA related calculations - const STREAM_LOAD_EMA_MULTIPLIER: u128 = 1024; - let multiplied_smoothing_factor: u128 = - 2 * STREAM_LOAD_EMA_MULTIPLIER / (u128::from(STREAM_LOAD_EMA_INTERVAL_COUNT) + 1); - - // The formula is - // updated_ema = recent_load * smoothing_factor + current_ema * (1 - smoothing_factor) - // To avoid floating point math, we are using STREAM_LOAD_EMA_MULTIPLIER - // updated_ema = (recent_load * multiplied_smoothing_factor - // + current_ema * (multiplier - multiplied_smoothing_factor)) / multiplier - (recent_load * multiplied_smoothing_factor - + current_ema * (STREAM_LOAD_EMA_MULTIPLIER - multiplied_smoothing_factor)) - / STREAM_LOAD_EMA_MULTIPLIER + pub fn new_unstaked() -> Self { + Self::new(Pubkey::new_unique(), 0) } - fn update_ema(&self, time_since_last_update_ms: u128) { - // if time_since_last_update_ms > STREAM_LOAD_EMA_INTERVAL_MS, there might be intervals where ema was not updated. - // count how many updates (1 + missed intervals) are needed. - let num_extra_updates = - time_since_last_update_ms.saturating_sub(1) / u128::from(STREAM_LOAD_EMA_INTERVAL_MS); - - let load_in_recent_interval = - u128::from(self.load_in_recent_interval.swap(0, Ordering::Relaxed)); - - let mut updated_load_ema = Self::ema_function( - u128::from(self.current_load_ema.load(Ordering::Relaxed)), - load_in_recent_interval, - ); - - for _ in 0..num_extra_updates { - updated_load_ema = Self::ema_function(updated_load_ema, 0); - if updated_load_ema == 0 { - break; - } - } + #[cfg(test)] + fn tokens(&self) -> u64 { + BucketState::from(self.bucket_state.load(Ordering::Relaxed)).tokens as u64 + } - let Ok(updated_load_ema) = u64::try_from(updated_load_ema) else { - error!("Failed to convert EMA {updated_load_ema} to a u64. Not updating the load EMA"); - self.stats - .stream_load_ema_overflow - .fetch_add(1, Ordering::Relaxed); - return; - }; - - if self.staked_throttling_on_load_threshold > 0 { - self.staked_throttling_enabled.store( - updated_load_ema >= self.staked_throttling_on_load_threshold, - Ordering::Relaxed, + /// try to consume a token from the throttler, if it can not it will block. + pub async fn wait_for_token(&self, stats: &StreamerStats) { + while self + .bucket_state + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| { + let BucketState { + tokens, + consumed, + last_refill, + } = BucketState::from(v); + if tokens == 0 { + None + } else { + Some( + BucketState { + tokens: tokens.saturating_sub(1), + consumed: consumed.saturating_add(1), + last_refill, + } + .into(), + ) + } + }) + .is_err() + { + let t0 = Instant::now(); + self.wake_notify.notified().await; + let throttled_ms = t0.elapsed().as_millis() as u64; + trace!( + "Throttled connection from {} for {} ms", + self.address, + throttled_ms ); + if self.true_stake_sol <= BASE_STAKE_SOL { + stats + .throttled_ms_unstaked + .fetch_add(throttled_ms, Ordering::Relaxed); + } else { + stats + .throttled_ms_staked + .fetch_add(throttled_ms, Ordering::Relaxed); + } + stats + .throttled_time_ms + .fetch_add(throttled_ms, Ordering::Relaxed); } - - self.current_load_ema - .store(updated_load_ema, Ordering::Relaxed); - self.stats - .stream_load_ema - .store(updated_load_ema as usize, Ordering::Relaxed); } - pub(crate) fn update_ema_if_needed(&self) { - const EMA_DURATION: Duration = Duration::from_millis(STREAM_LOAD_EMA_INTERVAL_MS); - // Read lock enables multiple connection handlers to run in parallel if interval is not expired - if Instant::now().duration_since(*self.last_update.read().unwrap()) >= EMA_DURATION { - let mut last_update_w = self.last_update.write().unwrap(); - // Recheck as some other thread might have updated the ema since this thread tried to acquire the write lock. - let since_last_update = Instant::now().duration_since(*last_update_w); - if since_last_update >= EMA_DURATION { - *last_update_w = Instant::now(); - self.update_ema(since_last_update.as_millis()); + /// A helper to drain the bucket as if it was actually used via repeated calls + /// to `wait_for_token`. This is not safe to call in a multithreaded context. + #[cfg(test)] + fn drain(&self) -> u64 { + let BucketState { + tokens, + consumed: _, + last_refill, + } = BucketState::from(self.bucket_state.load(Ordering::SeqCst)); + + self.bucket_state.store( + BucketState { + tokens: 0, + consumed: tokens, + last_refill, } - } + .into(), + Ordering::SeqCst, + ); + tokens as u64 } - pub(crate) fn increment_load(&self, peer_type: ConnectionPeerType) { - if peer_type.is_staked() { - self.load_in_recent_interval.fetch_add(1, Ordering::Relaxed); - } - self.update_ema_if_needed(); + /// Refill the token bucket. Updates the effective stake internally + /// and returns the new value. + pub fn refill(&self, refill_amount: u16, my_max_tokens: u16) -> u64 { + debug_assert!( + refill_amount > 0, + "Refill should never be zero to ensure stake decay works" + ); + let mut consumed_snapshot = 0; + let mut last_refill_snapshot = 0; + self.bucket_state + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |v| { + let BucketState { + tokens, + consumed, + last_refill, + } = BucketState::from(v); + + consumed_snapshot = consumed; + last_refill_snapshot = last_refill; + + let new_tokens = (tokens.saturating_add(refill_amount)).min(my_max_tokens); + + Some( + BucketState { + tokens: new_tokens, + consumed: 0, + last_refill: refill_amount, + } + .into(), + ) + }) + .unwrap(); + + self.wake_notify.notify_waiters(); + // compute effective stake based on the utiliation of last refill + let effective_stake = self + .true_stake_sol + .saturating_mul(consumed_snapshot as u64) + // last_refill would be zero immediately after creation of the connection + .checked_div(last_refill_snapshot as u64) + .unwrap_or(self.true_stake_sol) + .clamp(FULLY_DECAYED_STAKE_SOL, self.true_stake_sol); + + self.effective_stake_sol + .store(effective_stake, Ordering::Relaxed); + effective_stake } - pub(crate) fn available_load_capacity_in_throttling_duration( - &self, - peer_type: ConnectionPeerType, - total_stake: u64, - ) -> u64 { - match peer_type { - ConnectionPeerType::Unstaked => self.max_unstaked_load_in_throttling_window, - ConnectionPeerType::Staked(stake) => { - if self.staked_throttling_enabled.load(Ordering::Relaxed) { - // 1 is added to `max_unstaked_load_in_throttling_window` to guarantee that staked - // clients get at least 1 more number of streams than unstaked connections. - self.max_staked_load_in_throttling_window - .saturating_mul(stake) - .checked_div(total_stake) - .unwrap_or(self.max_unstaked_load_in_throttling_window + 1) - .max(self.max_unstaked_load_in_throttling_window + 1) - } else { - self.max_staked_load_in_throttling_window - } - } - } + #[inline] + fn effective_stake(&self) -> u64 { + self.effective_stake_sol.load(Ordering::Relaxed) } +} - pub(crate) fn max_streams_per_ms(&self) -> u64 { - self.max_streams_per_ms +const fn token_fill_per_interval(max_tps: u64) -> u64 { + max_tps + .saturating_mul(REFILL_INTERVAL.as_millis() as u64) + .saturating_div(1000) +} + +pub async fn refill_task( + staked_connection_table: Arc>>, + unstaked_connection_table: Arc>>, + max_tps: u64, + cancel: CancellationToken, +) { + debug!("Spawning refill task with {max_tps} TPS"); + + let mut interval = tokio::time::interval(REFILL_INTERVAL); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Burst); + let mut refiller = Refiller::new(staked_connection_table, unstaked_connection_table).await; + while !cancel.is_cancelled() { + interval.tick().await; // first tick completes instantly + refiller.do_refill(max_tps).await; } } -#[derive(Debug)] -pub struct ConnectionStreamCounter { - pub(crate) stream_count: AtomicU64, - last_throttling_instant: RwLock, +pub struct Refiller { + last_iter_total_stake: u64, + last_iter_effective_stake: u64, + staked_connection_table: Arc>>, + unstaked_connection_table: Arc>>, } -impl OpaqueStreamerCounter for ConnectionStreamCounter {} +impl Refiller { + pub async fn new( + staked_connection_table: Arc>>, + unstaked_connection_table: Arc>>, + ) -> Self { + let last_iter_total_stake = { + // initialize effective stake of unstaked connections based on their number + let guard = unstaked_connection_table.lock().await; + (guard.table_size() as u64).saturating_mul(BASE_STAKE_SOL) + } + .saturating_add({ + // and for staked use actual stake + let guard = staked_connection_table.lock().await; + guard.connected_stake() / LAMPORTS_PER_SOL + }); -impl ConnectionStreamCounter { - pub fn new() -> Self { Self { - stream_count: AtomicU64::default(), - last_throttling_instant: RwLock::new(tokio::time::Instant::now()), + last_iter_total_stake, + last_iter_effective_stake: last_iter_total_stake, + staked_connection_table, + unstaked_connection_table, } } - /// Reset the counter and last throttling instant and - /// return last_throttling_instant regardless it is reset or not. - pub(crate) fn reset_throttling_params_if_needed(&self) -> tokio::time::Instant { - let last_throttling_instant = *self.last_throttling_instant.read().unwrap(); - if tokio::time::Instant::now().duration_since(last_throttling_instant) - > STREAM_THROTTLING_INTERVAL - { - let mut last_throttling_instant = self.last_throttling_instant.write().unwrap(); - // Recheck as some other thread might have done throttling since this thread tried to acquire the write lock. - if tokio::time::Instant::now().duration_since(*last_throttling_instant) - > STREAM_THROTTLING_INTERVAL - { - *last_throttling_instant = tokio::time::Instant::now(); - self.stream_count.store(0, Ordering::Relaxed); + ///Refill the rate limiters in connection tables + #[allow(clippy::arithmetic_side_effects)] + pub async fn do_refill(&mut self, max_tps: u64) -> u64 { + // counts allocated tokens for debugging + let mut allocated_tokens = 0; + // retrieve stats from last iteration, make sure we do not get zero in + // the total counters to avoid division by zero. + let total_effective_stake = self.last_iter_effective_stake.max(BASE_STAKE_SOL); + self.last_iter_effective_stake = 0; + let total_stake = self.last_iter_total_stake.max(BASE_STAKE_SOL); + self.last_iter_total_stake = 0; + let token_fill_per_interval = token_fill_per_interval(max_tps); + + for conn_table in [ + &self.staked_connection_table, + &self.unstaked_connection_table, + ] { + let guard = conn_table.lock().await; + for (_key, connection_entry_vec) in guard.iter() { + let Some(connection_entry) = connection_entry_vec.first() else { + continue; + }; + let entry = connection_entry.stream_counter.as_ref(); + let entry_effective_stake = entry.effective_stake(); + + // minimal amount this bucket should be able to hold (proportional to stake) + // this allows high-staked connections to ramp up faster + let token_capacity_min = + (token_fill_per_interval * MAX_BURST * entry.true_stake_sol / total_stake) + .clamp(MIN_BUCKET_SIZE, MAX_BUCKET_SIZE); + // share of total amount to deposit in this token bucket + let my_token_share = + token_fill_per_interval * entry_effective_stake / total_effective_stake; + // make sure we always deposit some tokens + let my_token_share = my_token_share.clamp(1, u16::MAX as u64); + // maximal amount this bucket should be able to hold (proportional to effective stake) + let tokens_capacity_max = + (my_token_share * MAX_BURST).clamp(token_capacity_min, MAX_BUCKET_SIZE); + trace!( + "Grant {my_token_share} (max {tokens_capacity_max}) TXs to {} based on \ + {}/{total_effective_stake} sol of stake.", + entry.address, + entry.effective_stake() + ); + allocated_tokens += my_token_share; + // fill the bucket with all available tokens + // record effective stake of the entry after refill to keep track + // of state as connections come and go + self.last_iter_effective_stake += + entry.refill(my_token_share as u16, tokens_capacity_max as u16); + + self.last_iter_total_stake += entry.true_stake_sol; } - *last_throttling_instant - } else { - last_throttling_instant } + trace!( + "Allocated {allocated_tokens} tokens out of {token_fill_per_interval} to users. \ + total_effective_stake={total_effective_stake}, total_stake={total_stake}" + ); + allocated_tokens } } -pub(crate) async fn throttle_stream( - stats: &StreamerStats, - peer_type: ConnectionPeerType, - remote_addr: std::net::SocketAddr, - stream_counter: &Arc, - max_streams_per_throttling_interval: u64, -) { - let throttle_interval_start = stream_counter.reset_throttling_params_if_needed(); - let streams_read_in_throttle_interval = stream_counter.stream_count.load(Ordering::Relaxed); - if streams_read_in_throttle_interval >= max_streams_per_throttling_interval { - // The peer is sending faster than we're willing to read. Sleep for what's - // left of this read interval so the peer backs off. - let throttle_duration = - STREAM_THROTTLING_INTERVAL.saturating_sub(throttle_interval_start.elapsed()); - - if !throttle_duration.is_zero() { - debug!( - "Throttling stream from {remote_addr:?}, peer type: {peer_type:?}, \ - max_streams_per_interval: {max_streams_per_throttling_interval}, \ - read_interval_streams: {streams_read_in_throttle_interval} throttle_duration: \ - {throttle_duration:?}" - ); - stats.throttled_streams.fetch_add(1, Ordering::Relaxed); - match peer_type { - ConnectionPeerType::Unstaked => { - stats - .throttled_unstaked_streams - .fetch_add(1, Ordering::Relaxed); - } - ConnectionPeerType::Staked(_) => { - stats - .throttled_staked_streams - .fetch_add(1, Ordering::Relaxed); - } - } - sleep(throttle_duration).await; +impl OpaqueStreamerCounter for StreamRateLimiter {} + +impl Ord for StreamRateLimiter { + fn cmp(&self, other: &Self) -> cmp::Ordering { + // high stake comes first + other.true_stake_sol.cmp(&self.true_stake_sol) + } +} +impl PartialOrd for StreamRateLimiter { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} +impl PartialEq for StreamRateLimiter { + fn eq(&self, other: &Self) -> bool { + self.true_stake_sol == other.true_stake_sol + } +} +impl Eq for StreamRateLimiter {} + +// these conversion helpers optimize to +// the same assembly as if we used transmute, minus +// the unsafe block. +impl From for BucketState { + #[inline] + fn from(v: u64) -> Self { + Self { + tokens: v as u16, + consumed: (v >> 16) as u16, + last_refill: (v >> 32) as u16, } } } +impl From for u64 { + #[inline] + fn from(val: BucketState) -> Self { + (val.tokens as u64) | ((val.consumed as u64) << 16) | ((val.last_refill as u64) << 32) + } +} #[cfg(test)] pub mod test { + #![allow(clippy::arithmetic_side_effects)] use { - super::*, - crate::quic::{ - StreamerStats, DEFAULT_MAX_STREAMS_PER_MS, DEFAULT_MAX_UNSTAKED_CONNECTIONS, + crate::{ + nonblocking::{ + stream_throttle::{ + token_fill_per_interval, Refiller, StreamRateLimiter, BASE_STAKE_SOL, + FULLY_DECAYED_STAKE_SOL, INITIAL_TOKENS, REFILL_INTERVAL, + }, + testing_utilities::fill_connection_table, + }, + quic::StreamerStats, }, - std::sync::{atomic::Ordering, Arc}, + solana_native_token::LAMPORTS_PER_SOL, + solana_pubkey::Pubkey, + std::{ + future::Future, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::Arc, + time::{Duration, Instant}, + }, + tokio::time::timeout, + tokio_util::sync::CancellationToken, }; - #[test] - fn test_max_streams_for_unstaked_connection() { - let load_ema = Arc::new(StakedStreamLoadEMA::new( - Arc::new(StreamerStats::default()), - DEFAULT_MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, - )); + #[tokio::test] + async fn test_stream_rate_limiter() { + let max_tokens = 100; + let entry = StreamRateLimiter::new_unstaked(); assert_eq!( - load_ema.available_load_capacity_in_throttling_duration( - ConnectionPeerType::Unstaked, - 10000, - ), - 20 + entry.effective_stake(), + BASE_STAKE_SOL, + "Unstaked nodes should be given BASE_STAKE_SOL worth of stake" + ); + entry.drain(); + assert_eq!( + entry.refill(max_tokens, max_tokens), + BASE_STAKE_SOL, + "Should have full stake applied" + ); + assert_eq!( + entry.refill(max_tokens, max_tokens), + FULLY_DECAYED_STAKE_SOL, + "Should have lost all possible stake due to no usage" ); - } - #[test] - fn test_staked_throttling_on_off() { - let mut load_ema = StakedStreamLoadEMA::new( - Arc::new(StreamerStats::default()), - DEFAULT_MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, + let stats = StreamerStats::default(); + entry.drain(); + assert_eq!( + entry.refill(max_tokens, max_tokens), + BASE_STAKE_SOL, + "Should have full stake reapplied" ); + let t0 = Instant::now(); + for _ in 0..100 { + entry.wait_for_token(&stats).await; + } + assert!(t0.elapsed() < REFILL_INTERVAL, "should not be blocked"); + let to = tokio::time::timeout(REFILL_INTERVAL, entry.wait_for_token(&stats)).await; + assert!(to.is_err(), "Must block waiting on token arrival"); + } - load_ema.staked_throttling_on_load_threshold = 10; + const MAX_TPS: u64 = 10000; + const NUM_CLIENTS: usize = 10; + const TOKEN_FILL_PER_INTERVAL: u64 = token_fill_per_interval(MAX_TPS); - load_ema.current_load_ema.store(12, Ordering::Relaxed); - load_ema - .load_in_recent_interval - .store(12, Ordering::Relaxed); - load_ema.update_ema(u128::from(STREAM_LOAD_EMA_INTERVAL_MS)); - assert!(load_ema.staked_throttling_enabled.load(Ordering::Relaxed)); + /// Test that the Refiller is working correctly: + /// * if we use the tokens they get reflled + /// * if we do not use the tokens effective stake decays + /// * once we start using the effective stake returns + #[tokio::test] + async fn test_refiller_stake_decay() { + agave_logger::setup(); + let stats = Arc::new(StreamerStats::default()); + let cancel = CancellationToken::new(); - load_ema.current_load_ema.store(4, Ordering::Relaxed); - load_ema.load_in_recent_interval.store(0, Ordering::Relaxed); - load_ema.update_ema(u128::from(STREAM_LOAD_EMA_INTERVAL_MS)); - assert!(!load_ema.staked_throttling_enabled.load(Ordering::Relaxed)); - } + let staked_connection_table = fill_connection_table(&[], &[], stats.clone()); - #[test] - fn test_staked_capacity_shares_when_throttled() { - let mut load_ema = StakedStreamLoadEMA::new( - Arc::new(StreamerStats::default()), - DEFAULT_MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, - ); + let sockets: Vec<_> = (0..NUM_CLIENTS as u32) + .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(i)), 0)) + .collect(); - load_ema - .staked_throttling_enabled - .store(true, Ordering::Relaxed); - load_ema.max_staked_load_in_throttling_window = 100; - load_ema.max_unstaked_load_in_throttling_window = 20; + let rate_limiters: Vec<_> = sockets + .iter() + .map(|_| Arc::new(StreamRateLimiter::new_unstaked())) + .collect(); + let unstaked_connection_table = + fill_connection_table(&sockets, &rate_limiters, stats.clone()); + + let mut refiller = Refiller::new(staked_connection_table, unstaked_connection_table).await; assert_eq!( - load_ema.available_load_capacity_in_throttling_duration( - ConnectionPeerType::Staked(10), - 100 - ), - load_ema.max_unstaked_load_in_throttling_window + 1 + rate_limiters[0].tokens(), + INITIAL_TOKENS as u64, + "Should have INITIAL_TOKENS initial tokens" ); + + for rl in rate_limiters.iter() { + rl.drain(); + } + assert_eq!(rate_limiters[0].tokens(), 0, "Should have no tokens"); + assert_blocks( + rate_limiters[0].wait_for_token(&stats), + "wait on empty buckets", + ) + .await; + + info!("wait for token after refill"); + refiller.do_refill(MAX_TPS).await; + rate_limiters[0].wait_for_token(&stats).await; assert_eq!( - load_ema.available_load_capacity_in_throttling_duration( - ConnectionPeerType::Staked(50), - 100 - ), - 50 + rate_limiters[0].tokens(), + (TOKEN_FILL_PER_INTERVAL / NUM_CLIENTS as u64 - 1), + "Should have spent 1 transaction for the one we just consumed" ); - } - - #[test] - fn test_no_throttle_below_threshold() { - let mut load_ema = StakedStreamLoadEMA::new( - Arc::new(StreamerStats::default()), - DEFAULT_MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, + assert_eq!( + rate_limiters[0].effective_stake(), + BASE_STAKE_SOL, + "Should have all stake applied" ); - load_ema - .staked_throttling_enabled - .store(false, Ordering::Relaxed); - load_ema.max_staked_load_in_throttling_window = 100; - load_ema.max_unstaked_load_in_throttling_window = 20; + info!("refill twice without consuming"); + refiller.do_refill(MAX_TPS).await; + refiller.do_refill(MAX_TPS).await; + assert_eq!( + rate_limiters[0].effective_stake(), + FULLY_DECAYED_STAKE_SOL, + "Should have all stake retracted due to no use during last refill" + ); + info!("test under max consumption"); + rate_limiters[0].drain(); + refiller.do_refill(MAX_TPS).await; + rate_limiters[0].drain(); + refiller.do_refill(MAX_TPS).await; + + let effective_stake = (NUM_CLIENTS as u64 - 1) * FULLY_DECAYED_STAKE_SOL + BASE_STAKE_SOL; + let expected_tokens = TOKEN_FILL_PER_INTERVAL * BASE_STAKE_SOL / effective_stake; + assert_eq!( + rate_limiters[0].tokens(), + expected_tokens, + "Bucket should be filled now" + ); + info!("test under no consumption"); + refiller.do_refill(MAX_TPS).await; + refiller.do_refill(MAX_TPS).await; assert_eq!( - load_ema.available_load_capacity_in_throttling_duration( - ConnectionPeerType::Staked(10), - 100 - ), - load_ema.max_staked_load_in_throttling_window + rate_limiters[0].effective_stake(), + FULLY_DECAYED_STAKE_SOL, + "Should have reduced effective stake due to lack of consumption" ); + cancel.cancel(); } - #[test] - fn test_ema_decay_handles_missing_intervals() { - let load_ema = StakedStreamLoadEMA::new( - Arc::new(StreamerStats::default()), - DEFAULT_MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, + /// Test the SWQOS - that token allocations are stake proportional under load + #[tokio::test] + async fn test_refiller_swqos() { + agave_logger::setup(); + let stats = Arc::new(StreamerStats::default()); + + const NUM_UNSTAKED_CLIENTS: usize = NUM_CLIENTS / 2; + let sockets: Vec<_> = (0..NUM_CLIENTS as u32) + .map(|i| SocketAddr::new(IpAddr::V4(Ipv4Addr::from_bits(i)), 0)) + .collect(); + let (unstaked_sockets, staked_sockets) = sockets.split_at(NUM_UNSTAKED_CLIENTS); + let unstaked_rate_limiters: Vec<_> = unstaked_sockets + .iter() + .map(|_| Arc::new(StreamRateLimiter::new_unstaked())) + .collect(); + let staked_rate_limiters: Vec<_> = staked_sockets + .iter() + .enumerate() + .map(|(index, _)| { + Arc::new(StreamRateLimiter::new( + Pubkey::new_unique(), + BASE_STAKE_SOL * LAMPORTS_PER_SOL * (index + 1) as u64, + )) + }) + .collect(); + + let unstaked_connection_table = + fill_connection_table(unstaked_sockets, &unstaked_rate_limiters, stats.clone()); + let staked_connection_table = + fill_connection_table(staked_sockets, &staked_rate_limiters, stats.clone()); + + let all_rate_limiters: Vec<_> = unstaked_rate_limiters + .iter() + .cloned() + .chain(staked_rate_limiters.iter().cloned()) + .collect(); + + let mut refiller = Refiller::new(staked_connection_table, unstaked_connection_table).await; + assert_eq!( + all_rate_limiters[0].tokens(), + INITIAL_TOKENS as u64, + "Should have INITIAL_TOKENS initial tokens" ); - load_ema.current_load_ema.store(100, Ordering::Relaxed); - load_ema - .load_in_recent_interval - .store(100, Ordering::Relaxed); - - load_ema.update_ema(u128::from(STREAM_LOAD_EMA_INTERVAL_MS * 3)); - - let expected = StakedStreamLoadEMA::ema_function( - StakedStreamLoadEMA::ema_function(StakedStreamLoadEMA::ema_function(100, 100), 0), - 0, - ); + for rl in all_rate_limiters.iter() { + rl.drain(); + } + assert_eq!(all_rate_limiters[0].tokens(), 0, "Should have no tokens"); + + let fake_stake = NUM_UNSTAKED_CLIENTS as u64 * BASE_STAKE_SOL; + let real_stake = refiller + .staked_connection_table + .lock() + .await + .connected_stake() + / LAMPORTS_PER_SOL; + let total_stake = fake_stake + real_stake; assert_eq!( - load_ema.current_load_ema.load(Ordering::Relaxed), - u64::try_from(expected).unwrap() + refiller.last_iter_total_stake, total_stake, + "total stake observed by refill + should match actual" ); - } - #[test] - fn test_total_stake_zero_safety() { - let load_ema = StakedStreamLoadEMA::new( - Arc::new(StreamerStats::default()), - DEFAULT_MAX_UNSTAKED_CONNECTIONS, - DEFAULT_MAX_STREAMS_PER_MS, - ); - load_ema - .staked_throttling_enabled - .store(true, Ordering::Relaxed); + info!("fill everything initially"); + refiller.do_refill(MAX_TPS).await; + let allocations: Vec<_> = all_rate_limiters.iter().map(|rl| rl.drain()).collect(); + debug!("{allocations:?}"); assert_eq!( - load_ema - .available_load_capacity_in_throttling_duration(ConnectionPeerType::Staked(10), 0), - load_ema.max_unstaked_load_in_throttling_window + 1 + allocations.iter().sum::(), + TOKEN_FILL_PER_INTERVAL, + "Should have allocated all tokens" ); + + for alloc in &allocations[..NUM_UNSTAKED_CLIENTS] { + assert_eq!( + *alloc, + TOKEN_FILL_PER_INTERVAL * fake_stake / total_stake / NUM_UNSTAKED_CLIENTS as u64 + ) + } + let staked_refill = TOKEN_FILL_PER_INTERVAL * real_stake / total_stake; + for (alloc, rl) in allocations[NUM_UNSTAKED_CLIENTS..] + .iter() + .zip(staked_rate_limiters.iter()) + { + assert_eq!( + *alloc, + staked_refill * rl.true_stake_sol / real_stake as u64 + ) + } + info!("Staked nodes do not use bandwidth"); + // ensure effective stake decay by draining only staked stream limiters + refiller.do_refill(MAX_TPS).await; + let _allocations: Vec<_> = unstaked_rate_limiters.iter().map(|rl| rl.drain()).collect(); + refiller.do_refill(MAX_TPS).await; + let _allocations: Vec<_> = unstaked_rate_limiters.iter().map(|rl| rl.drain()).collect(); + refiller.do_refill(MAX_TPS).await; + let total_effective_stake = refiller.last_iter_effective_stake; + let allocations: Vec<_> = unstaked_rate_limiters.iter().map(|rl| rl.drain()).collect(); + debug!("{allocations:?}"); + for alloc in allocations { + assert_eq!( + alloc, + TOKEN_FILL_PER_INTERVAL * fake_stake + / total_effective_stake + / NUM_UNSTAKED_CLIENTS as u64, + "Unstaked should be getting more tokens now that staked are idle" + ) + } + info!("Untaked nodes do not use bandwidth"); + + // ensure effective stake decay by draining only staked stream limiters + refiller.do_refill(MAX_TPS).await; + let _allocations: Vec<_> = staked_rate_limiters.iter().map(|rl| rl.drain()).collect(); + refiller.do_refill(MAX_TPS).await; + let _allocations: Vec<_> = staked_rate_limiters.iter().map(|rl| rl.drain()).collect(); + refiller.do_refill(MAX_TPS).await; + let total_effective_stake = refiller.last_iter_effective_stake; + let allocations: Vec<_> = staked_rate_limiters.iter().map(|rl| rl.drain()).collect(); + debug!("{allocations:?}"); + let staked_refill = TOKEN_FILL_PER_INTERVAL * real_stake / total_effective_stake; + for (alloc, rl) in allocations.iter().zip(staked_rate_limiters.iter()) { + assert_eq!( + *alloc, + staked_refill * rl.true_stake_sol / real_stake as u64 + ) + } + } + + /// Awaits `fut` for 100ms. + /// - If the future **completes within 100ms**, this function panics. + /// - If the future **does not complete within 100ms**, this function returns normally. + pub async fn assert_blocks(fut: F, msg: &'static str) + where + F: Future, + { + match timeout(Duration::from_millis(100), fut).await { + Ok(_) => panic!("{msg} did not block"), + Err(_) => { + // timed out as expected; do nothing + } + } } } diff --git a/streamer/src/nonblocking/swqos.rs b/streamer/src/nonblocking/swqos.rs index cc5003b05f944d..ec6fe094a1a3fe 100644 --- a/streamer/src/nonblocking/swqos.rs +++ b/streamer/src/nonblocking/swqos.rs @@ -9,10 +9,7 @@ use { CONNECTION_CLOSE_CODE_EXCEED_MAX_STREAM_COUNT, CONNECTION_CLOSE_REASON_DISALLOWED, CONNECTION_CLOSE_REASON_EXCEED_MAX_STREAM_COUNT, }, - stream_throttle::{ - throttle_stream, ConnectionStreamCounter, StakedStreamLoadEMA, - STREAM_THROTTLING_INTERVAL_MS, - }, + stream_throttle::{refill_task, StreamRateLimiter, BASE_STAKE_SOL}, }, quic::{ StreamerStats, DEFAULT_MAX_QUIC_CONNECTIONS_PER_STAKED_PEER, @@ -89,11 +86,11 @@ impl SwQosConfig { pub struct SwQos { config: SwQosConfig, - staked_stream_load_ema: Arc, stats: Arc, staked_nodes: Arc>, - unstaked_connection_table: Arc>>, - staked_connection_table: Arc>>, + unstaked_connection_table: Arc>>, + staked_connection_table: Arc>>, + cancel: CancellationToken, } // QoS Params for Stake weighted QoS @@ -104,8 +101,7 @@ pub struct SwQosConnectionContext { total_stake: u64, in_staked_table: bool, last_update: Arc, - remote_address: std::net::SocketAddr, - stream_counter: Option>, + stream_counter: Option>, } impl ConnectionContext for SwQosConnectionContext { @@ -127,13 +123,9 @@ impl SwQos { ) -> Self { Self { config: config.clone(), - staked_stream_load_ema: Arc::new(StakedStreamLoadEMA::new( - stats.clone(), - config.max_unstaked_connections, - config.max_streams_per_ms, - )), stats, staked_nodes, + cancel: cancel.clone(), unstaked_connection_table: Arc::new(Mutex::new(ConnectionTable::new( ConnectionTableType::Unstaked, cancel.clone(), @@ -185,16 +177,10 @@ impl SwQos { &self, client_connection_tracker: ClientConnectionTracker, connection: &Connection, - mut connection_table_l: MutexGuard>, + mut connection_table_l: MutexGuard>, conn_context: &SwQosConnectionContext, - ) -> Result< - ( - Arc, - CancellationToken, - Arc, - ), - ConnectionHandlerError, - > { + ) -> Result<(Arc, CancellationToken, Arc), ConnectionHandlerError> + { // get current RTT and limit it to MAX_RTT_MS let rtt_millis = connection.rtt().as_millis() as u64; if let Ok(max_uni_streams) = VarInt::from_u64(compute_max_allowed_uni_streams_with_rtt( @@ -213,20 +199,36 @@ impl SwQos { remote_addr, ); - let max_connections_per_peer = match conn_context.peer_type() { - ConnectionPeerType::Unstaked => self.config.max_connections_per_unstaked_peer, - ConnectionPeerType::Staked(_) => self.config.max_connections_per_staked_peer, - }; + connection.set_max_concurrent_uni_streams(max_uni_streams); + let key = ConnectionTableKey::new(remote_addr.ip(), conn_context.remote_pubkey); + let (max_connections_per_peer, stream_counter_prototype) = + match conn_context.peer_type() { + ConnectionPeerType::Unstaked => { + (self.config.max_connections_per_unstaked_peer, None) + } + ConnectionPeerType::Staked(peer_stake) => ( + self.config.max_connections_per_staked_peer, + Some(StreamRateLimiter::new( + conn_context.remote_pubkey.unwrap_or_default(), + peer_stake, + )), + ), + }; if let Some((last_update, cancel_connection, stream_counter)) = connection_table_l .try_add_connection( - ConnectionTableKey::new(remote_addr.ip(), conn_context.remote_pubkey), + key, remote_addr.port(), client_connection_tracker, Some(connection.clone()), conn_context.peer_type(), conn_context.last_update.clone(), max_connections_per_peer, - || Arc::new(ConnectionStreamCounter::new()), + move || { + Arc::new( + stream_counter_prototype + .unwrap_or_else(StreamRateLimiter::new_unstaked), + ) + }, ) { update_open_connections_stat(&self.stats, &connection_table_l); @@ -255,11 +257,11 @@ impl SwQos { fn prune_unstaked_connection_table( &self, - unstaked_connection_table: &mut ConnectionTable, + unstaked_connection_table: &mut ConnectionTable, max_unstaked_connections: usize, stats: Arc, ) { - if unstaked_connection_table.total_size >= max_unstaked_connections { + if unstaked_connection_table.table_size() >= max_unstaked_connections { const PRUNE_TABLE_TO_PERCENTAGE: u8 = 90; let max_percentage_full = Percentage::from(PRUNE_TABLE_TO_PERCENTAGE); @@ -275,17 +277,11 @@ impl SwQos { &self, client_connection_tracker: ClientConnectionTracker, connection: &Connection, - connection_table: Arc>>, + connection_table: Arc>>, max_connections: usize, conn_context: &SwQosConnectionContext, - ) -> Result< - ( - Arc, - CancellationToken, - Arc, - ), - ConnectionHandlerError, - > { + ) -> Result<(Arc, CancellationToken, Arc), ConnectionHandlerError> + { let stats = self.stats.clone(); if max_connections > 0 { let mut connection_table = connection_table.lock().await; @@ -304,17 +300,18 @@ impl SwQos { Err(ConnectionHandlerError::ConnectionAddError) } } - - fn max_streams_per_throttling_interval(&self, conn_context: &SwQosConnectionContext) -> u64 { - self.staked_stream_load_ema - .available_load_capacity_in_throttling_duration( - conn_context.peer_type, - conn_context.total_stake, - ) - } } impl QosController for SwQos { + async fn async_init(&mut self) { + tokio::spawn(refill_task( + self.staked_connection_table.clone(), + self.unstaked_connection_table.clone(), + self.config.max_streams_per_ms * 1000, + self.cancel.child_token(), + )); + } + fn build_connection_context(&self, connection: &Connection) -> SwQosConnectionContext { get_connection_stake(connection, &self.staked_nodes).map_or( SwQosConnectionContext { @@ -322,7 +319,6 @@ impl QosController for SwQos { total_stake: 0, remote_pubkey: None, in_staked_table: false, - remote_address: connection.remote_address(), stream_counter: None, last_update: Arc::new(AtomicU64::new(timing::timestamp())), }, @@ -331,12 +327,9 @@ impl QosController for SwQos { // interval during which we allow max (MAX_STREAMS_PER_MS * STREAM_THROTTLING_INTERVAL_MS) streams. let peer_type = { - let max_streams_per_ms = self.staked_stream_load_ema.max_streams_per_ms(); - let min_stake_ratio = - 1_f64 / (max_streams_per_ms * STREAM_THROTTLING_INTERVAL_MS) as f64; - let stake_ratio = stake as f64 / total_stake as f64; - if stake_ratio < min_stake_ratio { - // If it is a staked connection with ultra low stake ratio, treat it as unstaked. + // If it is a staked connection with ultra low stake, treat it as unstaked. + // This prevents 1-SOL nodes from polluting the staked nodes table. + if stake < BASE_STAKE_SOL { ConnectionPeerType::Unstaked } else { ConnectionPeerType::Staked(stake) @@ -348,7 +341,6 @@ impl QosController for SwQos { total_stake, remote_pubkey: Some(pubkey), in_staked_table: false, - remote_address: connection.remote_address(), last_update: Arc::new(AtomicU64::new(timing::timestamp())), stream_counter: None, } @@ -370,7 +362,7 @@ impl QosController for SwQos { ConnectionPeerType::Staked(stake) => { let mut connection_table_l = self.staked_connection_table.lock().await; - if connection_table_l.total_size >= self.config.max_staked_connections { + if connection_table_l.table_size() >= self.config.max_staked_connections { let num_pruned = connection_table_l.prune_random(PRUNE_RANDOM_SAMPLE_SIZE, stake); self.stats @@ -379,7 +371,7 @@ impl QosController for SwQos { update_open_connections_stat(&self.stats, &connection_table_l); } - if connection_table_l.total_size < self.config.max_staked_connections { + if connection_table_l.table_size() < self.config.max_staked_connections { if let Ok((last_update, cancel_connection, stream_counter)) = self .cache_new_connection( client_connection_tracker, @@ -457,24 +449,11 @@ impl QosController for SwQos { } } - fn on_stream_accepted(&self, conn_context: &SwQosConnectionContext) { - self.staked_stream_load_ema - .increment_load(conn_context.peer_type); - conn_context - .stream_counter - .as_ref() - .unwrap() - .stream_count - .fetch_add(1, Ordering::Relaxed); - } + fn on_stream_accepted(&self, _conn_context: &SwQosConnectionContext) {} - fn on_stream_error(&self, _conn_context: &SwQosConnectionContext) { - self.staked_stream_load_ema.update_ema_if_needed(); - } + fn on_stream_error(&self, _conn_context: &SwQosConnectionContext) {} - fn on_stream_closed(&self, _conn_context: &SwQosConnectionContext) { - self.staked_stream_load_ema.update_ema_if_needed(); - } + fn on_stream_closed(&self, _conn_context: &SwQosConnectionContext) {} #[allow(clippy::manual_async_fn)] fn remove_connection( @@ -511,22 +490,8 @@ impl QosController for SwQos { #[allow(clippy::manual_async_fn)] fn on_new_stream(&self, context: &SwQosConnectionContext) -> impl Future + Send { async move { - let peer_type = context.peer_type(); - let remote_addr = context.remote_address; - let stream_counter: &Arc = - context.stream_counter.as_ref().unwrap(); - - let max_streams_per_throttling_interval = - self.max_streams_per_throttling_interval(context); - - throttle_stream( - &self.stats, - peer_type, - remote_addr, - stream_counter, - max_streams_per_throttling_interval, - ) - .await; + let stream_counter: &Arc = context.stream_counter.as_ref().unwrap(); + stream_counter.wait_for_token(&self.stats).await; } } diff --git a/streamer/src/nonblocking/testing_utilities.rs b/streamer/src/nonblocking/testing_utilities.rs index b9f5b7fe37fc37..ec124d99cb4735 100644 --- a/streamer/src/nonblocking/testing_utilities.rs +++ b/streamer/src/nonblocking/testing_utilities.rs @@ -3,7 +3,11 @@ use { super::quic::{SpawnNonBlockingServerResult, ALPN_TPU_PROTOCOL_ID}, crate::{ nonblocking::{ - quic::spawn_server, + quic::{ + spawn_server, ClientConnectionTracker, ConnectionPeerType, ConnectionTable, + ConnectionTableKey, ConnectionTableType, + }, + stream_throttle::StreamRateLimiter, swqos::{SwQos, SwQosConfig}, }, quic::{QuicServerError, QuicStreamerConfig, StreamerStats, QUIC_MAX_TIMEOUT}, @@ -15,6 +19,7 @@ use { TokioRuntime, TransportConfig, }, solana_keypair::Keypair, + solana_native_token::LAMPORTS_PER_SOL, solana_net_utils::sockets::{ bind_to_localhost_unique, localhost_port_range_for_tests, multi_bind_in_range_with_config, SocketConfiguration as SocketConfig, @@ -23,10 +28,10 @@ use { solana_tls_utils::{new_dummy_x509_certificate, tls_client_config_builder}, std::{ net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, - sync::{Arc, RwLock}, + sync::{atomic::AtomicU64, Arc, RwLock}, time::{Duration, Instant}, }, - tokio::{task::JoinHandle, time::sleep}, + tokio::{sync::Mutex, task::JoinHandle, time::sleep}, tokio_util::sync::CancellationToken, }; @@ -50,12 +55,7 @@ where { let stats = Arc::::default(); - let swqos = Arc::new(SwQos::new( - qos_config, - stats.clone(), - staked_nodes, - cancel.clone(), - )); + let swqos = SwQos::new(qos_config, stats.clone(), staked_nodes, cancel.clone()); spawn_server( name, @@ -126,6 +126,7 @@ pub fn setup_quic_server( let keypair = Keypair::new(); let server_address = sockets[0].local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(option_staked_nodes.unwrap_or_default())); + let cancel = CancellationToken::new(); let SpawnNonBlockingServerResult { @@ -218,3 +219,36 @@ pub async fn check_multiple_streams( } assert_eq!(total_packets, num_expected_packets); } + +/// Generates a fake connection table filled with provided sockets and rate_limiters. +/// Useful to test stream throttling and nothing else. +pub fn fill_connection_table( + socket_addrs: &[SocketAddr], + rate_limiters: &[Arc], + stats: Arc, +) -> Arc>> { + let max_connections_per_peer = 1; + let cancel = CancellationToken::new(); + let mut connection_table: ConnectionTable = + ConnectionTable::new(ConnectionTableType::Unstaked, cancel.clone()); + + for (socket, limiter) in socket_addrs.iter().zip(rate_limiters.iter().cloned()) { + let peer_type = match limiter.true_stake_sol { + 0 => ConnectionPeerType::Unstaked, + x => ConnectionPeerType::Staked(x * LAMPORTS_PER_SOL), + }; + connection_table + .try_add_connection( + ConnectionTableKey::IP(socket.ip()), + socket.port(), + ClientConnectionTracker::new(stats.clone(), 100000).unwrap(), + None, + peer_type, + Arc::new(AtomicU64::new(10)), + max_connections_per_peer, + || limiter, + ) + .unwrap(); + } + Arc::new(Mutex::new(connection_table)) +} diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 2616e467d82621..61de2d6bce3128 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -210,16 +210,13 @@ pub struct StreamerStats { // Per IP rate-limiting is triggered each time when there are too many connections // opened from a particular IP address. pub(crate) connection_rate_limited_per_ipaddr: AtomicUsize, - pub(crate) throttled_streams: AtomicUsize, - pub(crate) stream_load_ema: AtomicUsize, - pub(crate) stream_load_ema_overflow: AtomicUsize, - pub(crate) stream_load_capacity_overflow: AtomicUsize, pub(crate) process_sampled_packets_us_hist: Mutex, pub(crate) perf_track_overhead_us: AtomicU64, pub(crate) total_staked_packets_sent_for_batching: AtomicUsize, pub(crate) total_unstaked_packets_sent_for_batching: AtomicUsize, - pub(crate) throttled_staked_streams: AtomicUsize, - pub(crate) throttled_unstaked_streams: AtomicUsize, + pub(crate) throttled_time_ms: AtomicU64, + pub(crate) throttled_ms_staked: AtomicU64, + pub(crate) throttled_ms_unstaked: AtomicU64, // All connections in various states such as Incoming, Connecting, Connection pub(crate) open_connections: AtomicUsize, pub(crate) open_staked_connections: AtomicUsize, @@ -460,33 +457,18 @@ impl StreamerStats { i64 ), ( - "throttled_streams", - self.throttled_streams.swap(0, Ordering::Relaxed), + "throttled_time_ms", + self.throttled_time_ms.swap(0, Ordering::Relaxed), i64 ), ( - "stream_load_ema", - self.stream_load_ema.load(Ordering::Relaxed), + "throttled_ms_unstaked", + self.throttled_ms_unstaked.swap(0, Ordering::Relaxed), i64 ), ( - "stream_load_ema_overflow", - self.stream_load_ema_overflow.load(Ordering::Relaxed), - i64 - ), - ( - "stream_load_capacity_overflow", - self.stream_load_capacity_overflow.load(Ordering::Relaxed), - i64 - ), - ( - "throttled_unstaked_streams", - self.throttled_unstaked_streams.swap(0, Ordering::Relaxed), - i64 - ), - ( - "throttled_staked_streams", - self.throttled_staked_streams.swap(0, Ordering::Relaxed), + "throttled_ms_staked", + self.throttled_ms_staked.swap(0, Ordering::Relaxed), i64 ), ( @@ -621,7 +603,7 @@ fn spawn_runtime_and_server( keypair: &Keypair, packet_sender: Sender, quic_server_params: QuicStreamerConfig, - qos: Arc, + qos: Q, cancel: CancellationToken, ) -> Result where @@ -674,12 +656,7 @@ pub fn spawn_stake_wighted_qos_server( cancel: CancellationToken, ) -> Result { let stats = Arc::::default(); - let swqos = Arc::new(SwQos::new( - qos_config, - stats.clone(), - staked_nodes, - cancel.clone(), - )); + let swqos = SwQos::new(qos_config, stats.clone(), staked_nodes, cancel.clone()); spawn_runtime_and_server( thread_name, metrics_name, @@ -706,12 +683,8 @@ pub fn spawn_simple_qos_server( cancel: CancellationToken, ) -> Result { let stats = Arc::::default(); - let simple_qos = Arc::new(SimpleQos::new( - qos_config, - stats.clone(), - staked_nodes, - cancel.clone(), - )); + + let simple_qos = SimpleQos::new(qos_config, stats.clone(), staked_nodes, cancel.clone()); spawn_runtime_and_server( thread_name, @@ -730,15 +703,18 @@ pub fn spawn_simple_qos_server( mod test { use { super::*, - crate::nonblocking::{ - quic::test::*, - testing_utilities::{check_multiple_streams, make_client_endpoint}, + crate::{ + nonblocking::{ + quic::test::*, + testing_utilities::{check_multiple_streams, make_client_endpoint}, + }, + streamer::StakedNodes, }, crossbeam_channel::{unbounded, Receiver}, solana_net_utils::sockets::bind_to_localhost_unique, solana_pubkey::Pubkey, solana_signer::Signer, - std::{collections::HashMap, net::SocketAddr, time::Instant}, + std::{collections::HashMap, net::SocketAddr, sync::RwLock, time::Instant}, tokio::time::sleep, }; @@ -903,10 +879,10 @@ mod test { (client_keypair.pubkey(), 1_000), // very small staked node (rich_node_keypair.pubkey(), 1_000_000_000), ]); - let staked_nodes = StakedNodes::new( + let staked_nodes = Arc::new(RwLock::new(StakedNodes::new( Arc::new(stakes), HashMap::::default(), // overrides - ); + ))); let server_params = QuicStreamerConfig::default_for_tests(); let qos_config = SimpleQosConfig { @@ -919,7 +895,7 @@ mod test { qos_config, }; let (t, receiver, server_address, cancel) = - setup_simple_qos_quic_server(server_params, Arc::new(RwLock::new(staked_nodes))); + setup_simple_qos_quic_server(server_params, staked_nodes); let runtime = rt_for_test(); let num_expected_packets = 20; diff --git a/streamer/src/streamer.rs b/streamer/src/streamer.rs index 0fb370db480245..46012a418d53d1 100644 --- a/streamer/src/streamer.rs +++ b/streamer/src/streamer.rs @@ -79,8 +79,8 @@ pub(crate) const SOCKET_READ_TIMEOUT: Duration = Duration::from_secs(1); // Total stake and nodes => stake map #[derive(Default)] pub struct StakedNodes { - stakes: Arc>, - overrides: HashMap, + pub stakes: Arc>, + pub overrides: HashMap, total_stake: u64, }