Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying custom interest for TcpStream #5796

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions tokio/src/io/interest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ impl Interest {
///
/// assert!(BOTH.is_readable());
/// assert!(BOTH.is_writable());
/// ```
pub const fn add(self, other: Interest) -> Interest {
Interest(self.0.add(other.0))
}
Expand All @@ -135,6 +136,12 @@ impl Interest {
}
}

impl Default for Interest {
fn default() -> Self {
Interest::READABLE.add(Interest::WRITABLE)
}
}

Comment on lines +139 to +144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this truly a reasonable default for Interest? I think Interest::READABLE | Interest::WRITABLE written out in full is always clearer.

impl ops::BitOr for Interest {
type Output = Self;

Expand Down
47 changes: 46 additions & 1 deletion tokio/src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,58 @@ impl TcpListener {
/// }
/// ```
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
self.accept_with_interest(Default::default()).await
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is a general practice, but I'd prefer Interest::default() here to make it clearer that we're passing an interest.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also assuming that the Default instance is a good choice, which I'm not sure about.

}

/// Accepts a new incoming connection from this listener with custom
/// interest registration.
///
/// This function will yield once a new TCP connection is established. When
/// established, the corresponding [`TcpStream`] and the remote peer's
/// address will be returned.
///
/// # Cancel safety
///
/// This method is cancel safe. If the method is used as the event in a
/// [`tokio::select!`](crate::select) statement and some other branch
/// completes first, then it is guaranteed that no new connections were
/// accepted by this method.
///
/// [`TcpStream`]: struct@crate::net::TcpStream
///
/// # Examples
///
/// ```no_run
/// use tokio::{io::Interest, net::TcpListener};
///
/// use std::io;
///
/// #[tokio::main]
/// async fn main() -> io::Result<()> {
/// let listener = TcpListener::bind("127.0.0.1:8080").await?;
///
/// match listener
/// .accept_with_interest(Interest::PRIORITY.add(Default::default()))
/// .await
/// {
/// Ok((_socket, addr)) => println!("new client: {:?}", addr),
/// Err(e) => println!("couldn't get client: {:?}", e),
/// }
///
/// Ok(())
/// }
/// ```
pub async fn accept_with_interest(
&self,
interest: Interest,
) -> io::Result<(TcpStream, SocketAddr)> {
let (mio, addr) = self
.io
.registration()
.async_io(Interest::READABLE, || self.io.accept())
.await?;

let stream = TcpStream::new(mio)?;
let stream = TcpStream::new_with_interest(mio, interest)?;
Ok((stream, addr))
}

Expand Down
2 changes: 1 addition & 1 deletion tokio/src/net/tcp/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ impl TcpSocket {
unsafe { mio::net::TcpStream::from_raw_socket(raw_socket) }
};

TcpStream::connect_mio(mio).await
TcpStream::connect_mio(mio, Default::default()).await
}

/// Converts the socket into a `TcpListener`.
Expand Down
65 changes: 60 additions & 5 deletions tokio/src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,59 @@ impl TcpStream {
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<TcpStream> {
Self::connect_with_interest(addr, Default::default()).await
}

/// Opens a TCP connection to a remote host with custom interest
/// registration..
///
/// `addr` is an address of the remote host. Anything which implements the
/// [`ToSocketAddrs`] trait can be supplied as the address. If `addr`
/// yields multiple addresses, connect will be attempted with each of the
/// addresses until a connection is successful. If none of the addresses
/// result in a successful connection, the error returned from the last
/// connection attempt (the last address) is returned.
///
/// To configure the socket before connecting, you can use the [`TcpSocket`]
/// type.
///
/// [`ToSocketAddrs`]: trait@crate::net::ToSocketAddrs
/// [`TcpSocket`]: struct@crate::net::TcpSocket
///
/// # Examples
///
/// ```no_run
/// use tokio::net::TcpStream;
/// use tokio::io::{AsyncWriteExt, Interest};
/// use std::error::Error;
///
/// #[tokio::main]
/// async fn main() -> Result<(), Box<dyn Error>> {
/// // Connect to a peer
/// let mut stream = TcpStream::connect_with_interest(
/// "127.0.0.1:8080",
/// Interest::PRIORITY.add(Default::default()),
/// )
/// .await?;
///
/// // Write some data.
/// stream.write_all(b"hello world!").await?;
///
/// Ok(())
/// }
/// ```
///
/// The [`write_all`] method is defined on the [`AsyncWriteExt`] trait.
///
/// [`write_all`]: fn@crate::io::AsyncWriteExt::write_all
/// [`AsyncWriteExt`]: trait@crate::io::AsyncWriteExt
pub async fn connect_with_interest<A: ToSocketAddrs>(addr: A, interest: Interest) -> io::Result<TcpStream> {
let addrs = to_socket_addrs(addr).await?;

let mut last_err = None;

for addr in addrs {
match TcpStream::connect_addr(addr).await {
match TcpStream::connect_addr(addr, interest).await {
Ok(stream) => return Ok(stream),
Err(e) => last_err = Some(e),
}
Expand All @@ -132,13 +179,13 @@ impl TcpStream {
}

/// Establishes a connection to the specified `addr`.
async fn connect_addr(addr: SocketAddr) -> io::Result<TcpStream> {
async fn connect_addr(addr: SocketAddr, interest: Interest) -> io::Result<TcpStream> {
let sys = mio::net::TcpStream::connect(addr)?;
TcpStream::connect_mio(sys).await
TcpStream::connect_mio(sys, interest).await
}

pub(crate) async fn connect_mio(sys: mio::net::TcpStream) -> io::Result<TcpStream> {
let stream = TcpStream::new(sys)?;
pub(crate) async fn connect_mio(sys: mio::net::TcpStream, interest: Interest) -> io::Result<TcpStream> {
let stream = TcpStream::new_with_interest(sys, interest)?;

// Once we've connected, wait for the stream to be writable as
// that's when the actual connection has been initiated. Once we're
Expand All @@ -161,6 +208,14 @@ impl TcpStream {
Ok(TcpStream { io })
}

pub(crate) fn new_with_interest(
connected: mio::net::TcpStream,
interest: Interest,
) -> io::Result<TcpStream> {
let io = PollEvented::new_with_interest(connected, interest)?;
Ok(TcpStream { io })
}

/// Creates new `TcpStream` from a `std::net::TcpStream`.
///
/// This function is intended to be used to wrap a TCP stream from the
Expand Down