diff --git a/net-utils/src/ip_echo_client.rs b/net-utils/src/ip_echo_client.rs index 1ec2def4af7..43d0bf99b23 100644 --- a/net-utils/src/ip_echo_client.rs +++ b/net-utils/src/ip_echo_client.rs @@ -8,7 +8,7 @@ use { itertools::Itertools, log::*, std::{ - collections::{BTreeMap, HashSet}, + collections::{BTreeMap, HashMap, HashSet}, net::{IpAddr, SocketAddr, TcpListener, TcpStream, UdpSocket}, sync::{Arc, RwLock}, time::{Duration, Instant}, @@ -219,14 +219,13 @@ pub(crate) async fn verify_all_reachable_tcp( ok } -/// Checks if all of the provided UDP ports are reachable by the machine at -/// `ip_echo_server_addr`. +/// Checks if all of the provided UDP ports on all of the provided IPs are +/// reachable by the machine at `ip_echo_server_addr`. /// This function will test a few ports at a time, retrying if necessary. /// Tests must complete within timeout provided, so a longer timeout may be /// necessary if checking many ports. /// A given amount of retries will be made to accommodate packet loss. /// This function may panic. -/// This function assumes that all sockets are bound to the same IP. pub(crate) async fn verify_all_reachable_udp( ip_echo_server_addr: SocketAddr, sockets: &[&UdpSocket], @@ -237,119 +236,107 @@ pub(crate) async fn verify_all_reachable_udp( warn!("No ports provided for verify_all_reachable_udp to check"); return true; } - // Extract the bind_address for requests from the first socket, it should be same for all others too - let bind_address = sockets[0] - .local_addr() - .expect("Sockets should be bound") - .ip(); - // This function may get fed multiple sockets bound to the same port. - // In such case we need to know which sockets are bound to each port, - // as only one of them will receive a packet from echo server - let mut ports_to_socks_map: BTreeMap<_, _> = BTreeMap::new(); + let mut ip_to_ports: HashMap>> = HashMap::new(); for &socket in sockets.iter() { - let local_binding = socket.local_addr().expect("Sockets should be bound"); - assert_eq!( - local_binding.ip(), - bind_address, - "All sockets should be bound to the same IP" - ); - let port = local_binding.port(); - ports_to_socks_map - .entry(port) - .or_insert_with(Vec::new) + let local_addr = socket.local_addr().expect("Socket must be bound"); + ip_to_ports + .entry(local_addr.ip()) + .or_default() + .entry(local_addr.port()) + .or_default() .push(socket); } + for (bind_ip, ports_to_socks_map) in ip_to_ports { + let ports: Vec = ports_to_socks_map.keys().copied().collect(); - let ports: Vec<_> = ports_to_socks_map.into_iter().collect(); - - info!( - "Checking that udp ports {:?} are reachable from {:?}", - ports.iter().map(|(port, _)| port).collect::>(), - ip_echo_server_addr - ); + info!( + "Checking that udp ports {:?} are reachable from bind IP {:?}", + ports, bind_ip + ); - 'outer: for chunk_to_check in ports.chunks(MAX_PORT_COUNT_PER_MESSAGE) { - let ports_to_check = chunk_to_check - .iter() - .map(|(port, _)| *port) - .collect::>(); + 'outer: for chunk_to_check in ports.chunks(MAX_PORT_COUNT_PER_MESSAGE) { + let ports_to_check = chunk_to_check.to_vec(); - for attempt in 0..retry_count { - if attempt > 0 { - error!("There are some udp ports with no response!! Retrying..."); - } - // clone off the sockets that use ports within our chunk - let sockets_to_check = chunk_to_check.iter().flat_map(|(_, sockets)| { - sockets + for attempt in 0..retry_count { + if attempt > 0 { + error!("There are some udp ports with no response!! Retrying..."); + } + // clone off the sockets that use ports within our chunk + let sockets_to_check: Vec = ports_to_check .iter() - .map(|&s| s.try_clone().expect("Unable to clone udp socket")) - }); - - let _ = ip_echo_server_request_with_binding( - ip_echo_server_addr, - IpEchoServerMessage::new(&[], &ports_to_check), - bind_address, - ) - .await - .map_err(|err| warn!("ip_echo_server request failed: {}", err)); - - let reachable_ports = Arc::new(RwLock::new(HashSet::new())); - // Spawn threads for each socket to check - let mut checkers = JoinSet::new(); - for socket in sockets_to_check { - let port = socket.local_addr().expect("Socket should be bound").port(); - let reachable_ports = reachable_ports.clone(); - - // Use blocking API since we have no idea if sockets given to us are nonblocking or not - checkers.spawn_blocking(move || { - let start = Instant::now(); - - let original_read_timeout = socket.read_timeout().unwrap(); - socket - .set_read_timeout(Some(Duration::from_millis(250))) - .unwrap(); - - loop { - if reachable_ports.read().unwrap().contains(&port) - || Instant::now().duration_since(start) >= timeout - { - break; + .flat_map(|port| ports_to_socks_map.get(port).unwrap()) + .map(|&s| s.try_clone().expect("Unable to clone UDP socket")) + .collect(); + + let _ = ip_echo_server_request_with_binding( + ip_echo_server_addr, + IpEchoServerMessage::new(&[], &ports_to_check), + bind_ip, + ) + .await + .map_err(|err| warn!("ip_echo_server request failed: {}", err)); + + let reachable_ports = Arc::new(RwLock::new(HashSet::new())); + // Spawn threads for each socket to check + let mut checkers = JoinSet::new(); + for socket in sockets_to_check { + let port = socket.local_addr().expect("Socket should be bound").port(); + let reachable_ports = reachable_ports.clone(); + + checkers.spawn_blocking(move || { + let start = Instant::now(); + + let original_read_timeout = socket.read_timeout().unwrap(); + socket + .set_read_timeout(Some(Duration::from_millis(250))) + .unwrap(); + + loop { + if reachable_ports.read().unwrap().contains(&port) + || Instant::now().duration_since(start) >= timeout + { + break; + } + + let recv_result = socket.recv(&mut [0; 1]); + debug!( + "Waited for incoming datagram on udp/{}: {:?}", + port, recv_result + ); + + if recv_result.is_ok() { + reachable_ports.write().unwrap().insert(port); + break; + } } - let recv_result = socket.recv(&mut [0; 1]); - debug!( - "Waited for incoming datagram on udp/{}: {:?}", - port, recv_result - ); - - if recv_result.is_ok() { - reachable_ports.write().unwrap().insert(port); - break; - } - } - socket.set_read_timeout(original_read_timeout).unwrap(); - }); - } + socket.set_read_timeout(original_read_timeout).unwrap(); + }); + } - while let Some(r) = checkers.join_next().await { - r.expect("Threads should exit cleanly"); + while let Some(r) = checkers.join_next().await { + r.expect("Threads should exit cleanly"); + } + // Might have lost a UDP packet, check that all ports were reached + let reachable_ports = Arc::into_inner(reachable_ports) + .expect("Single owner expected") + .into_inner() + .expect("No threads should hold the lock"); + info!( + "checked udp ports: {:?}, reachable udp ports: {:?}", + ports_to_check, reachable_ports + ); + if reachable_ports.len() == ports_to_check.len() { + continue 'outer; // starts checking next chunk of ports, if any + } } - // Might have lost a UDP packet, check that all ports were reached - let reachable_ports = Arc::into_inner(reachable_ports) - .expect("Single owner expected") - .into_inner() - .expect("No threads should hold the lock"); - info!( - "checked udp ports: {:?}, reachable udp ports: {:?}", - ports_to_check, reachable_ports + error!( + "Maximum retry count reached. Some ports for IP {} unreachable.", + bind_ip ); - if reachable_ports.len() == ports_to_check.len() { - continue 'outer; // starts checking next chunk of ports, if any - } + return false; } - error!("Maximum retry count is reached...."); - return false; } true }