From fb19f3a86997af1c8a31a7d5ce6f2b018c9b5a0d Mon Sep 17 00:00:00 2001 From: Ivan Nikulin Date: Wed, 14 Oct 2020 00:02:16 +0100 Subject: [PATCH] feat(client): add `HttpConnector::set_local_addresses` to set both IPv6 and IPv4 local addrs (#2172) Currently HttpConnector::set_local_address method accepts a single argument. Server might not support IPv6 or IPv4. Therefore, the only solution at the moment is to manually perform DNS resolution and pick appropriate local address family. This is inefficient, as leads to 2 DNS lookups per request. This commit allows specifying both IPv4 and IPv6, so connector can decide which one to use based on DNS resolution results. --- Cargo.toml | 3 + src/client/connect/dns.rs | 76 +++++++++------ src/client/connect/http.rs | 188 +++++++++++++++++++++++++++++++------ 3 files changed, 211 insertions(+), 56 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 245aa277f8..fcde338e71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,9 @@ tokio-util = { version = "0.3", features = ["codec"] } tower-util = "0.3" url = "1.0" +[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies] +pnet = "0.25.0" + [features] default = [ "runtime", diff --git a/src/client/connect/dns.rs b/src/client/connect/dns.rs index 613579174a..39ffc71e34 100644 --- a/src/client/connect/dns.rs +++ b/src/client/connect/dns.rs @@ -200,27 +200,33 @@ impl IpAddrs { None } - pub(super) fn split_by_preference(self, local_addr: Option) -> (IpAddrs, IpAddrs) { - if let Some(local_addr) = local_addr { - let preferred = self - .iter - .filter(|addr| addr.is_ipv6() == local_addr.is_ipv6()) - .collect(); - - (IpAddrs::new(preferred), IpAddrs::new(vec![])) - } else { - let preferring_v6 = self - .iter - .as_slice() - .first() - .map(SocketAddr::is_ipv6) - .unwrap_or(false); - - let (preferred, fallback) = self - .iter - .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); - - (IpAddrs::new(preferred), IpAddrs::new(fallback)) + #[inline] + fn filter(self, predicate: impl FnMut(&SocketAddr) -> bool) -> IpAddrs { + IpAddrs::new(self.iter.filter(predicate).collect()) + } + + pub(super) fn split_by_preference( + self, + local_addr_ipv4: Option, + local_addr_ipv6: Option, + ) -> (IpAddrs, IpAddrs) { + match (local_addr_ipv4, local_addr_ipv6) { + (Some(_), None) => (self.filter(SocketAddr::is_ipv4), IpAddrs::new(vec![])), + (None, Some(_)) => (self.filter(SocketAddr::is_ipv6), IpAddrs::new(vec![])), + _ => { + let preferring_v6 = self + .iter + .as_slice() + .first() + .map(SocketAddr::is_ipv6) + .unwrap_or(false); + + let (preferred, fallback) = self + .iter + .partition::, _>(|addr| addr.is_ipv6() == preferring_v6); + + (IpAddrs::new(preferred), IpAddrs::new(fallback)) + } } } @@ -355,34 +361,50 @@ mod tests { #[test] fn test_ip_addrs_split_by_preference() { - let v4_addr = (Ipv4Addr::new(127, 0, 0, 1), 80).into(); - let v6_addr = (Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1), 80).into(); + let ip_v4 = Ipv4Addr::new(127, 0, 0, 1); + let ip_v6 = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1); + let v4_addr = (ip_v4, 80).into(); + let v6_addr = (ip_v6, 80).into(); + + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v4_addr, v6_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv4()); + assert!(fallback.next().unwrap().is_ipv6()); + + let (mut preferred, mut fallback) = IpAddrs { + iter: vec![v6_addr, v4_addr].into_iter(), + } + .split_by_preference(None, None); + assert!(preferred.next().unwrap().is_ipv6()); + assert!(fallback.next().unwrap().is_ipv4()); let (mut preferred, mut fallback) = IpAddrs { iter: vec![v4_addr, v6_addr].into_iter(), } - .split_by_preference(None); + .split_by_preference(Some(ip_v4), Some(ip_v6)); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.next().unwrap().is_ipv6()); let (mut preferred, mut fallback) = IpAddrs { iter: vec![v6_addr, v4_addr].into_iter(), } - .split_by_preference(None); + .split_by_preference(Some(ip_v4), Some(ip_v6)); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.next().unwrap().is_ipv4()); let (mut preferred, fallback) = IpAddrs { iter: vec![v4_addr, v6_addr].into_iter(), } - .split_by_preference(Some(v4_addr.ip())); + .split_by_preference(Some(ip_v4), None); assert!(preferred.next().unwrap().is_ipv4()); assert!(fallback.is_empty()); let (mut preferred, fallback) = IpAddrs { iter: vec![v4_addr, v6_addr].into_iter(), } - .split_by_preference(Some(v6_addr.ip())); + .split_by_preference(None, Some(ip_v6)); assert!(preferred.next().unwrap().is_ipv6()); assert!(fallback.is_empty()); } diff --git a/src/client/connect/http.rs b/src/client/connect/http.rs index d61dce3a6a..c1cdf4e129 100644 --- a/src/client/connect/http.rs +++ b/src/client/connect/http.rs @@ -3,7 +3,7 @@ use std::fmt; use std::future::Future; use std::io; use std::marker::PhantomData; -use std::net::{IpAddr, SocketAddr}; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; use std::pin::Pin; use std::sync::Arc; use std::task::{self, Poll}; @@ -72,7 +72,8 @@ struct Config { enforce_http: bool, happy_eyeballs_timeout: Option, keep_alive_timeout: Option, - local_address: Option, + local_address_ipv4: Option, + local_address_ipv6: Option, nodelay: bool, reuse_address: bool, send_buffer_size: Option, @@ -111,7 +112,8 @@ impl HttpConnector { enforce_http: true, happy_eyeballs_timeout: Some(Duration::from_millis(300)), keep_alive_timeout: None, - local_address: None, + local_address_ipv4: None, + local_address_ipv6: None, nodelay: false, reuse_address: false, send_buffer_size: None, @@ -166,7 +168,26 @@ impl HttpConnector { /// Default is `None`. #[inline] pub fn set_local_address(&mut self, addr: Option) { - self.config_mut().local_address = addr; + let (v4, v6) = match addr { + Some(IpAddr::V4(a)) => (Some(a), None), + Some(IpAddr::V6(a)) => (None, Some(a)), + _ => (None, None), + }; + + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = v4; + cfg.local_address_ipv6 = v6; + } + + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + #[inline] + pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) { + let cfg = self.config_mut(); + + cfg.local_address_ipv4 = Some(addr_ipv4); + cfg.local_address_ipv6 = Some(addr_ipv6); } /// Set the connect timeout. @@ -311,7 +332,8 @@ where }; let c = ConnectingTcp::new( - config.local_address, + config.local_address_ipv4, + config.local_address_ipv6, addrs, config.connect_timeout, config.happy_eyeballs_timeout, @@ -454,7 +476,8 @@ impl StdError for ConnectError { } struct ConnectingTcp { - local_addr: Option, + local_addr_ipv4: Option, + local_addr_ipv6: Option, preferred: ConnectingTcpRemote, fallback: Option, reuse_address: bool, @@ -462,17 +485,20 @@ struct ConnectingTcp { impl ConnectingTcp { fn new( - local_addr: Option, + local_addr_ipv4: Option, + local_addr_ipv6: Option, remote_addrs: dns::IpAddrs, connect_timeout: Option, fallback_timeout: Option, reuse_address: bool, ) -> ConnectingTcp { if let Some(fallback_timeout) = fallback_timeout { - let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference(local_addr); + let (preferred_addrs, fallback_addrs) = + remote_addrs.split_by_preference(local_addr_ipv4, local_addr_ipv6); if fallback_addrs.is_empty() { return ConnectingTcp { - local_addr, + local_addr_ipv4, + local_addr_ipv6, preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), fallback: None, reuse_address, @@ -480,7 +506,8 @@ impl ConnectingTcp { } ConnectingTcp { - local_addr, + local_addr_ipv4, + local_addr_ipv6, preferred: ConnectingTcpRemote::new(preferred_addrs, connect_timeout), fallback: Some(ConnectingTcpFallback { delay: tokio::time::delay_for(fallback_timeout), @@ -490,7 +517,8 @@ impl ConnectingTcp { } } else { ConnectingTcp { - local_addr, + local_addr_ipv4, + local_addr_ipv6, preferred: ConnectingTcpRemote::new(remote_addrs, connect_timeout), fallback: None, reuse_address, @@ -523,13 +551,22 @@ impl ConnectingTcpRemote { impl ConnectingTcpRemote { async fn connect( &mut self, - local_addr: &Option, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, reuse_address: bool, ) -> io::Result { let mut err = None; for addr in &mut self.addrs { debug!("connecting to {}", addr); - match connect(&addr, local_addr, reuse_address, self.connect_timeout)?.await { + match connect( + &addr, + local_addr_ipv4, + local_addr_ipv6, + reuse_address, + self.connect_timeout, + )? + .await + { Ok(tcp) => { debug!("connected to {}", addr); return Ok(tcp); @@ -551,9 +588,38 @@ impl ConnectingTcpRemote { } } +fn bind_local_address( + socket: &socket2::Socket, + dst_addr: &SocketAddr, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, +) -> io::Result<()> { + match (*dst_addr, local_addr_ipv4, local_addr_ipv6) { + (SocketAddr::V4(_), Some(addr), _) => { + socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?; + } + (SocketAddr::V6(_), _, Some(addr)) => { + socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?; + } + _ => { + if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match *dst_addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + socket.bind(&any.into())?; + } + } + } + + Ok(()) +} + fn connect( addr: &SocketAddr, - local_addr: &Option, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, reuse_address: bool, connect_timeout: Option, ) -> io::Result>> { @@ -568,17 +634,7 @@ fn connect( socket.set_reuse_address(true)?; } - if let Some(ref local_addr) = *local_addr { - // Caller has requested this socket be bound before calling connect - socket.bind(&SocketAddr::new(local_addr.clone(), 0).into())?; - } else if cfg!(windows) { - // Windows requires a socket be bound before calling connect - let any: SocketAddr = match *addr { - SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), - SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), - }; - socket.bind(&any.into())?; - } + bind_local_address(&socket, addr, local_addr_ipv4, local_addr_ipv6)?; let addr = *addr; @@ -600,17 +656,27 @@ fn connect( impl ConnectingTcp { async fn connect(mut self) -> io::Result { let Self { - ref local_addr, + ref local_addr_ipv4, + ref local_addr_ipv6, reuse_address, .. } = self; match self.fallback { - None => self.preferred.connect(local_addr, reuse_address).await, + None => { + self.preferred + .connect(local_addr_ipv4, local_addr_ipv6, reuse_address) + .await + } Some(mut fallback) => { - let preferred_fut = self.preferred.connect(local_addr, reuse_address); + let preferred_fut = + self.preferred + .connect(local_addr_ipv4, local_addr_ipv6, reuse_address); futures_util::pin_mut!(preferred_fut); - let fallback_fut = fallback.remote.connect(local_addr, reuse_address); + let fallback_fut = + fallback + .remote + .connect(local_addr_ipv4, local_addr_ipv6, reuse_address); futures_util::pin_mut!(fallback_fut); let (result, future) = @@ -666,6 +732,32 @@ mod tests { assert_eq!(&*err.msg, super::INVALID_NOT_HTTP); } + #[cfg(any(target_os = "linux", target_os = "macos"))] + fn get_local_ips() -> (Option, Option) { + use std::net::{IpAddr, TcpListener}; + + let mut ip_v4 = None; + let mut ip_v6 = None; + + let ips = pnet::datalink::interfaces() + .into_iter() + .flat_map(|i| i.ips.into_iter().map(|n| n.ip())); + + for ip in ips { + match ip { + IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip), + IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip), + _ => (), + } + + if ip_v4.is_some() && ip_v6.is_some() { + break; + } + } + + (ip_v4, ip_v6) + } + #[tokio::test] async fn test_errors_missing_scheme() { let dst = "example.domain".parse().unwrap(); @@ -676,6 +768,43 @@ mod tests { assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME); } + // NOTE: pnet crate that we use in this test doesn't compile on Windows + #[cfg(any(target_os = "linux", target_os = "macos"))] + #[tokio::test] + async fn local_address() { + use std::net::{IpAddr, TcpListener}; + + let (bind_ip_v4, bind_ip_v6) = get_local_ips(); + let server4 = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = server4.local_addr().unwrap().port(); + let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap(); + + let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move { + let mut connector = HttpConnector::new(); + + match (bind_ip_v4, bind_ip_v6) { + (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6), + (Some(v4), None) => connector.set_local_address(Some(v4.into())), + (None, Some(v6)) => connector.set_local_address(Some(v6.into())), + _ => unreachable!(), + } + + connect(connector, dst.parse().unwrap()).await.unwrap(); + + let (_, client_addr) = server.accept().unwrap(); + + assert_eq!(client_addr.ip(), expected_ip); + }; + + if let Some(ip) = bind_ip_v4 { + assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await; + } + + if let Some(ip) = bind_ip_v6 { + assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await; + } + } + #[test] #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)] fn client_happy_eyeballs() { @@ -797,6 +926,7 @@ mod tests { .map(|host| (host.clone(), addr.port()).into()) .collect(); let connecting_tcp = ConnectingTcp::new( + None, None, dns::IpAddrs::new(addrs), None,