diff --git a/tokio/src/net/unix/listener.rs b/tokio/src/net/unix/listener.rs index 79b554ee1ab..d2ede2c836b 100644 --- a/tokio/src/net/unix/listener.rs +++ b/tokio/src/net/unix/listener.rs @@ -3,8 +3,12 @@ use crate::net::unix::{SocketAddr, UnixStream}; use std::fmt; use std::io; +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt; +#[cfg(target_os = "linux")] +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}; -use std::os::unix::net; +use std::os::unix::net::{self, SocketAddr as StdSocketAddr}; use std::path::Path; use std::task::{Context, Poll}; @@ -70,7 +74,20 @@ impl UnixListener { where P: AsRef, { - let listener = mio::net::UnixListener::bind(path)?; + // For now, we handle abstract socket paths on linux here. + #[cfg(target_os = "linux")] + let addr = { + let os_str_bytes = path.as_ref().as_os_str().as_bytes(); + if os_str_bytes.starts_with(b"\0") { + StdSocketAddr::from_abstract_name(os_str_bytes)? + } else { + StdSocketAddr::from_pathname(path)? + } + }; + #[cfg(not(target_os = "linux"))] + let addr = StdSocketAddr::from_pathname(path)?; + + let listener = mio::net::UnixListener::bind_addr(&addr)?; let io = PollEvented::new(listener)?; Ok(UnixListener { io }) } diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 60d58139699..63a02f46777 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -8,8 +8,12 @@ use crate::net::unix::SocketAddr; use std::fmt; use std::io::{self, Read, Write}; use std::net::Shutdown; +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt; +#[cfg(target_os = "linux")] +use std::os::unix::ffi::OsStrExt; use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd}; -use std::os::unix::net; +use std::os::unix::net::{self, SocketAddr as StdSocketAddr}; use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; @@ -66,7 +70,20 @@ impl UnixStream { where P: AsRef, { - let stream = mio::net::UnixStream::connect(path)?; + // On linux, abstract socket paths need to be considered. + #[cfg(target_os = "linux")] + let addr = { + let os_str_bytes = path.as_ref().as_os_str().as_bytes(); + if os_str_bytes.starts_with(b"\0") { + StdSocketAddr::from_abstract_name(os_str_bytes)? + } else { + StdSocketAddr::from_pathname(path)? + } + }; + #[cfg(not(target_os = "linux"))] + let addr = StdSocketAddr::from_pathname(path)?; + + let stream = mio::net::UnixStream::connect_addr(&addr)?; let stream = UnixStream::new(stream)?; poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?; diff --git a/tokio/tests/uds_stream.rs b/tokio/tests/uds_stream.rs index b8c4e6a8eed..48d29287747 100644 --- a/tokio/tests/uds_stream.rs +++ b/tokio/tests/uds_stream.rs @@ -409,3 +409,16 @@ async fn epollhup() -> io::Result<()> { assert_eq!(err.kind(), io::ErrorKind::ConnectionReset); Ok(()) } + +// test for https://github.com/tokio-rs/tokio/issues/6767 +#[tokio::test] +#[cfg(target_os = "linux")] +async fn abstract_socket_name() { + let socket_path = "\0aaa"; + let listener = UnixListener::bind(socket_path).unwrap(); + + let accept = listener.accept(); + let connect = UnixStream::connect(&socket_path); + + try_join(accept, connect).await.unwrap(); +}