Skip to content

Commit abb6471

Browse files
authored
refactor(client): use tokio's TcpSocket for more sockopts (#2335)
Signed-off-by: Eliza Weisman <[email protected]>
1 parent bdb5e5d commit abb6471

File tree

2 files changed

+30
-28
lines changed

2 files changed

+30
-28
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ itoa = "0.4.1"
3535
tracing = { version = "0.1", default-features = false, features = ["log", "std"] }
3636
pin-project = "1.0"
3737
tower-service = "0.3"
38-
tokio = { version = "0.3", features = ["sync", "stream"] }
38+
tokio = { version = "0.3.4", features = ["sync", "stream"] }
3939
want = "0.3"
4040

4141
# Optional

src/client/connect/http.rs

+29-27
Original file line numberDiff line numberDiff line change
@@ -569,29 +569,30 @@ fn connect(
569569
connect_timeout: Option<Duration>,
570570
) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
571571
// TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
572-
// keepalive timeout and send/recv buffer size, it would be nice to use that
573-
// instead of socket2, and avoid the unsafe `into_raw_fd`/`from_raw_fd`
574-
// dance...
572+
// keepalive timeout, it would be nice to use that instead of socket2,
573+
// and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
575574
use socket2::{Domain, Protocol, Socket, Type};
575+
use std::convert::TryInto;
576+
576577
let domain = match *addr {
577578
SocketAddr::V4(_) => Domain::ipv4(),
578579
SocketAddr::V6(_) => Domain::ipv6(),
579580
};
580581
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))
581582
.map_err(ConnectError::m("tcp open error"))?;
582583

583-
if config.reuse_address {
584-
socket
585-
.set_reuse_address(true)
586-
.map_err(ConnectError::m("tcp set_reuse_address error"))?;
587-
}
588-
589584
// When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
590585
// responsible for ensuring O_NONBLOCK is set.
591586
socket
592587
.set_nonblocking(true)
593588
.map_err(ConnectError::m("tcp set_nonblocking error"))?;
594589

590+
if let Some(dur) = config.keep_alive_timeout {
591+
socket
592+
.set_keepalive(Some(dur))
593+
.map_err(ConnectError::m("tcp set_keepalive error"))?;
594+
}
595+
595596
bind_local_address(
596597
&socket,
597598
addr,
@@ -600,24 +601,6 @@ fn connect(
600601
)
601602
.map_err(ConnectError::m("tcp bind local error"))?;
602603

603-
if let Some(dur) = config.keep_alive_timeout {
604-
socket
605-
.set_keepalive(Some(dur))
606-
.map_err(ConnectError::m("tcp set_keepalive error"))?;
607-
}
608-
609-
if let Some(size) = config.send_buffer_size {
610-
socket
611-
.set_send_buffer_size(size)
612-
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
613-
}
614-
615-
if let Some(size) = config.recv_buffer_size {
616-
socket
617-
.set_recv_buffer_size(size)
618-
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
619-
}
620-
621604
#[cfg(unix)]
622605
let socket = unsafe {
623606
// Safety: `from_raw_fd` is only safe to call if ownership of the raw
@@ -636,6 +619,25 @@ fn connect(
636619
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
637620
TcpSocket::from_raw_socket(socket.into_raw_socket())
638621
};
622+
623+
if config.reuse_address {
624+
socket
625+
.set_reuseaddr(true)
626+
.map_err(ConnectError::m("tcp set_reuse_address error"))?;
627+
}
628+
629+
if let Some(size) = config.send_buffer_size {
630+
socket
631+
.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX))
632+
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
633+
}
634+
635+
if let Some(size) = config.recv_buffer_size {
636+
socket
637+
.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX))
638+
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
639+
}
640+
639641
let connect = socket.connect(*addr);
640642
Ok(async move {
641643
match connect_timeout {

0 commit comments

Comments
 (0)