Skip to content

Commit

Permalink
refactor(client): use tokio's TcpSocket for more sockopts (#2335)
Browse files Browse the repository at this point in the history
Signed-off-by: Eliza Weisman <[email protected]>
  • Loading branch information
hawkw authored Nov 18, 2020
1 parent bdb5e5d commit abb6471
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 28 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ itoa = "0.4.1"
tracing = { version = "0.1", default-features = false, features = ["log", "std"] }
pin-project = "1.0"
tower-service = "0.3"
tokio = { version = "0.3", features = ["sync", "stream"] }
tokio = { version = "0.3.4", features = ["sync", "stream"] }
want = "0.3"

# Optional
Expand Down
56 changes: 29 additions & 27 deletions src/client/connect/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -569,29 +569,30 @@ fn connect(
connect_timeout: Option<Duration>,
) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
// TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
// keepalive timeout and send/recv buffer size, it would be nice to use that
// instead of socket2, and avoid the unsafe `into_raw_fd`/`from_raw_fd`
// dance...
// keepalive timeout, it would be nice to use that instead of socket2,
// and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
use socket2::{Domain, Protocol, Socket, Type};
use std::convert::TryInto;

let domain = match *addr {
SocketAddr::V4(_) => Domain::ipv4(),
SocketAddr::V6(_) => Domain::ipv6(),
};
let socket = Socket::new(domain, Type::stream(), Some(Protocol::tcp()))
.map_err(ConnectError::m("tcp open error"))?;

if config.reuse_address {
socket
.set_reuse_address(true)
.map_err(ConnectError::m("tcp set_reuse_address error"))?;
}

// When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
// responsible for ensuring O_NONBLOCK is set.
socket
.set_nonblocking(true)
.map_err(ConnectError::m("tcp set_nonblocking error"))?;

if let Some(dur) = config.keep_alive_timeout {
socket
.set_keepalive(Some(dur))
.map_err(ConnectError::m("tcp set_keepalive error"))?;
}

bind_local_address(
&socket,
addr,
Expand All @@ -600,24 +601,6 @@ fn connect(
)
.map_err(ConnectError::m("tcp bind local error"))?;

if let Some(dur) = config.keep_alive_timeout {
socket
.set_keepalive(Some(dur))
.map_err(ConnectError::m("tcp set_keepalive error"))?;
}

if let Some(size) = config.send_buffer_size {
socket
.set_send_buffer_size(size)
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
}

if let Some(size) = config.recv_buffer_size {
socket
.set_recv_buffer_size(size)
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
}

#[cfg(unix)]
let socket = unsafe {
// Safety: `from_raw_fd` is only safe to call if ownership of the raw
Expand All @@ -636,6 +619,25 @@ fn connect(
use std::os::windows::io::{FromRawSocket, IntoRawSocket};
TcpSocket::from_raw_socket(socket.into_raw_socket())
};

if config.reuse_address {
socket
.set_reuseaddr(true)
.map_err(ConnectError::m("tcp set_reuse_address error"))?;

This comment has been minimized.

Copy link
@Sh4rK

Sh4rK Nov 25, 2020

set_reuse_address in the error message is not updated to set_reuseaddr which is the actual method name :)

}

if let Some(size) = config.send_buffer_size {
socket
.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX))
.map_err(ConnectError::m("tcp set_send_buffer_size error"))?;
}

if let Some(size) = config.recv_buffer_size {
socket
.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX))
.map_err(ConnectError::m("tcp set_recv_buffer_size error"))?;
}

let connect = socket.connect(*addr);
Ok(async move {
match connect_timeout {
Expand Down

0 comments on commit abb6471

Please sign in to comment.