diff --git a/src/client/connect/dns.rs b/src/client/connect/dns.rs index 17d3c7df67..e50382ab3a 100644 --- a/src/client/connect/dns.rs +++ b/src/client/connect/dns.rs @@ -194,17 +194,25 @@ impl IpAddrs { None } - pub(super) fn split_by_preference(self) -> (IpAddrs, IpAddrs) { - let preferring_v6 = self.iter - .as_slice() - .first() - .map(SocketAddr::is_ipv6) - .unwrap_or(false); - - let (preferred, fallback) = self.iter - .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); - - (IpAddrs::new(preferred), IpAddrs::new(fallback)) + pub(super) fn split_by_preference(self, local_addr: Option) -> (IpAddrs, IpAddrs) { + if let Some(local_addr) = local_addr { + let preferred = self.iter + .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6()) + .collect(); + + (IpAddrs::new(preferred), IpAddrs::new(vec![])) + } else { + let preferring_v6 = self.iter + .as_slice() + .first() + .map(SocketAddr::is_ipv6) + .unwrap_or(false); + + let (preferred, fallback) = self.iter + .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); + + (IpAddrs::new(preferred), IpAddrs::new(fallback)) + } } pub(super) fn is_empty(&self) -> bool { @@ -325,14 +333,24 @@ mod tests { let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into(); let (mut preferred, mut fallback) = - IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(); + IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(None); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.next().unwrap().is_ipv6()); let (mut preferred, mut fallback) = - IpAddrs { iter: vec![v6_addr, v4_addr].into_iter() }.split_by_preference(); + IpAddrs { iter: vec![v6_addr, v4_addr].into_iter() }.split_by_preference(None); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.next().unwrap().is_ipv4()); + + let (mut preferred, fallback) = + IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(Some(v4_addr.ip())); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.is_empty()); + + let (mut preferred, fallback) = + IpAddrs { iter: vec![v4_addr, v6_addr].into_iter() }.split_by_preference(Some(v6_addr.ip())); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.is_empty()); } #[test] diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index 7a4c9cccc7..336945eb26 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -501,7 +501,7 @@ impl ConnectingTcp { reuse_address: bool, ) -> ConnectingTcp { if let Some(fallback_timeout) = fallback_timeout { - let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(); + let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(local_addr); if fallback_addrs.is_empty() { return ConnectingTcp { local_addr,