diff --git a/talpid-core/src/split_tunnel/windows/mod.rs b/talpid-core/src/split_tunnel/windows/mod.rs index 40d7c340ad3c..8d5ae1c7f673 100644 --- a/talpid-core/src/split_tunnel/windows/mod.rs +++ b/talpid-core/src/split_tunnel/windows/mod.rs @@ -5,7 +5,10 @@ mod volume_monitor; mod windows; use crate::{tunnel::TunnelMetadata, tunnel_state_machine::TunnelCommand}; -use futures::channel::{mpsc, oneshot}; +use futures::{ + channel::{mpsc, oneshot}, + StreamExt, +}; use std::{ collections::HashMap, convert::TryFrom, @@ -108,7 +111,6 @@ pub struct SplitTunnel { quit_event: Arc, excluded_processes: Arc>>, _route_change_callback: Option, - daemon_tx: Weak>, async_path_update_in_progress: Arc, route_manager: RouteManagerHandle, } @@ -119,10 +121,20 @@ enum Request { Stop, } type RequestResponseTx = sync_mpsc::Sender>; -type RequestTx = sync_mpsc::Sender<(Request, RequestResponseTx)>; +type RequestTx = sync_mpsc::Sender<(Request, Option)>; const REQUEST_TIMEOUT: Duration = Duration::from_secs(5); +impl Request { + fn request_name(&self) -> &'static str { + match self { + Request::SetPaths(_) => "SetPaths", + Request::RegisterIps(_) => "RegisterIps", + Request::Stop => "Stop", + } + } +} + #[derive(Default, PartialEq, Clone)] struct InterfaceAddresses { tunnel_ipv4: Option, @@ -168,8 +180,12 @@ impl SplitTunnel { ) -> Result { let excluded_processes = Arc::new(RwLock::new(HashMap::new())); - let (request_tx, handle) = - Self::spawn_request_thread(resource_dir, volume_update_rx, excluded_processes.clone())?; + let (request_tx, handle) = Self::spawn_request_thread( + resource_dir, + daemon_tx, + volume_update_rx, + excluded_processes.clone(), + )?; let (event_thread, quit_event) = Self::spawn_event_listener(handle, excluded_processes.clone())?; @@ -180,7 +196,6 @@ impl SplitTunnel { event_thread: Some(event_thread), quit_event, _route_change_callback: None, - daemon_tx, async_path_update_in_progress: Arc::new(AtomicBool::new(false)), excluded_processes, route_manager, @@ -387,6 +402,7 @@ impl SplitTunnel { fn spawn_request_thread( resource_dir: PathBuf, + daemon_tx: Weak>, volume_update_rx: mpsc::UnboundedReceiver<()>, excluded_processes: Arc>>, ) -> Result<(RequestTx, Arc), Error> { @@ -429,6 +445,7 @@ impl SplitTunnel { let mut previous_addresses = InterfaceAddresses::default(); while let Ok((request, response_tx)) = rx.recv() { + let request_name = request.request_name(); let response = match request { Request::SetPaths(paths) => { let mut monitored_paths_guard = monitored_paths.lock().unwrap(); @@ -475,21 +492,65 @@ impl SplitTunnel { } Request::Stop => { if let Err(error) = handle.reset().map_err(Error::ResetError) { - let _ = response_tx.send(Err(error)); + if let Some(response_tx) = response_tx { + let _ = response_tx.send(Err(error)); + } else { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to deinitialize split tunneling" + ) + ); + } continue; } monitored_paths.lock().unwrap().clear(); excluded_processes.write().unwrap().clear(); - let _ = response_tx.send(Ok(())); + if let Some(response_tx) = response_tx { + let _ = response_tx.send(Ok(())); + } // Stop listening to commands break; } }; - if response_tx.send(response).is_err() { - log::error!("A response could not be sent for a completed request"); + + // Handle IOCTL result + + let mut log_response = None; + if let Some(response_tx) = response_tx { + if let Err(error) = response_tx.send(response) { + log::error!( + "A response could not be sent for completed request/ioctl: {}", + request_name + ); + log_response = Some(error.0); + } + } else { + log_response = Some(response); + + if let Some(daemon_tx) = daemon_tx.upgrade() { + log::debug!( + "Entering error state due to failed request/ioctl: {}", + request_name + ); + let _ = daemon_tx.unbounded_send(TunnelCommand::Block( + ErrorStateCause::SplitTunnelError, + )); + } else { + log::error!( + "Cannot handle failed request since tunnel state machine is down" + ); + } + } + if let Some(Err(error)) = log_response { + log::error!( + "Request/ioctl failed: {}\n{}", + request_name, + error.display_chain() + ); } } @@ -547,7 +608,7 @@ impl SplitTunnel { let (response_tx, response_rx) = sync_mpsc::channel(); request_tx - .send((request, response_tx)) + .send((request, Some(response_tx))) .map_err(|_| Error::SplitTunnelDown)?; response_rx @@ -589,7 +650,7 @@ impl SplitTunnel { let wait_task = move || { request_tx - .send((request, response_tx)) + .send((request, Some(response_tx))) .map_err(|_| Error::SplitTunnelDown)?; response_rx.recv().map_err(|_| Error::SplitTunnelDown)? }; @@ -600,7 +661,7 @@ impl SplitTunnel { }); } - /// Instructs the driver to redirect traffic from sockets bound to 0.0.0.0, ::, or the + /// Instructs the driver to redirect connections for sockets bound to 0.0.0.0, ::, or the /// tunnel addresses (if any) to the default route. pub fn set_tunnel_addresses(&mut self, metadata: Option<&TunnelMetadata>) -> Result<(), Error> { let mut tunnel_ipv4 = None; @@ -619,7 +680,6 @@ impl SplitTunnel { let context_mutex = Arc::new(Mutex::new( SplitTunnelDefaultRouteChangeHandlerContext::new( self.request_tx.clone(), - self.daemon_tx.clone(), tunnel_ipv4, tunnel_ipv6, ), @@ -645,7 +705,7 @@ impl SplitTunnel { // could deadlock if the dropped callback is invoked (see `init_context`). .map_err(|_| Error::RegisterRouteChangeCallback)?; - Self::init_context(context)?; + Self::init_context(context, &self.request_tx)?; self._route_change_callback = callback; Ok(()) @@ -653,6 +713,7 @@ impl SplitTunnel { fn init_context( mut context: MutexGuard<'_, SplitTunnelDefaultRouteChangeHandlerContext>, + request_tx: &RequestTx, ) -> Result<(), Error> { // NOTE: This should remain a separate function. Dropping the context after `callback` // causes a deadlock if `split_tunnel_default_route_change_handler` is called at the same @@ -661,15 +722,21 @@ impl SplitTunnel { // to complete. context.initialize_internet_addresses()?; - context.register_ips() + SplitTunnel::send_request_inner(request_tx, Request::RegisterIps(context.addresses.clone())) } - /// Instructs the driver to stop redirecting tunnel traffic and INADDR_ANY. + /// Instructs the driver to stop redirecting connections. pub fn clear_tunnel_addresses(&mut self) -> Result<(), Error> { self._route_change_callback = None; self.send_request(Request::RegisterIps(InterfaceAddresses::default())) } + /// Returns whether connections are being redirected. + pub fn has_tunnel_addresses(&self) -> bool { + // NOTE: Relying on assumption that `set_tunnel_addresses` was used here. + self._route_change_callback.is_some() + } + /// Returns a handle used for interacting with the split tunnel module. pub fn handle(&self) -> SplitTunnelHandle { SplitTunnelHandle { @@ -700,21 +767,27 @@ impl Drop for SplitTunnel { } struct SplitTunnelDefaultRouteChangeHandlerContext { - request_tx: RequestTx, - pub daemon_tx: Weak>, + tx: mpsc::UnboundedSender, + abort_handle: tokio::task::JoinHandle<()>, pub addresses: InterfaceAddresses, } +impl Drop for SplitTunnelDefaultRouteChangeHandlerContext { + fn drop(&mut self) { + self.abort_handle.abort(); + } +} + impl SplitTunnelDefaultRouteChangeHandlerContext { pub fn new( request_tx: RequestTx, - daemon_tx: Weak>, tunnel_ipv4: Option, tunnel_ipv6: Option, ) -> Self { + let (tx, abort_handle) = Self::create_burst_guard(request_tx); SplitTunnelDefaultRouteChangeHandlerContext { - request_tx, - daemon_tx, + tx, + abort_handle, addresses: InterfaceAddresses { tunnel_ipv4, tunnel_ipv6, @@ -724,11 +797,58 @@ impl SplitTunnelDefaultRouteChangeHandlerContext { } } - pub fn register_ips(&self) -> Result<(), Error> { - SplitTunnel::send_request_inner( - &self.request_tx, - Request::RegisterIps(self.addresses.clone()), - ) + fn create_burst_guard( + request_tx: RequestTx, + ) -> ( + mpsc::UnboundedSender, + tokio::task::JoinHandle<()>, + ) { + let (tx, mut rx) = mpsc::unbounded(); + + let send_request = move |addresses| { + if request_tx + .send((Request::RegisterIps(addresses), None)) + .is_err() + { + log::error!("Split tunnel request thread is down"); + } + }; + + let abort_handle = tokio::spawn(async move { + const GRACE_PERIOD: Duration = Duration::from_secs(5); + const MAX_PERIOD: Duration = Duration::from_secs(10); + + while let Some(mut addresses) = rx.next().await { + let initial_time = tokio::time::Instant::now(); + loop { + if initial_time.elapsed() >= MAX_PERIOD { + send_request(addresses); + break; + } + + let next = rx.next(); + let delay = tokio::time::sleep(GRACE_PERIOD); + futures::pin_mut!(delay); + + match futures::future::select(next, delay).await { + futures::future::Either::Left((Some(new_addresses), _)) => { + // TODO: combine? + addresses = new_addresses; + continue; + } + futures::future::Either::Left((None, _)) => { + // Return from function + return; + } + futures::future::Either::Right((..)) => { + send_request(addresses); + break; + } + } + } + } + }); + (tx, abort_handle) } pub fn initialize_internet_addresses(&mut self) -> Result<(), Error> { @@ -781,14 +901,9 @@ fn split_tunnel_default_route_change_handler( // Update the "internet interface" IP when best default route changes let mut ctx = ctx_mutex.lock().expect("ST route handler mutex poisoned"); - let daemon_tx = ctx.daemon_tx.upgrade(); - let maybe_send = move |content| { - if let Some(tx) = daemon_tx { - let _ = tx.unbounded_send(content); - } - }; + let prev_addrs = ctx.addresses.clone(); - let result = match event_type { + match event_type { Updated(default_route) | UpdatedDetails(default_route) => { match get_ip_address_for_interface(address_family, default_route.iface) { Ok(Some(ip)) => match ip { @@ -797,14 +912,6 @@ fn split_tunnel_default_route_change_handler( }, Ok(None) => { log::warn!("Failed to obtain default route interface address"); - match address_family { - AddressFamily::Ipv4 => { - ctx.addresses.internet_ipv4 = None; - } - AddressFamily::Ipv6 => { - ctx.addresses.internet_ipv6 = None; - } - } } Err(error) => { log::error!( @@ -813,32 +920,25 @@ fn split_tunnel_default_route_change_handler( "Failed to obtain default route interface address" ) ); - maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); - return; } }; - - ctx.register_ips() } // no default route - Removed => { - match address_family { - AddressFamily::Ipv4 => { - ctx.addresses.internet_ipv4 = None; - } - AddressFamily::Ipv6 => { - ctx.addresses.internet_ipv6 = None; - } + Removed => match address_family { + AddressFamily::Ipv4 => { + ctx.addresses.internet_ipv4 = None; } - ctx.register_ips() - } - }; + AddressFamily::Ipv6 => { + ctx.addresses.internet_ipv6 = None; + } + }, + } - if let Err(error) = result { - log::error!( - "{}", - error.display_chain_with_msg("Failed to register new addresses in split tunnel driver") - ); - maybe_send(TunnelCommand::Block(ErrorStateCause::SplitTunnelError)); + if prev_addrs == ctx.addresses { + return; + } + + if ctx.tx.unbounded_send(ctx.addresses.clone()).is_err() { + log::error!("Split tunnel request thread is down"); } } diff --git a/talpid-core/src/tunnel_state_machine/connecting_state.rs b/talpid-core/src/tunnel_state_machine/connecting_state.rs index 2bfc83e55872..6007409b8729 100644 --- a/talpid-core/src/tunnel_state_machine/connecting_state.rs +++ b/talpid-core/src/tunnel_state_machine/connecting_state.rs @@ -406,6 +406,8 @@ impl ConnectingState { AfterDisconnect::Block(ErrorStateCause::AuthFailed(reason)), ), Some((TunnelEvent::InterfaceUp(metadata, allowed_tunnel_traffic), _done_tx)) => { + // NOTE: It is crucial to set the correct tunnel IP before allowing any traffic into + // the tunnel, as leaks into the tunnel are possible otherwise. #[cfg(windows)] if let Err(error) = shared_values .split_tunnel @@ -550,18 +552,6 @@ impl TunnelState for ConnectingState { ErrorState::enter(shared_values, ErrorStateCause::TunnelParameterError(err)) } Ok(tunnel_parameters) => { - #[cfg(windows)] - if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to reset addresses in split tunnel driver" - ) - ); - - return ErrorState::enter(shared_values, ErrorStateCause::SplitTunnelError); - } - if let Err(error) = Self::set_firewall_policy( shared_values, &tunnel_parameters, diff --git a/talpid-core/src/tunnel_state_machine/disconnected_state.rs b/talpid-core/src/tunnel_state_machine/disconnected_state.rs index 92a05862451b..5cef198be274 100644 --- a/talpid-core/src/tunnel_state_machine/disconnected_state.rs +++ b/talpid-core/src/tunnel_state_machine/disconnected_state.rs @@ -51,7 +51,15 @@ impl DisconnectedState { shared_values: &mut SharedTunnelStateValues, should_reset_firewall: bool, ) { - if should_reset_firewall && !shared_values.block_when_disconnected { + if shared_values.block_when_disconnected { + if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { + log::error!( + "{}", + error + .display_chain_with_msg("Failed to reset addresses in split tunnel driver") + ); + } + } else if should_reset_firewall { if let Err(error) = shared_values.split_tunnel.clear_tunnel_addresses() { log::error!( "{}", @@ -60,11 +68,6 @@ impl DisconnectedState { ) ); } - } else if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { - log::error!( - "{}", - error.display_chain_with_msg("Failed to reset addresses in split tunnel driver") - ); } } diff --git a/talpid-core/src/tunnel_state_machine/error_state.rs b/talpid-core/src/tunnel_state_machine/error_state.rs index 7fe95c9f673e..706f67fc75aa 100644 --- a/talpid-core/src/tunnel_state_machine/error_state.rs +++ b/talpid-core/src/tunnel_state_machine/error_state.rs @@ -85,13 +85,17 @@ impl TunnelState for ErrorState { block_reason: Self::Bootstrap, ) -> (TunnelStateWrapper, TunnelStateTransition) { #[cfg(windows)] - if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { - log::error!( - "{}", - error.display_chain_with_msg( - "Failed to register addresses with split tunnel driver" - ) - ); + if !block_reason.prevents_split_tunneling() + && !shared_values.split_tunnel.has_tunnel_addresses() + { + if let Err(error) = shared_values.split_tunnel.set_tunnel_addresses(None) { + log::error!( + "{}", + error.display_chain_with_msg( + "Failed to register addresses with split tunnel driver" + ) + ); + } } #[cfg(target_os = "macos")] diff --git a/talpid-types/src/tunnel.rs b/talpid-types/src/tunnel.rs index 3d34f4a96aec..fbea662fdc13 100644 --- a/talpid-types/src/tunnel.rs +++ b/talpid-types/src/tunnel.rs @@ -111,6 +111,11 @@ impl ErrorStateCause { pub fn prevents_filtering_resolver(&self) -> bool { matches!(self, Self::SetDnsError) } + + #[cfg(target_os = "windows")] + pub fn prevents_split_tunneling(&self) -> bool { + matches!(self, Self::SplitTunnelError | Self::IsOffline) + } } /// Errors that can occur when generating tunnel parameters.