Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 90 additions & 103 deletions net-utils/src/ip_echo_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use {
itertools::Itertools,
log::*,
std::{
collections::{BTreeMap, HashSet},
collections::{BTreeMap, HashMap, HashSet},
net::{IpAddr, SocketAddr, TcpListener, TcpStream, UdpSocket},
sync::{Arc, RwLock},
time::{Duration, Instant},
Expand Down Expand Up @@ -219,14 +219,13 @@ pub(crate) async fn verify_all_reachable_tcp(
ok
}

/// Checks if all of the provided UDP ports are reachable by the machine at
/// `ip_echo_server_addr`.
/// Checks if all of the provided UDP ports on all of the provided IPs are
/// reachable by the machine at `ip_echo_server_addr`.
/// This function will test a few ports at a time, retrying if necessary.
/// Tests must complete within timeout provided, so a longer timeout may be
/// necessary if checking many ports.
/// A given amount of retries will be made to accommodate packet loss.
/// This function may panic.
/// This function assumes that all sockets are bound to the same IP.
pub(crate) async fn verify_all_reachable_udp(
ip_echo_server_addr: SocketAddr,
sockets: &[&UdpSocket],
Expand All @@ -237,119 +236,107 @@ pub(crate) async fn verify_all_reachable_udp(
warn!("No ports provided for verify_all_reachable_udp to check");
return true;
}
// Extract the bind_address for requests from the first socket, it should be same for all others too
let bind_address = sockets[0]
.local_addr()
.expect("Sockets should be bound")
.ip();
// This function may get fed multiple sockets bound to the same port.
// In such case we need to know which sockets are bound to each port,
// as only one of them will receive a packet from echo server
let mut ports_to_socks_map: BTreeMap<_, _> = BTreeMap::new();
let mut ip_to_ports: HashMap<IpAddr, BTreeMap<u16, Vec<&UdpSocket>>> = HashMap::new();
for &socket in sockets.iter() {
let local_binding = socket.local_addr().expect("Sockets should be bound");
assert_eq!(
local_binding.ip(),
bind_address,
"All sockets should be bound to the same IP"
);
let port = local_binding.port();
ports_to_socks_map
.entry(port)
.or_insert_with(Vec::new)
let local_addr = socket.local_addr().expect("Socket must be bound");
ip_to_ports
.entry(local_addr.ip())
.or_default()
.entry(local_addr.port())
.or_default()
.push(socket);
}
for (bind_ip, ports_to_socks_map) in ip_to_ports {
let ports: Vec<u16> = ports_to_socks_map.keys().copied().collect();

let ports: Vec<_> = ports_to_socks_map.into_iter().collect();

info!(
"Checking that udp ports {:?} are reachable from {:?}",
ports.iter().map(|(port, _)| port).collect::<Vec<_>>(),
ip_echo_server_addr
);
info!(
"Checking that udp ports {:?} are reachable from bind IP {:?}",
ports, bind_ip
);

'outer: for chunk_to_check in ports.chunks(MAX_PORT_COUNT_PER_MESSAGE) {
let ports_to_check = chunk_to_check
.iter()
.map(|(port, _)| *port)
.collect::<Vec<_>>();
'outer: for chunk_to_check in ports.chunks(MAX_PORT_COUNT_PER_MESSAGE) {
let ports_to_check = chunk_to_check.to_vec();

for attempt in 0..retry_count {
if attempt > 0 {
error!("There are some udp ports with no response!! Retrying...");
}
// clone off the sockets that use ports within our chunk
let sockets_to_check = chunk_to_check.iter().flat_map(|(_, sockets)| {
sockets
for attempt in 0..retry_count {
if attempt > 0 {
error!("There are some udp ports with no response!! Retrying...");
}
// clone off the sockets that use ports within our chunk
let sockets_to_check: Vec<UdpSocket> = ports_to_check
.iter()
.map(|&s| s.try_clone().expect("Unable to clone udp socket"))
});

let _ = ip_echo_server_request_with_binding(
ip_echo_server_addr,
IpEchoServerMessage::new(&[], &ports_to_check),
bind_address,
)
.await
.map_err(|err| warn!("ip_echo_server request failed: {}", err));

let reachable_ports = Arc::new(RwLock::new(HashSet::new()));
// Spawn threads for each socket to check
let mut checkers = JoinSet::new();
for socket in sockets_to_check {
let port = socket.local_addr().expect("Socket should be bound").port();
let reachable_ports = reachable_ports.clone();

// Use blocking API since we have no idea if sockets given to us are nonblocking or not
checkers.spawn_blocking(move || {
let start = Instant::now();

let original_read_timeout = socket.read_timeout().unwrap();
socket
.set_read_timeout(Some(Duration::from_millis(250)))
.unwrap();

loop {
if reachable_ports.read().unwrap().contains(&port)
|| Instant::now().duration_since(start) >= timeout
{
break;
.flat_map(|port| ports_to_socks_map.get(port).unwrap())
.map(|&s| s.try_clone().expect("Unable to clone UDP socket"))
.collect();

let _ = ip_echo_server_request_with_binding(
ip_echo_server_addr,
IpEchoServerMessage::new(&[], &ports_to_check),
bind_ip,
)
.await
.map_err(|err| warn!("ip_echo_server request failed: {}", err));

let reachable_ports = Arc::new(RwLock::new(HashSet::new()));
// Spawn threads for each socket to check
let mut checkers = JoinSet::new();
for socket in sockets_to_check {
let port = socket.local_addr().expect("Socket should be bound").port();
let reachable_ports = reachable_ports.clone();

checkers.spawn_blocking(move || {
let start = Instant::now();

let original_read_timeout = socket.read_timeout().unwrap();
socket
.set_read_timeout(Some(Duration::from_millis(250)))
.unwrap();

loop {
if reachable_ports.read().unwrap().contains(&port)
|| Instant::now().duration_since(start) >= timeout
{
break;
}

let recv_result = socket.recv(&mut [0; 1]);
debug!(
"Waited for incoming datagram on udp/{}: {:?}",
port, recv_result
);

if recv_result.is_ok() {
reachable_ports.write().unwrap().insert(port);
break;
}
}

let recv_result = socket.recv(&mut [0; 1]);
debug!(
"Waited for incoming datagram on udp/{}: {:?}",
port, recv_result
);

if recv_result.is_ok() {
reachable_ports.write().unwrap().insert(port);
break;
}
}
socket.set_read_timeout(original_read_timeout).unwrap();
});
}
socket.set_read_timeout(original_read_timeout).unwrap();
});
}

while let Some(r) = checkers.join_next().await {
r.expect("Threads should exit cleanly");
while let Some(r) = checkers.join_next().await {
r.expect("Threads should exit cleanly");
}
// Might have lost a UDP packet, check that all ports were reached
let reachable_ports = Arc::into_inner(reachable_ports)
.expect("Single owner expected")
.into_inner()
.expect("No threads should hold the lock");
info!(
"checked udp ports: {:?}, reachable udp ports: {:?}",
ports_to_check, reachable_ports
);
if reachable_ports.len() == ports_to_check.len() {
continue 'outer; // starts checking next chunk of ports, if any
}
}

// Might have lost a UDP packet, check that all ports were reached
let reachable_ports = Arc::into_inner(reachable_ports)
.expect("Single owner expected")
.into_inner()
.expect("No threads should hold the lock");
info!(
"checked udp ports: {:?}, reachable udp ports: {:?}",
ports_to_check, reachable_ports
error!(
"Maximum retry count reached. Some ports for IP {} unreachable.",
bind_ip
);
if reachable_ports.len() == ports_to_check.len() {
continue 'outer; // starts checking next chunk of ports, if any
}
return false;
}
error!("Maximum retry count is reached....");
return false;
}
true
}
Loading