Skip to content

Commit afe849e

Browse files
committed
net: add handling for abstract socket name
1 parent 1077b0b commit afe849e

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

tokio/src/net/unix/listener.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@ use crate::net::unix::{SocketAddr, UnixStream};
33

44
use std::fmt;
55
use std::io;
6+
#[cfg(target_os = "linux")]
7+
use std::os::linux::net::SocketAddrExt;
8+
#[cfg(target_os = "linux")]
9+
use std::os::unix::ffi::OsStrExt;
610
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
7-
use std::os::unix::net;
11+
use std::os::unix::net::{self, SocketAddr as StdSocketAddr};
812
use std::path::Path;
913
use std::task::{Context, Poll};
1014

@@ -70,7 +74,20 @@ impl UnixListener {
7074
where
7175
P: AsRef<Path>,
7276
{
73-
let listener = mio::net::UnixListener::bind(path)?;
77+
// On linux, abstract socket paths need to be considered.
78+
#[cfg(target_os = "linux")]
79+
let addr = {
80+
let os_str_bytes = path.as_ref().as_os_str().as_bytes();
81+
if os_str_bytes.starts_with(b"\0") {
82+
StdSocketAddr::from_abstract_name(os_str_bytes)?
83+
} else {
84+
StdSocketAddr::from_pathname(path)?
85+
}
86+
};
87+
#[cfg(not(target_os = "linux"))]
88+
let addr = StdSocketAddr::from_pathname(path)?;
89+
90+
let listener = mio::net::UnixListener::bind_addr(&addr)?;
7491
let io = PollEvented::new(listener)?;
7592
Ok(UnixListener { io })
7693
}

tokio/src/net/unix/stream.rs

+19-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ use crate::net::unix::SocketAddr;
88
use std::fmt;
99
use std::io::{self, Read, Write};
1010
use std::net::Shutdown;
11+
#[cfg(target_os = "linux")]
12+
use std::os::linux::net::SocketAddrExt;
13+
#[cfg(target_os = "linux")]
14+
use std::os::unix::ffi::OsStrExt;
1115
use std::os::unix::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, RawFd};
12-
use std::os::unix::net;
16+
use std::os::unix::net::{self, SocketAddr as StdSocketAddr};
1317
use std::path::Path;
1418
use std::pin::Pin;
1519
use std::task::{Context, Poll};
@@ -66,7 +70,20 @@ impl UnixStream {
6670
where
6771
P: AsRef<Path>,
6872
{
69-
let stream = mio::net::UnixStream::connect(path)?;
73+
// On linux, abstract socket paths need to be considered.
74+
#[cfg(target_os = "linux")]
75+
let addr = {
76+
let os_str_bytes = path.as_ref().as_os_str().as_bytes();
77+
if os_str_bytes.starts_with(b"\0") {
78+
StdSocketAddr::from_abstract_name(os_str_bytes)?
79+
} else {
80+
StdSocketAddr::from_pathname(path)?
81+
}
82+
};
83+
#[cfg(not(target_os = "linux"))]
84+
let addr = StdSocketAddr::from_pathname(path)?;
85+
86+
let stream = mio::net::UnixStream::connect_addr(&addr)?;
7087
let stream = UnixStream::new(stream)?;
7188

7289
poll_fn(|cx| stream.io.registration().poll_write_ready(cx)).await?;

tokio/tests/uds_stream.rs

+13
Original file line numberDiff line numberDiff line change
@@ -409,3 +409,16 @@ async fn epollhup() -> io::Result<()> {
409409
assert_eq!(err.kind(), io::ErrorKind::ConnectionReset);
410410
Ok(())
411411
}
412+
413+
// test for https://github.com/tokio-rs/tokio/issues/6767
414+
#[tokio::test]
415+
#[cfg(target_os = "linux")]
416+
async fn abstract_socket_name() {
417+
let socket_path = "\0aaa";
418+
let listener = UnixListener::bind(socket_path).unwrap();
419+
420+
let accept = listener.accept();
421+
let connect = UnixStream::connect(&socket_path);
422+
423+
try_join(accept, connect).await.unwrap();
424+
}

0 commit comments

Comments
 (0)