diff --git a/src/proxy/inbound_passthrough.rs b/src/proxy/inbound_passthrough.rs index 67d4e90657..e2b8b41464 100644 --- a/src/proxy/inbound_passthrough.rs +++ b/src/proxy/inbound_passthrough.rs @@ -42,10 +42,11 @@ impl InboundPassthrough { pi: Arc, drain: DrainWatcher, ) -> Result { - let listener = pi + let mut listener = pi .socket_factory .tcp_bind(pi.cfg.inbound_plaintext_addr) .map_err(|e| Error::Bind(pi.cfg.inbound_plaintext_addr, e))?; + listener.set_socket_options(Some(pi.cfg.socket_config)); let enable_orig_src = super::maybe_set_transparent(&pi, &listener)?; diff --git a/src/socket.rs b/src/socket.rs index cf1a083ee6..77f2a8d583 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -20,12 +20,11 @@ use tokio::io; use tokio::net::TcpSocket; use tokio::net::{TcpListener, TcpStream}; +use crate::config::SocketConfig; +use socket2::{SockRef, TcpKeepalive}; + #[cfg(target_os = "linux")] -use { - socket2::{Domain, SockRef}, - std::io::ErrorKind, - tracing::warn, -}; +use {socket2::Domain, std::io::ErrorKind, tracing::warn}; #[cfg(target_os = "linux")] pub fn set_freebind_and_transparent(socket: &TcpSocket) -> io::Result<()> { @@ -145,21 +144,50 @@ mod linux { } /// Listener is a wrapper For TCPListener with sane defaults. Notably, setting NODELAY -pub struct Listener(TcpListener); +/// You pass also pass it additional socket options to set on accepted connections. +pub struct Listener { + listener: TcpListener, + cfg: Option, +} impl Listener { pub fn new(l: TcpListener) -> Self { - Self(l) + Listener { + listener: l, + cfg: None, + } } pub fn local_addr(&self) -> SocketAddr { - self.0.local_addr().expect("local_addr is available") + self.listener.local_addr().expect("local_addr is available") } pub fn inner(self) -> TcpListener { - self.0 + self.listener + } + pub fn set_socket_options(&mut self, cfg: Option) { + self.cfg = cfg; } pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> { - let (stream, remote) = self.0.accept().await?; + let (stream, remote) = self.listener.accept().await?; stream.set_nodelay(true)?; + if let Some(cfg) = self.cfg { + if cfg.keepalive_enabled { + let ka = TcpKeepalive::new() + .with_time(cfg.keepalive_time) + .with_retries(cfg.keepalive_retries) + .with_interval(cfg.keepalive_interval); + tracing::trace!( + "set keepalive: {:?}", + SockRef::from(&stream).set_tcp_keepalive(&ka) + ); + } + if cfg.user_timeout_enabled { + let ut = cfg.keepalive_time + cfg.keepalive_retries * cfg.keepalive_interval; + tracing::trace!( + "set user timeout: {:?}", + SockRef::from(&stream).set_tcp_user_timeout(Some(ut)) + ); + } + } Ok((stream, remote)) } } @@ -167,7 +195,7 @@ impl Listener { #[cfg(target_os = "linux")] impl Listener { pub fn set_transparent(&self) -> io::Result<()> { - SockRef::from(&self.0).set_ip_transparent_v4(true) + SockRef::from(&self.listener).set_ip_transparent_v4(true) } }