diff --git a/Cargo.lock b/Cargo.lock index b30ff431a8574d..7887dffe9d4651 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7373,6 +7373,7 @@ dependencies = [ "solana-version", "solana-vote-program", "tikv-jemallocator", + "tokio-util 0.7.16", ] [[package]] @@ -9922,6 +9923,7 @@ dependencies = [ "solana-transaction-error", "thiserror 2.0.16", "tokio", + "tokio-util 0.7.16", ] [[package]] @@ -11757,6 +11759,7 @@ dependencies = [ "solana-version", "thiserror 2.0.16", "tokio", + "tokio-util 0.7.16", "url 2.5.7", "x509-parser", ] diff --git a/bench-vote/Cargo.toml b/bench-vote/Cargo.toml index f6c73ae97a94b4..b0460558f6ba1f 100644 --- a/bench-vote/Cargo.toml +++ b/bench-vote/Cargo.toml @@ -29,6 +29,7 @@ solana-streamer = { workspace = true } solana-transaction = { workspace = true } solana-version = { workspace = true } solana-vote-program = { workspace = true } +tokio-util = { workspace = true } [target.'cfg(not(any(target_env = "msvc", target_os = "freebsd")))'.dependencies] jemallocator = { workspace = true } diff --git a/bench-vote/src/main.rs b/bench-vote/src/main.rs index 2ea31a0b774a19..488c75fe22230d 100644 --- a/bench-vote/src/main.rs +++ b/bench-vote/src/main.rs @@ -18,7 +18,7 @@ use { solana_streamer::{ packet::PacketBatchRecycler, quic::{ - spawn_server, QuicServerParams, DEFAULT_MAX_QUIC_CONNECTIONS_PER_PEER, + spawn_server_with_cancel, QuicServerParams, DEFAULT_MAX_QUIC_CONNECTIONS_PER_PEER, DEFAULT_MAX_STAKED_CONNECTIONS, }, streamer::{receiver, PacketBatchReceiver, StakedNodes, StreamerReceiveStats}, @@ -36,6 +36,7 @@ use { thread::{self, spawn, JoinHandle, Result}, time::{Duration, Instant, SystemTime}, }, + tokio_util::sync::CancellationToken, }; #[cfg(not(any(target_env = "msvc", target_os = "freebsd")))] @@ -244,8 +245,9 @@ fn main() -> Result<()> { } }); - let (exit, read_threads, sink_threads, destination) = if !client_only { + let (exit, cancel, read_threads, sink_threads, destination) = if !client_only { let exit = Arc::new(AtomicBool::new(false)); + let cancel = CancellationToken::new(); let mut read_channels = Vec::new(); let mut read_threads = Vec::new(); @@ -273,15 +275,15 @@ fn main() -> Result<()> { let (s_reader, r_reader) = unbounded(); read_channels.push(r_reader); - let server = spawn_server( + let server = spawn_server_with_cancel( "solRcvrBenVote", "bench_vote_metrics", read_sockets, &quic_params.identity_keypair, s_reader, - exit.clone(), quic_params.staked_nodes.clone(), quic_server_params, + cancel.clone(), ) .unwrap(); read_threads.push(server.thread); @@ -316,12 +318,13 @@ fn main() -> Result<()> { println!("Running server at {destination:?}"); ( Some(exit), + Some(cancel), Some(read_threads), Some(sink_threads), destination, ) } else { - (None, None, None, destination.unwrap()) + (None, None, None, None, destination.unwrap()) }; let start = SystemTime::now(); @@ -344,6 +347,7 @@ fn main() -> Result<()> { if !server_only { if let Some(exit) = exit { exit.store(true, Ordering::Relaxed); + cancel.unwrap().cancel(); } } else { println!("To stop the server, please press ^C"); diff --git a/core/src/tpu.rs b/core/src/tpu.rs index fd696398932cbd..f6958828c79020 100644 --- a/core/src/tpu.rs +++ b/core/src/tpu.rs @@ -49,7 +49,7 @@ use { vote_sender_types::{ReplayVoteReceiver, ReplayVoteSender}, }, solana_streamer::{ - quic::{spawn_server, QuicServerParams, SpawnServerResult}, + quic::{spawn_server_with_cancel, QuicServerParams, SpawnServerResult}, streamer::StakedNodes, }, solana_turbine::{ @@ -65,6 +65,7 @@ use { time::Duration, }, tokio::sync::mpsc::Sender as AsyncSender, + tokio_util::sync::CancellationToken, }; pub struct TpuSockets { @@ -156,6 +157,7 @@ impl Tpu { enable_block_production_forwarding: bool, _generator_config: Option, /* vestigial code for replay invalidator */ key_notifiers: Arc>, + cancel: CancellationToken, ) -> Self { let TpuSockets { transactions: transactions_sockets, @@ -208,15 +210,15 @@ impl Tpu { endpoints: _, thread: tpu_vote_quic_t, key_updater: vote_streamer_key_updater, - } = spawn_server( + } = spawn_server_with_cancel( "solQuicTVo", "quic_streamer_tpu_vote", tpu_vote_quic_sockets, keypair, vote_packet_sender.clone(), - exit.clone(), staked_nodes.clone(), vote_quic_server_config, + cancel.clone(), ) .unwrap(); @@ -226,15 +228,15 @@ impl Tpu { endpoints: _, thread: tpu_quic_t, key_updater, - } = spawn_server( + } = spawn_server_with_cancel( "solQuicTpu", "quic_streamer_tpu", transactions_quic_sockets, keypair, packet_sender, - exit.clone(), staked_nodes.clone(), tpu_quic_server_config, + cancel.clone(), ) .unwrap(); (Some(tpu_quic_t), Some(key_updater)) @@ -248,15 +250,15 @@ impl Tpu { endpoints: _, thread: tpu_forwards_quic_t, key_updater: forwards_key_updater, - } = spawn_server( + } = spawn_server_with_cancel( "solQuicTpuFwd", "quic_streamer_tpu_forwards", transactions_forwards_quic_sockets, keypair, forwarded_packet_sender, - exit.clone(), staked_nodes.clone(), tpu_fwd_quic_server_config, + cancel, ) .unwrap(); (Some(tpu_forwards_quic_t), Some(forwards_key_updater)) diff --git a/core/src/validator.rs b/core/src/validator.rs index 2ce56ccd46db02..96a524ce5aa9ba 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -735,8 +735,8 @@ impl Validator { timer.stop(); info!("Cleaning orphaned account snapshot directories done. {timer}"); - // token used to cancel tpu-client-next. - let cancel_tpu_client_next = CancellationToken::new(); + // token used to cancel tpu-client-next and streamer. + let cancel = CancellationToken::new(); { let exit = exit.clone(); config @@ -744,12 +744,12 @@ impl Validator { .write() .unwrap() .register_exit(Box::new(move || exit.store(true, Ordering::Relaxed))); - let cancel_tpu_client_next = cancel_tpu_client_next.clone(); + let cancel = cancel.clone(); config .validator_exit .write() .unwrap() - .register_exit(Box::new(move || cancel_tpu_client_next.cancel())); + .register_exit(Box::new(move || cancel.cancel())); } let ( @@ -1178,7 +1178,7 @@ impl Validator { Arc::as_ref(&identity_keypair), node.sockets.rpc_sts_client, runtime_handle.clone(), - cancel_tpu_client_next.clone(), + cancel.clone(), ) } else { let Some(connection_cache) = &connection_cache else { @@ -1633,7 +1633,7 @@ impl Validator { Arc::as_ref(&identity_keypair), tpu_transactions_forwards_client_sockets.take().unwrap(), runtime_handle.clone(), - cancel_tpu_client_next, + cancel.clone(), node_multihoming.clone(), )) }; @@ -1690,6 +1690,7 @@ impl Validator { config.enable_block_production_forwarding, config.generator_config.clone(), key_notifiers.clone(), + cancel, ); datapoint_info!( diff --git a/programs/sbf/Cargo.lock b/programs/sbf/Cargo.lock index 1b7b82ba859114..fdf1215a2df5b7 100644 --- a/programs/sbf/Cargo.lock +++ b/programs/sbf/Cargo.lock @@ -10687,7 +10687,7 @@ version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adcb7fd841cd518e279be3d5a3eb0636409487998a4aff22f3de87b81e88384f" dependencies = [ - "cfg-if 1.0.0", + "cfg-if 1.0.3", "proc-macro2", "quote", "syn 2.0.87", diff --git a/quic-client/Cargo.toml b/quic-client/Cargo.toml index 92b537ebd1713a..1427c56a35d235 100644 --- a/quic-client/Cargo.toml +++ b/quic-client/Cargo.toml @@ -40,3 +40,4 @@ solana-net-utils = { workspace = true, features = ["dev-context-only-utils"] } solana-packet = { workspace = true } solana-perf = { workspace = true } solana-streamer = { workspace = true, features = ["dev-context-only-utils"] } +tokio-util = { workspace = true } diff --git a/quic-client/tests/quic_client.rs b/quic-client/tests/quic_client.rs index 63d54f2ec75b26..e086bc0ad58fe6 100644 --- a/quic-client/tests/quic_client.rs +++ b/quic-client/tests/quic_client.rs @@ -18,13 +18,11 @@ mod tests { solana_tls_utils::{new_dummy_x509_certificate, QuicClientCertificate}, std::{ net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, RwLock, - }, + sync::{Arc, RwLock}, time::{Duration, Instant}, }, tokio::time::sleep, + tokio_util::sync::CancellationToken, }; fn check_packets( @@ -52,11 +50,11 @@ mod tests { assert!(total_packets > 0); } - fn server_args() -> (UdpSocket, Arc, Keypair) { + fn server_args() -> (UdpSocket, CancellationToken, Keypair) { let port_range = localhost_port_range_for_tests(); ( bind_to(IpAddr::V4(Ipv4Addr::LOCALHOST), port_range.0).expect("should bind"), - Arc::new(AtomicBool::new(false)), + CancellationToken::new(), Keypair::new(), ) } @@ -70,20 +68,20 @@ mod tests { solana_logger::setup(); let (sender, receiver) = unbounded(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); - let (s, exit, keypair) = server_args(); + let (s, cancel, keypair) = server_args(); let SpawnServerResult { endpoints: _, thread: t, key_updater: _, - } = solana_streamer::quic::spawn_server( + } = solana_streamer::quic::spawn_server_with_cancel( "solQuicTest", "quic_streamer_test", vec![s.try_clone().unwrap()], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams::default_for_tests(), + cancel.clone(), ) .unwrap(); @@ -105,7 +103,7 @@ mod tests { assert!(client.send_data_batch_async(packets).is_ok()); check_packets(receiver, num_bytes, num_expected_packets); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.join().unwrap(); } @@ -150,20 +148,20 @@ mod tests { solana_logger::setup(); let (sender, receiver) = unbounded(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); - let (s, exit, keypair) = server_args(); + let (s, cancel, keypair) = server_args(); let solana_streamer::nonblocking::quic::SpawnNonBlockingServerResult { endpoints: _, stats: _, thread: t, max_concurrent_connections: _, - } = solana_streamer::nonblocking::quic::spawn_server( + } = solana_streamer::nonblocking::quic::spawn_server_with_cancel( "quic_streamer_test", vec![s.try_clone().unwrap()], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams::default_for_tests(), + cancel.clone(), ) .unwrap(); @@ -186,7 +184,7 @@ mod tests { } nonblocking_check_packets(receiver, num_bytes, num_expected_packets).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.await.unwrap(); } @@ -208,26 +206,26 @@ mod tests { // Request Receiver let (sender, receiver) = unbounded(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); - let (request_recv_socket, request_recv_exit, keypair) = server_args(); + let (request_recv_socket, request_recv_cancel, keypair) = server_args(); let SpawnServerResult { endpoints: request_recv_endpoints, thread: request_recv_thread, key_updater: _, - } = solana_streamer::quic::spawn_server( + } = solana_streamer::quic::spawn_server_with_cancel( "solQuicTest", "quic_streamer_test", [request_recv_socket.try_clone().unwrap()], &keypair, sender, - request_recv_exit.clone(), staked_nodes.clone(), QuicServerParams::default_for_tests(), + request_recv_cancel.clone(), ) .unwrap(); drop(request_recv_endpoints); // Response Receiver: - let (response_recv_socket, response_recv_exit, keypair2) = server_args(); + let (response_recv_socket, response_recv_cancel, keypair2) = server_args(); let (sender2, receiver2) = unbounded(); let addr = response_recv_socket.local_addr().unwrap().ip(); @@ -237,15 +235,15 @@ mod tests { endpoints: mut response_recv_endpoints, thread: response_recv_thread, key_updater: _, - } = solana_streamer::quic::spawn_server( + } = solana_streamer::quic::spawn_server_with_cancel( "solQuicTest", "quic_streamer_test", [response_recv_socket], &keypair2, sender2, - response_recv_exit.clone(), staked_nodes, QuicServerParams::default_for_tests(), + response_recv_cancel.clone(), ) .unwrap(); @@ -304,11 +302,11 @@ mod tests { drop(request_sender); drop(response_sender); - request_recv_exit.store(true, Ordering::Relaxed); + request_recv_cancel.cancel(); request_recv_thread.join().unwrap(); info!("Request receiver exited!"); - response_recv_exit.store(true, Ordering::Relaxed); + response_recv_cancel.cancel(); response_recv_thread.join().unwrap(); info!("Response receiver exited!"); } @@ -318,20 +316,20 @@ mod tests { solana_logger::setup(); let (sender, receiver) = unbounded(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); - let (s, exit, keypair) = server_args(); + let (s, cancel, keypair) = server_args(); let solana_streamer::nonblocking::quic::SpawnNonBlockingServerResult { endpoints: _, stats: _, thread: t, max_concurrent_connections: _, - } = solana_streamer::nonblocking::quic::spawn_server( + } = solana_streamer::nonblocking::quic::spawn_server_with_cancel( "quic_streamer_test", vec![s.try_clone().unwrap()], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams::default_for_tests(), + cancel.clone(), ) .unwrap(); @@ -353,7 +351,7 @@ mod tests { } nonblocking_check_packets(receiver, num_bytes, num_expected_packets).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.await.unwrap(); // We close the connection after the server is down, this should not block diff --git a/streamer/examples/swqos.rs b/streamer/examples/swqos.rs index 3fea3af9400a8d..eadfc4742d4cbd 100644 --- a/streamer/examples/swqos.rs +++ b/streamer/examples/swqos.rs @@ -24,13 +24,11 @@ use { net::SocketAddr, path::Path, str::FromStr as _, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, RwLock, - }, + sync::{Arc, RwLock}, time::Duration, }, tokio::time::{sleep, Instant}, + tokio_util::sync::CancellationToken, }; fn parse_duration(arg: &str) -> Result { @@ -96,7 +94,6 @@ async fn main() -> anyhow::Result<()> { ) .expect("should bind"); - let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = bounded(1024); let keypair = Keypair::new(); @@ -108,22 +105,23 @@ async fn main() -> anyhow::Result<()> { Arc::new(RwLock::new(nodes)) }; + let cancel = CancellationToken::new(); let SpawnNonBlockingServerResult { endpoints, stats, thread: run_thread, max_concurrent_connections: _, - } = solana_streamer::nonblocking::quic::spawn_server( + } = solana_streamer::nonblocking::quic::spawn_server_with_cancel( "quic_streamer_test", [socket.try_clone()?], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams { max_connections_per_peer: cli.max_connections_per_peer, ..QuicServerParams::default() }, + cancel.clone(), )?; info!("Server listening on {}", socket.local_addr()?); @@ -158,7 +156,7 @@ async fn main() -> anyhow::Result<()> { sleep(cli.test_duration).await; info!("Server terminating"); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); drop(endpoints); run_thread.await?; logger_thread.await??; diff --git a/streamer/src/nonblocking/quic.rs b/streamer/src/nonblocking/quic.rs index dab92349c92526..d6f9c11e81746b 100644 --- a/streamer/src/nonblocking/quic.rs +++ b/streamer/src/nonblocking/quic.rs @@ -46,7 +46,6 @@ use { Arc, RwLock, }, task::Poll, - thread, time::{Duration, Instant}, }, tokio::{ @@ -60,7 +59,7 @@ use { // introduce any other awaits while holding the RwLock. select, sync::{Mutex, MutexGuard}, - task::JoinHandle, + task::{self, JoinHandle}, time::{sleep, timeout}, }, tokio_util::sync::CancellationToken, @@ -150,6 +149,7 @@ pub fn spawn_server_multi( staked_nodes: Arc>, quic_server_params: QuicServerParams, ) -> Result { + #[allow(deprecated)] spawn_server( name, sockets, @@ -161,7 +161,7 @@ pub fn spawn_server_multi( ) } -/// Spawn a streamer instance in the current tokio runtime. +#[deprecated(since = "3.1.0", note = "Use spawn_server_with_cancel instead")] pub fn spawn_server( name: &'static str, sockets: impl IntoIterator, @@ -170,6 +170,41 @@ pub fn spawn_server( exit: Arc, staked_nodes: Arc>, quic_server_params: QuicServerParams, +) -> Result { + let cancel = CancellationToken::new(); + tokio::spawn({ + let cancel = cancel.clone(); + async move { + loop { + if exit.load(Ordering::Relaxed) { + cancel.cancel(); + break; + } + sleep(Duration::from_millis(100)).await; + } + } + }); + + spawn_server_with_cancel( + name, + sockets, + keypair, + packet_sender, + staked_nodes, + quic_server_params, + cancel, + ) +} + +/// Spawn a streamer instance in the current tokio runtime. +pub fn spawn_server_with_cancel( + name: &'static str, + sockets: impl IntoIterator, + keypair: &Keypair, + packet_sender: Sender, + staked_nodes: Arc>, + quic_server_params: QuicServerParams, + cancel: CancellationToken, ) -> Result { let sockets: Vec<_> = sockets.into_iter().collect(); info!("Start {name} quic server on {sockets:?}"); @@ -188,16 +223,33 @@ pub fn spawn_server( }) .collect::, _>>()?; let stats = Arc::::default(); + let (packet_batch_sender, packet_batch_receiver) = + bounded(quic_server_params.coalesce_channel_size); + task::spawn_blocking({ + let cancel = cancel.clone(); + let stats = stats.clone(); + move || { + run_packet_batch_sender( + packet_sender, + packet_batch_receiver, + stats, + quic_server_params.coalesce, + cancel, + ); + } + }); + let max_concurrent_connections = quic_server_params.max_concurrent_connections(); let handle = tokio::spawn(run_server( name, endpoints.clone(), - packet_sender, - exit, + packet_batch_sender, staked_nodes, stats.clone(), quic_server_params, + cancel, )); + Ok(SpawnNonBlockingServerResult { endpoints, stats, @@ -255,11 +307,11 @@ impl ClientConnectionTracker { async fn run_server( name: &'static str, endpoints: Vec, - packet_sender: Sender, - exit: Arc, + packet_batch_sender: Sender, staked_nodes: Arc>, stats: Arc, quic_server_params: QuicServerParams, + cancel: CancellationToken, ) { let rate_limiter = Arc::new(ConnectionRateLimiter::new( quic_server_params.max_connections_per_ipaddr_per_min, @@ -272,7 +324,7 @@ async fn run_server( debug!("spawn quic server"); let mut last_datapoint = Instant::now(); let unstaked_connection_table: Arc> = - Arc::new(Mutex::new(ConnectionTable::new(false))); + Arc::new(Mutex::new(ConnectionTable::new(false, cancel.clone()))); let stream_load_ema = Arc::new(StakedStreamLoadEMA::new( stats.clone(), quic_server_params.max_unstaked_connections, @@ -282,22 +334,7 @@ async fn run_server( .quic_endpoints_count .store(endpoints.len(), Ordering::Relaxed); let staked_connection_table: Arc> = - Arc::new(Mutex::new(ConnectionTable::new(true))); - let (sender, receiver) = bounded(quic_server_params.coalesce_channel_size); - - thread::spawn({ - let exit = exit.clone(); - let stats = stats.clone(); - move || { - packet_batch_sender( - packet_sender, - receiver, - exit, - stats, - quic_server_params.coalesce, - ); - } - }); + Arc::new(Mutex::new(ConnectionTable::new(true, cancel.clone()))); let mut accepts = endpoints .iter() @@ -310,7 +347,7 @@ async fn run_server( }) .collect::>(); - while !exit.load(Ordering::Relaxed) { + loop { let timeout_connection = select! { ready = accepts.next() => { if let Some((connecting, i)) = ready { @@ -329,6 +366,7 @@ async fn run_server( _ = tokio::time::sleep(WAIT_FOR_CONNECTION_TIMEOUT) => { Err(()) } + _ = cancel.cancelled() => break, }; if last_datapoint.elapsed().as_secs() >= 5 { @@ -375,7 +413,7 @@ async fn run_server( client_connection_tracker, unstaked_connection_table.clone(), staked_connection_table.clone(), - sender.clone(), + packet_batch_sender.clone(), staked_nodes.clone(), stats.clone(), stream_load_ema.clone(), @@ -568,10 +606,10 @@ fn handle_and_cache_new_connection( remote_addr, last_update, connection_table, - cancel_connection, params.clone(), stream_load_ema, stream_counter, + cancel_connection, )); Ok(()) } else { @@ -880,12 +918,12 @@ fn handle_connection_error(e: quinn::ConnectionError, stats: &StreamerStats, fro // Holder(s) of the Sender on the other end should not // wait for this function to exit -fn packet_batch_sender( +fn run_packet_batch_sender( packet_sender: Sender, packet_receiver: Receiver, - exit: Arc, stats: Arc, coalesce: Duration, + cancel: CancellationToken, ) { trace!("enter packet_batch_sender"); let mut batch_start_time = Instant::now(); @@ -902,7 +940,7 @@ fn packet_batch_sender( .fetch_add(PACKETS_PER_BATCH, Ordering::Relaxed); loop { - if exit.load(Ordering::Relaxed) { + if cancel.is_cancelled() { return; } let elapsed = batch_start_time.elapsed(); @@ -920,7 +958,7 @@ fn packet_batch_sender( // The downstream channel is disconnected, this error is not recoverable. if matches!(e, TrySendError::Disconnected(_)) { - exit.store(true, Ordering::Relaxed); + cancel.cancel(); return; } } else { @@ -1034,10 +1072,10 @@ async fn handle_connection( remote_addr: SocketAddr, last_update: Arc, connection_table: Arc>, - cancel: CancellationToken, params: NewConnectionHandlerParams, stream_load_ema: Arc, stream_counter: Arc, + cancel: CancellationToken, ) { let NewConnectionHandlerParams { packet_sender, @@ -1393,16 +1431,19 @@ struct ConnectionTable { table: IndexMap>, total_size: usize, is_staked: bool, + cancel: CancellationToken, } -// Prune the connection which has the oldest update -// Return number pruned +/// Prune the connection which has the oldest update +/// +/// Return number pruned impl ConnectionTable { - fn new(is_staked: bool) -> Self { + fn new(is_staked: bool, cancel: CancellationToken) -> Self { Self { table: IndexMap::default(), total_size: 0, is_staked, + cancel, } } @@ -1479,7 +1520,7 @@ impl ConnectionTable { .map(|c| c <= max_connections_per_peer) .unwrap_or(false); if has_connection_capacity { - let cancel = CancellationToken::new(); + let cancel = self.cancel.child_token(); let last_update = Arc::new(AtomicU64::new(last_update)); let stream_counter = connection_entry .first() @@ -1687,12 +1728,12 @@ pub mod test { async fn test_quic_server_exit() { let SpawnTestServerResult { join_handle, - exit, receiver: _, server_address: _, stats: _, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } @@ -1701,14 +1742,14 @@ pub mod test { solana_logger::setup(); let SpawnTestServerResult { join_handle, - exit, receiver, server_address, stats: _, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); check_timeout(receiver, server_address).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } @@ -1717,18 +1758,18 @@ pub mod test { solana_logger::setup(); let (pkt_batch_sender, pkt_batch_receiver) = unbounded(); let (ptk_sender, pkt_receiver) = unbounded(); - let exit = Arc::new(AtomicBool::new(false)); + let cancel = CancellationToken::new(); let stats = Arc::new(StreamerStats::default()); - let handle = thread::spawn({ - let exit = exit.clone(); + let handle = task::spawn_blocking({ + let cancel = cancel.clone(); move || { - packet_batch_sender( + run_packet_batch_sender( pkt_batch_sender, pkt_receiver, - exit, stats, DEFAULT_TPU_COALESCE, + cancel, ); } }); @@ -1757,10 +1798,10 @@ pub mod test { } } assert_eq!(i, num_packets); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); // Explicit drop to wake up packet_batch_sender drop(ptk_sender); - handle.join().unwrap(); + handle.await.unwrap(); } #[tokio::test(flavor = "multi_thread")] @@ -1768,10 +1809,10 @@ pub mod test { solana_logger::setup(); let SpawnTestServerResult { join_handle, - exit, receiver: _, server_address, stats, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); let conn1 = make_client_endpoint(&server_address, None).await; @@ -1794,7 +1835,7 @@ pub mod test { // after the timeouts) assert!(s1.write_all(&[0u8]).await.is_err()); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } @@ -1803,13 +1844,13 @@ pub mod test { solana_logger::setup(); let SpawnTestServerResult { join_handle, - exit, receiver: _, server_address, stats: _, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); check_block_multiple_connections(server_address).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } @@ -1819,10 +1860,10 @@ pub mod test { let SpawnTestServerResult { join_handle, - exit, receiver: _, server_address, stats, + cancel, } = setup_quic_server( None, QuicServerParams { @@ -1885,7 +1926,7 @@ pub mod test { } assert!(start.elapsed().as_secs() < 1); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } @@ -1894,13 +1935,13 @@ pub mod test { solana_logger::setup(); let SpawnTestServerResult { join_handle, - exit, receiver, server_address, stats: _, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); check_multiple_writes(receiver, server_address, None).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } @@ -1916,13 +1957,13 @@ pub mod test { ); let SpawnTestServerResult { join_handle, - exit, receiver, server_address, stats, + cancel, } = setup_quic_server(Some(staked_nodes), QuicServerParams::default_for_tests()); check_multiple_writes(receiver, server_address, Some(&client_keypair)).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); sleep(Duration::from_millis(100)).await; assert_eq!( @@ -1948,13 +1989,13 @@ pub mod test { ); let SpawnTestServerResult { join_handle, - exit, receiver, server_address, stats, + cancel, } = setup_quic_server(Some(staked_nodes), QuicServerParams::default_for_tests()); check_multiple_writes(receiver, server_address, Some(&client_keypair)).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); sleep(Duration::from_millis(100)).await; assert_eq!( @@ -1972,13 +2013,13 @@ pub mod test { solana_logger::setup(); let SpawnTestServerResult { join_handle, - exit, receiver, server_address, stats, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); check_multiple_writes(receiver, server_address, None).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); sleep(Duration::from_millis(100)).await; assert_eq!( @@ -1995,32 +2036,32 @@ pub mod test { async fn test_quic_server_unstaked_node_connect_failure() { solana_logger::setup(); let s = bind_to_localhost_unique().expect("should bind"); - let exit = Arc::new(AtomicBool::new(false)); let (sender, _) = unbounded(); let keypair = Keypair::new(); let server_address = s.local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let cancel = CancellationToken::new(); let SpawnNonBlockingServerResult { endpoints: _, stats: _, thread: t, max_concurrent_connections: _, - } = spawn_server( + } = spawn_server_with_cancel( "quic_streamer_test", [s], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams { max_unstaked_connections: 0, // Do not allow any connection from unstaked clients/nodes ..QuicServerParams::default_for_tests() }, + cancel.clone(), ) .unwrap(); check_unstaked_node_connect_failure(server_address).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.await.unwrap(); } @@ -2028,27 +2069,27 @@ pub mod test { async fn test_quic_server_multiple_streams() { solana_logger::setup(); let s = bind_to_localhost_unique().expect("should bind"); - let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); let keypair = Keypair::new(); let server_address = s.local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let cancel = CancellationToken::new(); let SpawnNonBlockingServerResult { endpoints: _, stats, thread: t, max_concurrent_connections: _, - } = spawn_server( + } = spawn_server_with_cancel( "quic_streamer_test", [s], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams { max_connections_per_peer: 2, ..QuicServerParams::default_for_tests() }, + cancel.clone(), ) .unwrap(); @@ -2057,8 +2098,12 @@ pub mod test { assert_eq!(stats.total_new_streams.load(Ordering::Relaxed), 20); assert_eq!(stats.total_connections.load(Ordering::Relaxed), 2); assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.await.unwrap(); + // handle of the streamer doesn't wait for the child task to finish, so + // it is not deterministic if the tasks handling connections exit before + // the assertion below or after. + sleep(Duration::from_millis(100)).await; assert_eq!(stats.total_connections.load(Ordering::Relaxed), 0); assert_eq!(stats.total_new_connections.load(Ordering::Relaxed), 2); } @@ -2067,7 +2112,8 @@ pub mod test { fn test_prune_table_with_ip() { use std::net::Ipv4Addr; solana_logger::setup(); - let mut table = ConnectionTable::new(false); + let cancel = CancellationToken::new(); + let mut table = ConnectionTable::new(false, cancel); let mut num_entries = 5; let max_connections_per_peer = 10; let sockets: Vec<_> = (0..num_entries) @@ -2120,7 +2166,8 @@ pub mod test { #[test] fn test_prune_table_with_unique_pubkeys() { solana_logger::setup(); - let mut table = ConnectionTable::new(false); + let cancel = CancellationToken::new(); + let mut table = ConnectionTable::new(false, cancel); // We should be able to add more entries than max_connections_per_peer, since each entry is // from a different peer pubkey. @@ -2158,7 +2205,8 @@ pub mod test { #[test] fn test_prune_table_with_non_unique_pubkeys() { solana_logger::setup(); - let mut table = ConnectionTable::new(false); + let cancel = CancellationToken::new(); + let mut table = ConnectionTable::new(false, cancel); let max_connections_per_peer = 10; let pubkey = Pubkey::new_unique(); @@ -2224,7 +2272,9 @@ pub mod test { fn test_prune_table_random() { use std::net::Ipv4Addr; solana_logger::setup(); - let mut table = ConnectionTable::new(true); + let cancel = CancellationToken::new(); + let mut table = ConnectionTable::new(false, cancel); + let num_entries = 5; let max_connections_per_peer = 10; let sockets: Vec<_> = (0..num_entries) @@ -2266,7 +2316,9 @@ pub mod test { fn test_remove_connections() { use std::net::Ipv4Addr; solana_logger::setup(); - let mut table = ConnectionTable::new(false); + let cancel = CancellationToken::new(); + let mut table = ConnectionTable::new(false, cancel); + let num_ips = 5; let max_connections_per_peer = 10; let mut sockets: Vec<_> = (0..num_ips) @@ -2395,10 +2447,10 @@ pub mod test { let SpawnTestServerResult { join_handle, - exit, receiver, server_address, stats, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); let client_connection = make_client_endpoint(&server_address, None).await; @@ -2428,7 +2480,7 @@ pub mod test { assert_eq!(expected_num_txs, num_txs_received); // stop it - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); assert_eq!( @@ -2456,7 +2508,7 @@ pub mod test { join_handle, server_address, stats, - exit, + cancel, .. } = setup_quic_server(None, QuicServerParams::default_for_tests()); @@ -2475,7 +2527,7 @@ pub mod test { _ => panic!("unexpected close"), } assert_eq!(stats.invalid_stream_size.load(Ordering::Relaxed), 1); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); join_handle.await.unwrap(); } } diff --git a/streamer/src/nonblocking/testing_utilities.rs b/streamer/src/nonblocking/testing_utilities.rs index 47003345cd3a3b..5d130e2cada878 100644 --- a/streamer/src/nonblocking/testing_utilities.rs +++ b/streamer/src/nonblocking/testing_utilities.rs @@ -1,7 +1,8 @@ //! Contains utility functions to create server and client for test purposes. use { - super::quic::{spawn_server, SpawnNonBlockingServerResult, ALPN_TPU_PROTOCOL_ID}, + super::quic::{SpawnNonBlockingServerResult, ALPN_TPU_PROTOCOL_ID}, crate::{ + nonblocking::quic::spawn_server_with_cancel, quic::{QuicServerParams, StreamerStats}, streamer::StakedNodes, }, @@ -20,10 +21,11 @@ use { solana_tls_utils::{new_dummy_x509_certificate, tls_client_config_builder}, std::{ net::{IpAddr, Ipv4Addr, SocketAddr, UdpSocket}, - sync::{atomic::AtomicBool, Arc, RwLock}, + sync::{Arc, RwLock}, time::{Duration, Instant}, }, tokio::{task::JoinHandle, time::sleep}, + tokio_util::sync::CancellationToken, }; pub fn get_client_config(keypair: &Keypair) -> ClientConfig { @@ -50,10 +52,10 @@ pub fn get_client_config(keypair: &Keypair) -> ClientConfig { pub struct SpawnTestServerResult { pub join_handle: JoinHandle<()>, - pub exit: Arc, pub receiver: crossbeam_channel::Receiver, pub server_address: SocketAddr, pub stats: Arc, + pub cancel: CancellationToken, } pub fn create_quic_server_sockets() -> Vec { @@ -86,33 +88,33 @@ pub fn setup_quic_server_with_sockets( option_staked_nodes: Option, quic_server_params: QuicServerParams, ) -> SpawnTestServerResult { - let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); let keypair = Keypair::new(); let server_address = sockets[0].local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(option_staked_nodes.unwrap_or_default())); + let cancel = CancellationToken::new(); let SpawnNonBlockingServerResult { endpoints: _, stats, thread: handle, max_concurrent_connections: _, - } = spawn_server( + } = spawn_server_with_cancel( "quic_streamer_test", sockets, &keypair, sender, - exit.clone(), staked_nodes, quic_server_params, + cancel.clone(), ) .unwrap(); SpawnTestServerResult { join_handle: handle, - exit, receiver, server_address, stats, + cancel, } } diff --git a/streamer/src/quic.rs b/streamer/src/quic.rs index 6bb74a268e1af6..7e1301aec81011 100644 --- a/streamer/src/quic.rs +++ b/streamer/src/quic.rs @@ -24,10 +24,11 @@ use { atomic::{AtomicBool, AtomicU64, AtomicUsize, Ordering}, Arc, Mutex, RwLock, }, - thread, + thread::{self}, time::Duration, }, tokio::runtime::Runtime, + tokio_util::sync::CancellationToken, }; // allow multiple connections for NAT and any open/close overlap @@ -588,7 +589,7 @@ impl StreamerStats { } } -#[deprecated(since = "3.0.0", note = "Use spawn_server instead")] +#[deprecated(since = "3.0.0", note = "Use spawn_server_with_cancel instead")] pub fn spawn_server_multi( thread_name: &'static str, metrics_name: &'static str, @@ -599,6 +600,7 @@ pub fn spawn_server_multi( staked_nodes: Arc>, quic_server_params: QuicServerParams, ) -> Result { + #[allow(deprecated)] spawn_server( thread_name, metrics_name, @@ -660,7 +662,7 @@ impl QuicServerParams { } } -/// Spawns a tokio runtime and a streamer instance inside it. +#[deprecated(since = "3.1.0", note = "Use spawn_server_with_cancel instead")] pub fn spawn_server( thread_name: &'static str, metrics_name: &'static str, @@ -670,18 +672,52 @@ pub fn spawn_server( exit: Arc, staked_nodes: Arc>, quic_server_params: QuicServerParams, +) -> Result { + let cancel = CancellationToken::new(); + thread::spawn({ + let cancel = cancel.clone(); + move || loop { + if exit.load(Ordering::Relaxed) { + cancel.cancel(); + break; + } + thread::sleep(Duration::from_millis(100)); + } + }); + spawn_server_with_cancel( + thread_name, + metrics_name, + sockets, + keypair, + packet_sender, + staked_nodes, + quic_server_params, + cancel, + ) +} + +/// Spawns a tokio runtime and a streamer instance inside it. +pub fn spawn_server_with_cancel( + thread_name: &'static str, + metrics_name: &'static str, + sockets: impl IntoIterator, + keypair: &Keypair, + packet_sender: Sender, + staked_nodes: Arc>, + quic_server_params: QuicServerParams, + cancel: CancellationToken, ) -> Result { let runtime = rt(format!("{thread_name}Rt"), quic_server_params.num_threads); let result = { let _guard = runtime.enter(); - crate::nonblocking::quic::spawn_server( + crate::nonblocking::quic::spawn_server_with_cancel( metrics_name, sockets, keypair, packet_sender, - exit, staked_nodes, quic_server_params, + cancel, ) }?; let handle = thread::Builder::new() @@ -721,59 +757,59 @@ mod test { fn setup_quic_server() -> ( std::thread::JoinHandle<()>, - Arc, crossbeam_channel::Receiver, SocketAddr, + CancellationToken, ) { let s = bind_to_localhost_unique().expect("should bind"); - let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); let keypair = Keypair::new(); let server_address = s.local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let cancel = CancellationToken::new(); let SpawnServerResult { endpoints: _, thread: t, key_updater: _, - } = spawn_server( + } = spawn_server_with_cancel( "solQuicTest", "quic_streamer_test", [s], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams::default_for_tests(), + cancel.clone(), ) .unwrap(); - (t, exit, receiver, server_address) + (t, receiver, server_address, cancel) } #[test] fn test_quic_server_exit() { - let (t, exit, _receiver, _server_address) = setup_quic_server(); - exit.store(true, Ordering::Relaxed); + let (t, _receiver, _server_address, cancel) = setup_quic_server(); + cancel.cancel(); t.join().unwrap(); } #[test] fn test_quic_timeout() { solana_logger::setup(); - let (t, exit, receiver, server_address) = setup_quic_server(); + let (t, receiver, server_address, cancel) = setup_quic_server(); let runtime = rt_for_test(); runtime.block_on(check_timeout(receiver, server_address)); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.join().unwrap(); } #[test] fn test_quic_server_block_multiple_connections() { solana_logger::setup(); - let (t, exit, _receiver, server_address) = setup_quic_server(); + let (t, _receiver, server_address, cancel) = setup_quic_server(); let runtime = rt_for_test(); runtime.block_on(check_block_multiple_connections(server_address)); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.join().unwrap(); } @@ -781,44 +817,44 @@ mod test { fn test_quic_server_multiple_streams() { solana_logger::setup(); let s = bind_to_localhost_unique().expect("should bind"); - let exit = Arc::new(AtomicBool::new(false)); let (sender, receiver) = unbounded(); let keypair = Keypair::new(); let server_address = s.local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let cancel = CancellationToken::new(); let SpawnServerResult { endpoints: _, thread: t, key_updater: _, - } = spawn_server( + } = spawn_server_with_cancel( "solQuicTest", "quic_streamer_test", [s], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams { max_connections_per_peer: 2, ..QuicServerParams::default_for_tests() }, + cancel.clone(), ) .unwrap(); let runtime = rt_for_test(); runtime.block_on(check_multiple_streams(receiver, server_address, None)); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.join().unwrap(); } #[test] fn test_quic_server_multiple_writes() { solana_logger::setup(); - let (t, exit, receiver, server_address) = setup_quic_server(); + let (t, receiver, server_address, cancel) = setup_quic_server(); let runtime = rt_for_test(); runtime.block_on(check_multiple_writes(receiver, server_address, None)); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.join().unwrap(); } @@ -826,33 +862,33 @@ mod test { fn test_quic_server_unstaked_node_connect_failure() { solana_logger::setup(); let s = bind_to_localhost_unique().expect("should bind"); - let exit = Arc::new(AtomicBool::new(false)); let (sender, _) = unbounded(); let keypair = Keypair::new(); let server_address = s.local_addr().unwrap(); let staked_nodes = Arc::new(RwLock::new(StakedNodes::default())); + let cancel = CancellationToken::new(); let SpawnServerResult { endpoints: _, thread: t, key_updater: _, - } = spawn_server( + } = spawn_server_with_cancel( "solQuicTest", "quic_streamer_test", [s], &keypair, sender, - exit.clone(), staked_nodes, QuicServerParams { max_unstaked_connections: 0, ..QuicServerParams::default_for_tests() }, + cancel.clone(), ) .unwrap(); let runtime = rt_for_test(); runtime.block_on(check_unstaked_node_connect_failure(server_address)); - exit.store(true, Ordering::Relaxed); + cancel.cancel(); t.join().unwrap(); } } diff --git a/tpu-client-next/tests/connection_workers_scheduler_test.rs b/tpu-client-next/tests/connection_workers_scheduler_test.rs index 8a876fb17baf4a..10465f9ecd5b19 100644 --- a/tpu-client-next/tests/connection_workers_scheduler_test.rs +++ b/tpu-client-next/tests/connection_workers_scheduler_test.rs @@ -29,7 +29,7 @@ use { collections::HashMap, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, num::Saturating, - sync::{atomic::Ordering, Arc}, + sync::Arc, time::Duration, }, tokio::{ @@ -196,10 +196,10 @@ fn spawn_tx_sender( async fn test_basic_transactions_sending() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); // Setup sending txs @@ -255,7 +255,7 @@ async fn test_basic_transactions_sending() { assert_eq!(stats.successfully_sent, expected_num_txs as u64,); // Stop server - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -286,10 +286,10 @@ async fn count_received_packets_for( async fn test_connection_denied_until_allowed() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); // To prevent server from accepting a new connection, we use the following observation. @@ -336,7 +336,7 @@ async fn test_connection_denied_until_allowed() { drop(throttling_connection); // Exit server - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -347,10 +347,10 @@ async fn test_connection_denied_until_allowed() { async fn test_connection_pruned_and_reopened() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server( None, QuicServerParams { @@ -389,7 +389,7 @@ async fn test_connection_pruned_and_reopened() { ); // Exit server - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -403,10 +403,10 @@ async fn test_staked_connection() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server( Some(staked_nodes), QuicServerParams { @@ -447,7 +447,7 @@ async fn test_staked_connection() { ); // Exit server - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -458,10 +458,10 @@ async fn test_staked_connection() { async fn test_connection_throttling() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server(None, QuicServerParams::default_for_tests()); // Setup sending txs @@ -494,7 +494,7 @@ async fn test_connection_throttling() { ); // Exit server - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -549,10 +549,10 @@ async fn test_no_host() { async fn test_rate_limiting() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server( None, QuicServerParams { @@ -600,7 +600,7 @@ async fn test_rate_limiting() { ); // Stop the server. - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -613,10 +613,10 @@ async fn test_rate_limiting() { async fn test_rate_limiting_establish_connection() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server( None, QuicServerParams { @@ -675,7 +675,7 @@ async fn test_rate_limiting_establish_connection() { assert_eq!(stats, SendTransactionStatsNonAtomic::default()); // Stop the server. - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } @@ -694,10 +694,10 @@ async fn test_update_identity() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server( Some(staked_nodes), QuicServerParams { @@ -746,20 +746,20 @@ async fn test_update_identity() { assert!(stats.successfully_sent > 0); // Exit server - exit.store(true, Ordering::Relaxed); + cancel.cancel(); server_handle.await.unwrap(); } -// Test that connection close events are detected immediately via connection.closed() -// monitoring, not only when send operations fail. +// Test that connection close events are detected immediately via +// connection.closed() monitoring, not only when send operations fail. #[tokio::test] async fn test_proactive_connection_close_detection() { let SpawnTestServerResult { join_handle: server_handle, - exit, receiver, server_address, stats: _stats, + cancel, } = setup_quic_server( None, QuicServerParams { @@ -773,23 +773,15 @@ async fn test_proactive_connection_close_detection() { let tx_size = 1; let (tx_sender, tx_receiver) = channel(10); - let sender_task = tokio::spawn(async move { - // Send first transaction to establish connection - tx_sender - .send(TransactionBatch::new(vec![vec![1u8; tx_size]])) - .await - .expect("Send first batch"); - - // Idle period where connection might be closed - sleep(Duration::from_millis(500)).await; - - // Attempt another send - drop(tx_sender.send(TransactionBatch::new(vec![vec![2u8; tx_size]]))); - }); - let (scheduler_handle, _update_identity_sender, scheduler_cancel) = setup_connection_worker_scheduler(server_address, tx_receiver, None).await; + // Send first transaction to establish connection + tx_sender + .send(TransactionBatch::new(vec![vec![1u8; tx_size]])) + .await + .expect("Send first batch"); + // Verify first packet received let mut first_packet_received = false; let start = Instant::now(); @@ -804,15 +796,21 @@ async fn test_proactive_connection_close_detection() { } assert!(first_packet_received, "First packet should be received"); - // Force connection close by exceeding max_connections_per_peer - let _pruning_connection = make_client_endpoint(&server_address, None).await; + // Exit server + cancel.cancel(); + server_handle.await.unwrap(); - // Allow time for proactive detection - sleep(Duration::from_millis(200)).await; + tx_sender + .send(TransactionBatch::new(vec![vec![2u8; tx_size]])) + .await + .expect("Send second batch"); + tx_sender + .send(TransactionBatch::new(vec![vec![3u8; tx_size]])) + .await + .expect("Send third batch"); // Clean up scheduler_cancel.cancel(); - let _ = sender_task.await; let stats = join_scheduler(scheduler_handle).await; // Verify proactive close detection @@ -820,8 +818,4 @@ async fn test_proactive_connection_close_detection() { stats.connection_error_application_closed > 0 || stats.write_error_connection_lost > 0, "Should detect connection close proactively. Stats: {stats:?}" ); - - // Exit server - exit.store(true, Ordering::Relaxed); - server_handle.await.unwrap(); } diff --git a/vortexor/Cargo.toml b/vortexor/Cargo.toml index 6ea58785513d47..e2afbc9339cd5c 100644 --- a/vortexor/Cargo.toml +++ b/vortexor/Cargo.toml @@ -64,6 +64,7 @@ solana-transaction-metrics-tracker = { workspace = true } solana-version = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = ["full"] } +tokio-util = { workspace = true } url = { workspace = true } x509-parser = { workspace = true } diff --git a/vortexor/src/main.rs b/vortexor/src/main.rs index d35cd499937471..e266efd0d41bbb 100644 --- a/vortexor/src/main.rs +++ b/vortexor/src/main.rs @@ -26,6 +26,7 @@ use { sync::{atomic::AtomicBool, Arc, RwLock}, time::Duration, }, + tokio_util::sync::CancellationToken, }; const DEFAULT_CHANNEL_SIZE: usize = 100_000; @@ -83,6 +84,7 @@ pub fn main() { let tpu_forward_address = args.tpu_forward_address; let max_streams_per_ms = args.max_streams_per_ms; let exit = Arc::new(AtomicBool::new(false)); + let cancel = CancellationToken::new(); // To be linked with the Tpu sigverify and forwarder service let (tpu_sender, tpu_receiver) = bounded(DEFAULT_CHANNEL_SIZE); let (tpu_fwd_sender, _tpu_fwd_receiver) = bounded(DEFAULT_CHANNEL_SIZE); @@ -202,7 +204,7 @@ pub fn main() { max_connections_per_ipaddr_per_min, tpu_coalesce, &identity_keypair, - exit, + cancel.clone(), ); vortexor.join().unwrap(); sigverify_stage.join().unwrap(); diff --git a/vortexor/src/vortexor.rs b/vortexor/src/vortexor.rs index b39d8f38315912..9aa8fb8156f9c4 100644 --- a/vortexor/src/vortexor.rs +++ b/vortexor/src/vortexor.rs @@ -12,15 +12,16 @@ use { solana_quic_definitions::NotifyKeyUpdate, solana_streamer::{ nonblocking::quic::DEFAULT_WAIT_FOR_CHUNK_TIMEOUT, - quic::{spawn_server, EndpointKeyUpdater, QuicServerParams}, + quic::{spawn_server_with_cancel, EndpointKeyUpdater, QuicServerParams}, streamer::StakedNodes, }, std::{ net::{SocketAddr, UdpSocket}, - sync::{atomic::AtomicBool, Arc, Mutex, RwLock}, + sync::{Arc, Mutex, RwLock}, thread::{self, JoinHandle}, time::Duration, }, + tokio_util::sync::CancellationToken, }; pub struct TpuSockets { @@ -115,7 +116,7 @@ impl Vortexor { max_connections_per_ipaddr_per_min: u64, tpu_coalesce: Duration, identity_keypair: &Keypair, - exit: Arc, + cancel: CancellationToken, ) -> Self { let mut quic_server_params = QuicServerParams { max_connections_per_peer, @@ -133,15 +134,15 @@ impl Vortexor { tpu_quic_fwd, } = tpu_sockets; - let tpu_result = spawn_server( + let tpu_result = spawn_server_with_cancel( "solVtxTpu", "quic_vortexor_tpu", tpu_quic, identity_keypair, tpu_sender.clone(), - exit.clone(), staked_nodes.clone(), quic_server_params.clone(), + cancel.clone(), ) .unwrap(); @@ -149,15 +150,15 @@ impl Vortexor { // for staked connections: quic_server_params.max_staked_connections = max_fwd_staked_connections; quic_server_params.max_unstaked_connections = max_fwd_unstaked_connections; - let tpu_fwd_result = spawn_server( + let tpu_fwd_result = spawn_server_with_cancel( "solVtxTpuFwd", "quic_vortexor_tpu_forwards", tpu_quic_fwd, identity_keypair, tpu_fwd_sender, - exit.clone(), staked_nodes.clone(), quic_server_params, + cancel.clone(), ) .unwrap(); diff --git a/vortexor/tests/vortexor.rs b/vortexor/tests/vortexor.rs index 6a18f4788e589e..0b5437f66b30a8 100644 --- a/vortexor/tests/vortexor.rs +++ b/vortexor/tests/vortexor.rs @@ -33,6 +33,7 @@ use { }, time::Duration, }, + tokio_util::sync::CancellationToken, url::Url, }; @@ -42,7 +43,7 @@ async fn test_vortexor() { let bind_address = solana_net_utils::parse_host("127.0.0.1").expect("invalid bind_address"); let keypair = Keypair::new(); - let exit = Arc::new(AtomicBool::new(false)); + let cancel = CancellationToken::new(); let (tpu_sender, tpu_receiver) = unbounded(); let (tpu_fwd_sender, tpu_fwd_receiver) = unbounded(); @@ -77,13 +78,13 @@ async fn test_vortexor() { DEFAULT_MAX_CONNECTIONS_PER_IPADDR_PER_MINUTE, DEFAULT_TPU_COALESCE, &keypair, - exit.clone(), + cancel.clone(), ); check_multiple_streams(tpu_receiver, tpu_address, Some(&keypair)).await; check_multiple_streams(tpu_fwd_receiver, tpu_fwd_address, Some(&keypair)).await; - exit.store(true, Ordering::Relaxed); + cancel.cancel(); vortexor.join().unwrap(); }