diff --git a/book/src/connection.md b/book/src/connection.md index af9bf90ec..bee7d0034 100644 --- a/book/src/connection.md +++ b/book/src/connection.md @@ -59,7 +59,7 @@ let (client_conn, server_conn) = futures_util::try_join!( // Client Builder::unix_stream(p0).p2p().build(), // Server - Builder::unix_stream(p1).server(&guid).p2p().build(), + Builder::unix_stream(p1).server(guid)?.p2p().build(), )?; # } # @@ -72,7 +72,7 @@ let (client_conn, server_conn) = futures_util::try_join!( [PID1]: https://www.freedesktop.org/wiki/Software/systemd/dbus/ [`futures::stream::Stream`]: https://docs.rs/futures/latest/futures/stream/trait.Stream.html [`MessageStream`]: https://docs.rs/zbus/latest/zbus/struct.MessageStream.html -[`Builder::address`]: https://docs.rs/zbus/latest/zbus/connection/struct.ConnectionBuilder.html#method.address +[`connection::Builder::address`]: https://docs.rs/zbus/latest/zbus/connection/struct.ConnectionBuilder.html#method.address [dspec]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses [^bus-less] Unless you implemented them, none of the bus methods will exist. diff --git a/zbus/src/address.rs b/zbus/src/address.rs deleted file mode 100644 index c4558c083..000000000 --- a/zbus/src/address.rs +++ /dev/null @@ -1,1057 +0,0 @@ -//! D-Bus address handling. -//! -//! Server addresses consist of a transport name followed by a colon, and then an optional, -//! comma-separated list of keys and values in the form key=value. -//! -//! See also: -//! -//! * [Server addresses] in the D-Bus specification. -//! -//! [Server addresses]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses - -#[cfg(target_os = "macos")] -use crate::process::run; -#[cfg(windows)] -use crate::win32::windows_autolaunch_bus_address; -use crate::{Error, Result}; -#[cfg(not(feature = "tokio"))] -use async_io::Async; -#[cfg(all(unix, not(target_os = "macos")))] -use nix::unistd::Uid; -#[cfg(not(feature = "tokio"))] -use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; -#[cfg(all(unix, not(feature = "tokio")))] -use std::os::unix::net::UnixStream; -use std::{collections::HashMap, env, str::FromStr}; -#[cfg(feature = "tokio")] -use tokio::net::TcpStream; -#[cfg(all(unix, feature = "tokio"))] -use tokio::net::UnixStream; -#[cfg(feature = "tokio-vsock")] -use tokio_vsock::VsockStream; -#[cfg(all(windows, not(feature = "tokio")))] -use uds_windows::UnixStream; -#[cfg(all(feature = "vsock", not(feature = "tokio")))] -use vsock::VsockStream; - -use std::{ - ffi::OsString, - fmt::{Display, Formatter}, - str::from_utf8_unchecked, -}; - -/// A `tcp:` address family. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum TcpAddressFamily { - Ipv4, - Ipv6, -} - -/// A `tcp:` D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct TcpAddress { - pub(crate) host: String, - pub(crate) bind: Option, - pub(crate) port: u16, - pub(crate) family: Option, -} - -impl TcpAddress { - /// Returns the `tcp:` address `host` value. - pub fn host(&self) -> &str { - &self.host - } - - /// Returns the `tcp:` address `bind` value. - pub fn bind(&self) -> Option<&str> { - self.bind.as_deref() - } - - /// Returns the `tcp:` address `port` value. - pub fn port(&self) -> u16 { - self.port - } - - /// Returns the `tcp:` address `family` value. - pub fn family(&self) -> Option { - self.family - } - - // Helper for FromStr - fn from_tcp(opts: HashMap<&str, &str>) -> Result { - let bind = None; - if opts.contains_key("bind") { - return Err(Error::Address("`bind` isn't yet supported".into())); - } - - let host = opts - .get("host") - .ok_or_else(|| Error::Address("tcp address is missing `host`".into()))? - .to_string(); - let port = opts - .get("port") - .ok_or_else(|| Error::Address("tcp address is missing `port`".into()))?; - let port = port - .parse::() - .map_err(|_| Error::Address("invalid tcp `port`".into()))?; - let family = opts - .get("family") - .map(|f| TcpAddressFamily::from_str(f)) - .transpose()?; - - Ok(Self { - host, - bind, - port, - family, - }) - } - - fn write_options(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - f.write_str("host=")?; - - encode_percents(f, self.host.as_ref())?; - - write!(f, ",port={}", self.port)?; - - if let Some(bind) = &self.bind { - f.write_str(",bind=")?; - encode_percents(f, bind.as_ref())?; - } - - if let Some(family) = &self.family { - write!(f, ",family={family}")?; - } - - Ok(()) - } -} - -#[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" -))] -/// A `tcp:` D-Bus address. -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct VsockAddress { - pub(crate) cid: u32, - pub(crate) port: u32, -} - -#[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" -))] -impl VsockAddress { - /// Create a new VSOCK address. - pub fn new(cid: u32, port: u32) -> Self { - Self { cid, port } - } -} - -/// A bus address -#[derive(Clone, Debug, PartialEq, Eq)] -#[non_exhaustive] -pub enum Address { - /// A path on the filesystem - Unix(OsString), - /// TCP address details - Tcp(TcpAddress), - /// TCP address details with nonce file path - NonceTcp { - addr: TcpAddress, - nonce_file: Vec, - }, - /// Autolaunch address with optional scope - Autolaunch(Option), - /// Launchd address with a required env key - Launchd(String), - #[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" - ))] - /// VSOCK address - /// - /// This variant is only available when either `vsock` or `tokio-vsock` feature is enabled. The - /// type of `stream` is `vsock::VsockStream` with `vsock` feature and - /// `tokio_vsock::VsockStream` with `tokio-vsock` feature. - Vsock(VsockAddress), - /// A listenable address using the specified path, in which a socket file with a random file - /// name starting with 'dbus-' will be created by the server. See [UNIX domain socket address] - /// reference documentation. - /// - /// This address is mostly relevant to server (typically bus broker) implementations. - /// - /// [UNIX domain socket address]: https://dbus.freedesktop.org/doc/dbus-specification.html#transports-unix-domain-sockets-addresses - UnixDir(OsString), - /// The same as UnixDir, except that on platforms with abstract sockets, the server may attempt - /// to create an abstract socket whose name starts with this directory instead of a path-based - /// socket. - /// - /// This address is mostly relevant to server (typically bus broker) implementations. - UnixTmpDir(OsString), -} - -#[cfg(not(feature = "tokio"))] -#[derive(Debug)] -pub(crate) enum Stream { - Unix(Async), - Tcp(Async), - #[cfg(feature = "vsock")] - Vsock(Async), -} - -#[cfg(feature = "tokio")] -#[derive(Debug)] -pub(crate) enum Stream { - #[cfg(unix)] - Unix(UnixStream), - Tcp(TcpStream), - #[cfg(feature = "tokio-vsock")] - Vsock(VsockStream), -} - -#[cfg(not(feature = "tokio"))] -async fn connect_tcp(addr: TcpAddress) -> Result> { - let addrs = crate::Task::spawn_blocking( - move || -> Result> { - let addrs = (addr.host(), addr.port()).to_socket_addrs()?.filter(|a| { - if let Some(family) = addr.family() { - if family == TcpAddressFamily::Ipv4 { - a.is_ipv4() - } else { - a.is_ipv6() - } - } else { - true - } - }); - Ok(addrs.collect()) - }, - "connect tcp", - ) - .await - .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?; - - // we could attempt connections in parallel? - let mut last_err = Error::Address("Failed to connect".into()); - for addr in addrs { - match Async::::connect(addr).await { - Ok(stream) => return Ok(stream), - Err(e) => last_err = e.into(), - } - } - - Err(last_err) -} - -#[cfg(feature = "tokio")] -async fn connect_tcp(addr: TcpAddress) -> Result { - TcpStream::connect((addr.host(), addr.port())) - .await - .map_err(|e| Error::InputOutput(e.into())) -} - -#[cfg(target_os = "macos")] -pub(crate) async fn macos_launchd_bus_address(env_key: &str) -> Result
{ - let output = run("launchctl", ["getenv", env_key]) - .await - .expect("failed to wait on launchctl output"); - - if !output.status.success() { - return Err(crate::Error::Address(format!( - "launchctl terminated with code: {}", - output.status - ))); - } - - let addr = String::from_utf8(output.stdout).map_err(|e| { - crate::Error::Address(format!("Unable to parse launchctl output as UTF-8: {}", e)) - })?; - - format!("unix:path={}", addr.trim()).parse() -} - -impl Address { - #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] - pub(crate) async fn connect(self) -> Result { - match self { - Address::Unix(p) => { - #[cfg(not(feature = "tokio"))] - { - #[cfg(windows)] - { - let stream = crate::Task::spawn_blocking( - move || UnixStream::connect(p), - "unix stream connection", - ) - .await?; - Async::new(stream) - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(not(windows))] - { - Async::::connect(p) - .await - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - } - - #[cfg(feature = "tokio")] - { - #[cfg(unix)] - { - UnixStream::connect(p) - .await - .map(Stream::Unix) - .map_err(|e| Error::InputOutput(e.into())) - } - - #[cfg(not(unix))] - { - let _ = p; - Err(Error::Unsupported) - } - } - } - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - Address::Vsock(addr) => { - let stream = VsockStream::connect_with_cid_port(addr.cid, addr.port)?; - Async::new(stream).map(Stream::Vsock).map_err(Into::into) - } - - #[cfg(feature = "tokio-vsock")] - Address::Vsock(addr) => VsockStream::connect(addr.cid, addr.port) - .await - .map(Stream::Vsock) - .map_err(Into::into), - - Address::Tcp(addr) => connect_tcp(addr).await.map(Stream::Tcp), - - Address::NonceTcp { addr, nonce_file } => { - #[allow(unused_mut)] - let mut stream = connect_tcp(addr).await?; - - #[cfg(unix)] - let nonce_file = { - use std::os::unix::ffi::OsStrExt; - std::ffi::OsStr::from_bytes(&nonce_file) - }; - - #[cfg(windows)] - let nonce_file = std::str::from_utf8(&nonce_file) - .map_err(|_| Error::Address("nonce file path is invalid UTF-8".to_owned()))?; - - #[cfg(not(feature = "tokio"))] - { - let nonce = std::fs::read(nonce_file)?; - let mut nonce = &nonce[..]; - - while !nonce.is_empty() { - let len = stream - .write_with(|mut s| std::io::Write::write(&mut s, nonce)) - .await?; - nonce = &nonce[len..]; - } - } - - #[cfg(feature = "tokio")] - { - let nonce = tokio::fs::read(nonce_file).await?; - tokio::io::AsyncWriteExt::write_all(&mut stream, &nonce).await?; - } - - Ok(Stream::Tcp(stream)) - } - - #[cfg(not(windows))] - Address::Autolaunch(_) => Err(Error::Address( - "Autolaunch addresses are only supported on Windows".to_owned(), - )), - - #[cfg(windows)] - Address::Autolaunch(Some(_)) => Err(Error::Address( - "Autolaunch scopes are currently unsupported".to_owned(), - )), - - #[cfg(windows)] - Address::Autolaunch(None) => { - let addr = windows_autolaunch_bus_address()?; - addr.connect().await - } - - #[cfg(not(target_os = "macos"))] - Address::Launchd(_) => Err(Error::Address( - "Launchd addresses are only supported on macOS".to_owned(), - )), - - #[cfg(target_os = "macos")] - Address::Launchd(env) => { - let addr = macos_launchd_bus_address(&env).await?; - addr.connect().await - } - Address::UnixDir(_) | Address::UnixTmpDir(_) => { - // you can't connect to a unix:dir - Err(Error::Unsupported) - } - } - } - - /// Get the address for session socket respecting the DBUS_SESSION_BUS_ADDRESS environment - /// variable. If we don't recognize the value (or it's not set) we fall back to - /// $XDG_RUNTIME_DIR/bus - pub fn session() -> Result { - match env::var("DBUS_SESSION_BUS_ADDRESS") { - Ok(val) => Self::from_str(&val), - _ => { - #[cfg(windows)] - { - #[cfg(feature = "windows-gdbus")] - return Self::from_str("autolaunch:"); - - #[cfg(not(feature = "windows-gdbus"))] - return Self::from_str("autolaunch:scope=*user"); - } - - #[cfg(all(unix, not(target_os = "macos")))] - { - let runtime_dir = env::var("XDG_RUNTIME_DIR") - .unwrap_or_else(|_| format!("/run/user/{}", Uid::effective())); - let path = format!("unix:path={runtime_dir}/bus"); - - Self::from_str(&path) - } - - #[cfg(target_os = "macos")] - return Self::from_str("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); - } - } - } - - /// Get the address for system bus respecting the DBUS_SYSTEM_BUS_ADDRESS environment - /// variable. If we don't recognize the value (or it's not set) we fall back to - /// /var/run/dbus/system_bus_socket - pub fn system() -> Result { - match env::var("DBUS_SYSTEM_BUS_ADDRESS") { - Ok(val) => Self::from_str(&val), - _ => { - #[cfg(all(unix, not(target_os = "macos")))] - return Self::from_str("unix:path=/var/run/dbus/system_bus_socket"); - - #[cfg(windows)] - return Self::from_str("autolaunch:"); - - #[cfg(target_os = "macos")] - return Self::from_str("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); - } - } - } - - // Helper for FromStr - #[cfg(any(unix, not(feature = "tokio")))] - fn from_unix(opts: HashMap<&str, &str>) -> Result { - let path = opts.get("path"); - let abs = opts.get("abstract"); - let dir = opts.get("dir"); - let tmpdir = opts.get("tmpdir"); - let addr = match (path, abs, dir, tmpdir) { - (Some(p), None, None, None) => Address::Unix(OsString::from(p)), - (None, Some(p), None, None) => { - let mut s = OsString::from("\0"); - s.push(p); - Address::Unix(s) - } - (None, None, Some(p), None) => Address::UnixDir(OsString::from(p)), - (None, None, None, Some(p)) => Address::UnixTmpDir(OsString::from(p)), - _ => { - return Err(Error::Address("unix: address is invalid".to_owned())); - } - }; - - Ok(addr) - } - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - fn from_vsock(opts: HashMap<&str, &str>) -> Result { - let cid = opts - .get("cid") - .ok_or_else(|| Error::Address("VSOCK address is missing cid=".into()))?; - let cid = cid - .parse::() - .map_err(|e| Error::Address(format!("Failed to parse VSOCK cid `{}`: {}", cid, e)))?; - let port = opts - .get("port") - .ok_or_else(|| Error::Address("VSOCK address is missing port=".into()))?; - let port = port - .parse::() - .map_err(|e| Error::Address(format!("Failed to parse VSOCK port `{}`: {}", port, e)))?; - - Ok(Address::Vsock(VsockAddress { cid, port })) - } -} - -impl FromStr for TcpAddressFamily { - type Err = Error; - - fn from_str(family: &str) -> Result { - match family { - "ipv4" => Ok(Self::Ipv4), - "ipv6" => Ok(Self::Ipv6), - _ => Err(Error::Address(format!( - "invalid tcp address `family`: {family}" - ))), - } - } -} - -impl Display for TcpAddressFamily { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::Ipv4 => write!(f, "ipv4"), - Self::Ipv6 => write!(f, "ipv6"), - } - } -} - -fn decode_hex(c: char) -> Result { - match c { - '0'..='9' => Ok(c as u8 - b'0'), - 'a'..='f' => Ok(c as u8 - b'a' + 10), - 'A'..='F' => Ok(c as u8 - b'A' + 10), - - _ => Err(Error::Address( - "invalid hexadecimal character in percent-encoded sequence".to_owned(), - )), - } -} - -fn decode_percents(value: &str) -> Result> { - let mut iter = value.chars(); - let mut decoded = Vec::new(); - - while let Some(c) = iter.next() { - if matches!(c, '-' | '0'..='9' | 'A'..='Z' | 'a'..='z' | '_' | '/' | '.' | '\\' | '*') { - decoded.push(c as u8) - } else if c == '%' { - decoded.push( - decode_hex(iter.next().ok_or_else(|| { - Error::Address("incomplete percent-encoded sequence".to_owned()) - })?)? - << 4 - | decode_hex(iter.next().ok_or_else(|| { - Error::Address("incomplete percent-encoded sequence".to_owned()) - })?)?, - ); - } else { - return Err(Error::Address("Invalid character in address".to_owned())); - } - } - - Ok(decoded) -} - -fn encode_percents(f: &mut Formatter<'_>, mut value: &[u8]) -> std::fmt::Result { - const LOOKUP: &str = "\ -%00%01%02%03%04%05%06%07%08%09%0a%0b%0c%0d%0e%0f\ -%10%11%12%13%14%15%16%17%18%19%1a%1b%1c%1d%1e%1f\ -%20%21%22%23%24%25%26%27%28%29%2a%2b%2c%2d%2e%2f\ -%30%31%32%33%34%35%36%37%38%39%3a%3b%3c%3d%3e%3f\ -%40%41%42%43%44%45%46%47%48%49%4a%4b%4c%4d%4e%4f\ -%50%51%52%53%54%55%56%57%58%59%5a%5b%5c%5d%5e%5f\ -%60%61%62%63%64%65%66%67%68%69%6a%6b%6c%6d%6e%6f\ -%70%71%72%73%74%75%76%77%78%79%7a%7b%7c%7d%7e%7f\ -%80%81%82%83%84%85%86%87%88%89%8a%8b%8c%8d%8e%8f\ -%90%91%92%93%94%95%96%97%98%99%9a%9b%9c%9d%9e%9f\ -%a0%a1%a2%a3%a4%a5%a6%a7%a8%a9%aa%ab%ac%ad%ae%af\ -%b0%b1%b2%b3%b4%b5%b6%b7%b8%b9%ba%bb%bc%bd%be%bf\ -%c0%c1%c2%c3%c4%c5%c6%c7%c8%c9%ca%cb%cc%cd%ce%cf\ -%d0%d1%d2%d3%d4%d5%d6%d7%d8%d9%da%db%dc%dd%de%df\ -%e0%e1%e2%e3%e4%e5%e6%e7%e8%e9%ea%eb%ec%ed%ee%ef\ -%f0%f1%f2%f3%f4%f5%f6%f7%f8%f9%fa%fb%fc%fd%fe%ff"; - - loop { - let pos = value.iter().position( - |c| !matches!(c, b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'_' | b'/' | b'.' | b'\\' | b'*'), - ); - - if let Some(pos) = pos { - // SAFETY: The above `position()` call made sure that only ASCII chars are in the string - // up to `pos` - f.write_str(unsafe { from_utf8_unchecked(&value[..pos]) })?; - - let c = value[pos]; - value = &value[pos + 1..]; - - let pos = c as usize * 3; - f.write_str(&LOOKUP[pos..pos + 3])?; - } else { - // SAFETY: The above `position()` call made sure that only ASCII chars are in the rest - // of the string - f.write_str(unsafe { from_utf8_unchecked(value) })?; - return Ok(()); - } - } -} - -impl Display for Address { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - fn fmt_unix_path( - f: &mut Formatter<'_>, - path: &OsString, - _is_abstract: bool, - ) -> std::fmt::Result { - #[cfg(unix)] - { - use std::os::unix::ffi::OsStrExt; - - let bytes = if _is_abstract { - &path.as_bytes()[1..] - } else { - path.as_bytes() - }; - encode_percents(f, bytes)?; - } - - #[cfg(windows)] - write!(f, "{}", path.to_str().ok_or(std::fmt::Error)?)?; - - Ok(()) - } - - match self { - Self::Tcp(addr) => { - f.write_str("tcp:")?; - addr.write_options(f)?; - } - - Self::NonceTcp { addr, nonce_file } => { - f.write_str("nonce-tcp:noncefile=")?; - encode_percents(f, nonce_file)?; - f.write_str(",")?; - addr.write_options(f)?; - } - - Self::Unix(path) => { - let is_abstract = { - #[cfg(unix)] - { - use std::os::unix::ffi::OsStrExt; - - path.as_bytes().first() == Some(&b'\0') - } - #[cfg(not(unix))] - false - }; - - if is_abstract { - f.write_str("unix:abstract=")?; - } else { - f.write_str("unix:path=")?; - } - - fmt_unix_path(f, path, is_abstract)?; - } - - Self::UnixDir(path) => { - f.write_str("unix:dir=")?; - fmt_unix_path(f, path, false)?; - } - - Self::UnixTmpDir(path) => { - f.write_str("unix:tmpdir=")?; - fmt_unix_path(f, path, false)?; - } - - #[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" - ))] - Self::Vsock(addr) => { - write!(f, "vsock:cid={},port={}", addr.cid, addr.port)?; - } - - Self::Autolaunch(scope) => { - write!(f, "autolaunch:")?; - if let Some(scope) = scope { - write!(f, "scope={scope}")?; - } - } - - Self::Launchd(env) => { - write!(f, "launchd:env={}", env)?; - } - } - - Ok(()) - } -} - -impl FromStr for Address { - type Err = Error; - - /// Parse a D-BUS address and return its path if we recognize it - fn from_str(address: &str) -> Result { - let col = address - .find(':') - .ok_or_else(|| Error::Address("address has no colon".to_owned()))?; - let transport = &address[..col]; - let mut options = HashMap::new(); - - if address.len() > col + 1 { - for kv in address[col + 1..].split(',') { - let (k, v) = match kv.find('=') { - Some(eq) => (&kv[..eq], &kv[eq + 1..]), - None => { - return Err(Error::Address( - "missing = when parsing key/value".to_owned(), - )) - } - }; - if options.insert(k, v).is_some() { - return Err(Error::Address(format!( - "Key `{k}` specified multiple times" - ))); - } - } - } - - match transport { - #[cfg(any(unix, not(feature = "tokio")))] - "unix" => Self::from_unix(options), - "tcp" => TcpAddress::from_tcp(options).map(Self::Tcp), - - "nonce-tcp" => Ok(Self::NonceTcp { - nonce_file: decode_percents( - options - .get("noncefile") - .ok_or_else(|| Error::Address("missing nonce file parameter".into()))?, - )?, - addr: TcpAddress::from_tcp(options)?, - }), - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - "vsock" => Self::from_vsock(options), - "autolaunch" => Ok(Self::Autolaunch( - options - .get("scope") - .map(|scope| -> Result<_> { - String::from_utf8(decode_percents(scope)?).map_err(|_| { - Error::Address("autolaunch scope is not valid UTF-8".to_owned()) - }) - }) - .transpose()?, - )), - "launchd" => Ok(Self::Launchd( - options - .get("env") - .ok_or_else(|| Error::Address("missing env key".into()))? - .to_string(), - )), - - _ => Err(Error::Address(format!( - "unsupported transport '{transport}'" - ))), - } - } -} - -impl TryFrom<&str> for Address { - type Error = Error; - - fn try_from(value: &str) -> Result { - Self::from_str(value) - } -} - -#[cfg(test)] -mod tests { - use super::{Address, TcpAddress, TcpAddressFamily}; - use crate::Error; - use std::str::FromStr; - use test_log::test; - - #[test] - fn parse_dbus_addresses() { - match Address::from_str("").unwrap_err() { - Error::Address(e) => assert_eq!(e, "address has no colon"), - _ => panic!(), - } - match Address::from_str("foo").unwrap_err() { - Error::Address(e) => assert_eq!(e, "address has no colon"), - _ => panic!(), - } - match Address::from_str("foo:opt").unwrap_err() { - Error::Address(e) => assert_eq!(e, "missing = when parsing key/value"), - _ => panic!(), - } - match Address::from_str("foo:opt=1,opt=2").unwrap_err() { - Error::Address(e) => assert_eq!(e, "Key `opt` specified multiple times"), - _ => panic!(), - } - match Address::from_str("tcp:host=localhost").unwrap_err() { - Error::Address(e) => assert_eq!(e, "tcp address is missing `port`"), - _ => panic!(), - } - match Address::from_str("tcp:host=localhost,port=32f").unwrap_err() { - Error::Address(e) => assert_eq!(e, "invalid tcp `port`"), - _ => panic!(), - } - match Address::from_str("tcp:host=localhost,port=123,family=ipv7").unwrap_err() { - Error::Address(e) => assert_eq!(e, "invalid tcp address `family`: ipv7"), - _ => panic!(), - } - match Address::from_str("unix:foo=blah").unwrap_err() { - Error::Address(e) => assert_eq!(e, "unix: address is invalid"), - _ => panic!(), - } - match Address::from_str("unix:path=/tmp,abstract=foo").unwrap_err() { - Error::Address(e) => { - assert_eq!(e, "unix: address is invalid") - } - _ => panic!(), - } - assert_eq!( - Address::Unix("/tmp/dbus-foo".into()), - Address::from_str("unix:path=/tmp/dbus-foo").unwrap() - ); - assert_eq!( - Address::Unix("\0/tmp/dbus-foo".into()), - Address::from_str("unix:abstract=/tmp/dbus-foo").unwrap() - ); - assert_eq!( - Address::Unix("/tmp/dbus-foo".into()), - Address::from_str("unix:path=/tmp/dbus-foo,guid=123").unwrap() - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: None - }), - Address::from_str("tcp:host=localhost,port=4142").unwrap() - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv4) - }), - Address::from_str("tcp:host=localhost,port=4142,family=ipv4").unwrap() - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv6) - }), - Address::from_str("tcp:host=localhost,port=4142,family=ipv6").unwrap() - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv6) - }), - Address::from_str("tcp:host=localhost,port=4142,family=ipv6,noncefile=/a/file/path") - .unwrap() - ); - assert_eq!( - Address::NonceTcp { - addr: TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv6), - }, - nonce_file: b"/a/file/path to file 1234".to_vec() - }, - Address::from_str( - "nonce-tcp:host=localhost,port=4142,family=ipv6,noncefile=/a/file/path%20to%20file%201234" - ) - .unwrap() - ); - assert_eq!( - Address::Autolaunch(None), - Address::from_str("autolaunch:").unwrap() - ); - assert_eq!( - Address::Autolaunch(Some("*my_cool_scope*".to_owned())), - Address::from_str("autolaunch:scope=*my_cool_scope*").unwrap() - ); - assert_eq!( - Address::Launchd("my_cool_env_key".to_owned()), - Address::from_str("launchd:env=my_cool_env_key").unwrap() - ); - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - assert_eq!( - Address::Vsock(crate::VsockAddress { - cid: 98, - port: 2934 - }), - Address::from_str("vsock:cid=98,port=2934,guid=123").unwrap() - ); - assert_eq!( - Address::UnixDir("/some/dir".into()), - Address::from_str("unix:dir=/some/dir").unwrap() - ); - assert_eq!( - Address::UnixTmpDir("/some/dir".into()), - Address::from_str("unix:tmpdir=/some/dir").unwrap() - ); - } - - #[test] - fn stringify_dbus_addresses() { - assert_eq!( - Address::Unix("/tmp/dbus-foo".into()).to_string(), - "unix:path=/tmp/dbus-foo" - ); - assert_eq!( - Address::UnixDir("/tmp/dbus-foo".into()).to_string(), - "unix:dir=/tmp/dbus-foo" - ); - assert_eq!( - Address::UnixTmpDir("/tmp/dbus-foo".into()).to_string(), - "unix:tmpdir=/tmp/dbus-foo" - ); - // FIXME: figure out how to handle abstract on Windows - #[cfg(unix)] - assert_eq!( - Address::Unix("\0/tmp/dbus-foo".into()).to_string(), - "unix:abstract=/tmp/dbus-foo" - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: None - }) - .to_string(), - "tcp:host=localhost,port=4142" - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv4) - }) - .to_string(), - "tcp:host=localhost,port=4142,family=ipv4" - ); - assert_eq!( - Address::Tcp(TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv6) - }) - .to_string(), - "tcp:host=localhost,port=4142,family=ipv6" - ); - assert_eq!( - Address::NonceTcp { - addr: TcpAddress { - host: "localhost".into(), - port: 4142, - bind: None, - family: Some(TcpAddressFamily::Ipv6), - }, - nonce_file: b"/a/file/path to file 1234".to_vec() - } - .to_string(), - "nonce-tcp:noncefile=/a/file/path%20to%20file%201234,host=localhost,port=4142,family=ipv6" - ); - assert_eq!(Address::Autolaunch(None).to_string(), "autolaunch:"); - assert_eq!( - Address::Autolaunch(Some("*my_cool_scope*".to_owned())).to_string(), - "autolaunch:scope=*my_cool_scope*" - ); - assert_eq!( - Address::Launchd("my_cool_key".to_owned()).to_string(), - "launchd:env=my_cool_key" - ); - - #[cfg(all(feature = "vsock", not(feature = "tokio")))] - assert_eq!( - Address::Vsock(crate::VsockAddress { - cid: 98, - port: 2934 - }) - .to_string(), - "vsock:cid=98,port=2934", // no support for guid= yet.. - ); - } - - #[test] - fn connect_tcp() { - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - let addr = Address::from_str(&format!("tcp:host=localhost,port={port}")).unwrap(); - crate::utils::block_on(async { addr.connect().await }).unwrap(); - } - - #[test] - fn connect_nonce_tcp() { - struct PercentEncoded<'a>(&'a [u8]); - - impl std::fmt::Display for PercentEncoded<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - super::encode_percents(f, self.0) - } - } - - use std::io::Write; - - const TEST_COOKIE: &[u8] = b"VERILY SECRETIVE"; - - let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - let port = listener.local_addr().unwrap().port(); - - let mut cookie = tempfile::NamedTempFile::new().unwrap(); - cookie.as_file_mut().write_all(TEST_COOKIE).unwrap(); - - let encoded_path = format!( - "{}", - PercentEncoded(cookie.path().to_str().unwrap().as_ref()) - ); - - let addr = Address::from_str(&format!( - "nonce-tcp:host=localhost,port={port},noncefile={encoded_path}" - )) - .unwrap(); - - let (sender, receiver) = std::sync::mpsc::sync_channel(1); - - std::thread::spawn(move || { - use std::io::Read; - - let mut client = listener.incoming().next().unwrap().unwrap(); - - let mut buf = [0u8; 16]; - client.read_exact(&mut buf).unwrap(); - - sender.send(buf == TEST_COOKIE).unwrap(); - }); - - crate::utils::block_on(addr.connect()).unwrap(); - - let saw_cookie = receiver - .recv_timeout(std::time::Duration::from_millis(100)) - .expect("nonce file content hasn't been received by server thread in time"); - - assert!( - saw_cookie, - "nonce file content has been received, but was invalid" - ); - } -} diff --git a/zbus/src/address/mod.rs b/zbus/src/address/mod.rs new file mode 100644 index 000000000..27175a910 --- /dev/null +++ b/zbus/src/address/mod.rs @@ -0,0 +1,483 @@ +//! D-Bus address handling. +//! +//! Server addresses consist of a transport name followed by a colon, and then an optional, +//! comma-separated list of keys and values in the form key=value. +//! +//! See also: +//! +//! * [Server addresses] in the D-Bus specification. +//! +//! [Server addresses]: https://dbus.freedesktop.org/doc/dbus-specification.html#addresses + +pub mod transport; + +use crate::{Error, Guid, OwnedGuid, Result}; +#[cfg(all(unix, not(target_os = "macos")))] +use nix::unistd::Uid; +use std::{collections::HashMap, env, str::FromStr}; + +use std::fmt::{Display, Formatter}; + +use self::transport::Stream; +pub use self::transport::Transport; + +/// A bus address +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub struct Address { + guid: Option, + transport: Transport, +} + +impl Address { + /// Create a new `Address` from a `Transport`. + pub fn new(transport: Transport) -> Self { + Self { + transport, + guid: None, + } + } + + /// Set the GUID for this address. + pub fn set_guid(mut self, guid: G) -> Result + where + G: TryInto, + G::Error: Into, + { + self.guid = Some(guid.try_into().map_err(Into::into)?); + + Ok(self) + } + + /// The transport details for this address. + pub fn transport(&self) -> &Transport { + &self.transport + } + + #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] + pub(crate) async fn connect(self) -> Result { + self.transport.connect().await + } + + /// Get the address for session socket respecting the DBUS_SESSION_BUS_ADDRESS environment + /// variable. If we don't recognize the value (or it's not set) we fall back to + /// $XDG_RUNTIME_DIR/bus + pub fn session() -> Result { + match env::var("DBUS_SESSION_BUS_ADDRESS") { + Ok(val) => Self::from_str(&val), + _ => { + #[cfg(windows)] + { + #[cfg(feature = "windows-gdbus")] + return Self::from_str("autolaunch:"); + + #[cfg(not(feature = "windows-gdbus"))] + return Self::from_str("autolaunch:scope=*user"); + } + + #[cfg(all(unix, not(target_os = "macos")))] + { + let runtime_dir = env::var("XDG_RUNTIME_DIR") + .unwrap_or_else(|_| format!("/run/user/{}", Uid::effective())); + let path = format!("unix:path={runtime_dir}/bus"); + + Self::from_str(&path) + } + + #[cfg(target_os = "macos")] + return Self::from_str("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + } + } + } + + /// Get the address for system bus respecting the DBUS_SYSTEM_BUS_ADDRESS environment + /// variable. If we don't recognize the value (or it's not set) we fall back to + /// /var/run/dbus/system_bus_socket + pub fn system() -> Result { + match env::var("DBUS_SYSTEM_BUS_ADDRESS") { + Ok(val) => Self::from_str(&val), + _ => { + #[cfg(all(unix, not(target_os = "macos")))] + return Self::from_str("unix:path=/var/run/dbus/system_bus_socket"); + + #[cfg(windows)] + return Self::from_str("autolaunch:"); + + #[cfg(target_os = "macos")] + return Self::from_str("launchd:env=DBUS_LAUNCHD_SESSION_BUS_SOCKET"); + } + } + } + + /// The GUID for this address, if known. + pub fn guid(&self) -> Option<&Guid<'_>> { + self.guid.as_ref().map(|guid| guid.inner()) + } +} + +impl Display for Address { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.transport.fmt(f)?; + + if let Some(guid) = &self.guid { + write!(f, ",guid={}", guid)?; + } + + Ok(()) + } +} + +impl FromStr for Address { + type Err = Error; + + /// Parse the transport part of a D-Bus address into a `Transport`. + fn from_str(address: &str) -> Result { + let col = address + .find(':') + .ok_or_else(|| Error::Address("address has no colon".to_owned()))?; + let transport = &address[..col]; + let mut options = HashMap::new(); + + if address.len() > col + 1 { + for kv in address[col + 1..].split(',') { + let (k, v) = match kv.find('=') { + Some(eq) => (&kv[..eq], &kv[eq + 1..]), + None => { + return Err(Error::Address( + "missing = when parsing key/value".to_owned(), + )) + } + }; + if options.insert(k, v).is_some() { + return Err(Error::Address(format!( + "Key `{k}` specified multiple times" + ))); + } + } + } + + Ok(Self { + guid: options + .remove("guid") + .map(|s| Guid::from_str(s).map(|guid| OwnedGuid::from(guid).to_owned())) + .transpose()?, + transport: Transport::from_options(transport, options)?, + }) + } +} + +impl TryFrom<&str> for Address { + type Error = Error; + + fn try_from(value: &str) -> Result { + Self::from_str(value) + } +} + +impl From for Address { + fn from(transport: Transport) -> Self { + Self::new(transport) + } +} + +#[cfg(test)] +mod tests { + use super::{ + transport::{Tcp, TcpTransportFamily, Transport}, + Address, + }; + #[cfg(target_os = "macos")] + use crate::address::transport::Launchd; + #[cfg(windows)] + use crate::address::transport::{Autolaunch, AutolaunchScope}; + use crate::{ + address::transport::{Unix, UnixPath}, + Error, + }; + use std::str::FromStr; + use test_log::test; + + #[test] + fn parse_dbus_addresses() { + match Address::from_str("").unwrap_err() { + Error::Address(e) => assert_eq!(e, "address has no colon"), + _ => panic!(), + } + match Address::from_str("foo").unwrap_err() { + Error::Address(e) => assert_eq!(e, "address has no colon"), + _ => panic!(), + } + match Address::from_str("foo:opt").unwrap_err() { + Error::Address(e) => assert_eq!(e, "missing = when parsing key/value"), + _ => panic!(), + } + match Address::from_str("foo:opt=1,opt=2").unwrap_err() { + Error::Address(e) => assert_eq!(e, "Key `opt` specified multiple times"), + _ => panic!(), + } + match Address::from_str("tcp:host=localhost").unwrap_err() { + Error::Address(e) => assert_eq!(e, "tcp address is missing `port`"), + _ => panic!(), + } + match Address::from_str("tcp:host=localhost,port=32f").unwrap_err() { + Error::Address(e) => assert_eq!(e, "invalid tcp `port`"), + _ => panic!(), + } + match Address::from_str("tcp:host=localhost,port=123,family=ipv7").unwrap_err() { + Error::Address(e) => assert_eq!(e, "invalid tcp address `family`: ipv7"), + _ => panic!(), + } + match Address::from_str("unix:foo=blah").unwrap_err() { + Error::Address(e) => assert_eq!(e, "unix: address is invalid"), + _ => panic!(), + } + #[cfg(target_os = "linux")] + match Address::from_str("unix:path=/tmp,abstract=foo").unwrap_err() { + Error::Address(e) => { + assert_eq!(e, "unix: address is invalid") + } + _ => panic!(), + } + assert_eq!( + Address::from_str("unix:path=/tmp/dbus-foo").unwrap(), + Transport::Unix(Unix::new(UnixPath::File("/tmp/dbus-foo".into()))).into(), + ); + #[cfg(target_os = "linux")] + assert_eq!( + Address::from_str("unix:abstract=/tmp/dbus-foo").unwrap(), + Transport::Unix(Unix::new(UnixPath::Abstract("/tmp/dbus-foo".into()))).into(), + ); + let guid = crate::Guid::generate(); + assert_eq!( + Address::from_str(&format!("unix:path=/tmp/dbus-foo,guid={guid}")).unwrap(), + Address::from(Transport::Unix(Unix::new(UnixPath::File( + "/tmp/dbus-foo".into() + )))) + .set_guid(guid.clone()) + .unwrap(), + ); + assert_eq!( + Address::from_str("tcp:host=localhost,port=4142").unwrap(), + Transport::Tcp(Tcp::new("localhost", 4142)).into(), + ); + assert_eq!( + Address::from_str("tcp:host=localhost,port=4142,family=ipv4").unwrap(), + Transport::Tcp(Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv4))) + .into(), + ); + assert_eq!( + Address::from_str("tcp:host=localhost,port=4142,family=ipv6").unwrap(), + Transport::Tcp(Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv6))) + .into(), + ); + assert_eq!( + Address::from_str("tcp:host=localhost,port=4142,family=ipv6,noncefile=/a/file/path") + .unwrap(), + Transport::Tcp( + Tcp::new("localhost", 4142) + .set_family(Some(TcpTransportFamily::Ipv6)) + .set_nonce_file(Some(b"/a/file/path".to_vec())) + ) + .into(), + ); + assert_eq!( + Address::from_str( + "nonce-tcp:host=localhost,port=4142,family=ipv6,noncefile=/a/file/path%20to%20file%201234" + ) + .unwrap(), + Transport::Tcp( + Tcp::new("localhost", 4142) + .set_family(Some(TcpTransportFamily::Ipv6)) + .set_nonce_file(Some(b"/a/file/path to file 1234".to_vec())) + ).into() + ); + #[cfg(windows)] + assert_eq!( + Address::from_str("autolaunch:").unwrap(), + Transport::Autolaunch(Autolaunch::new()).into(), + ); + #[cfg(windows)] + assert_eq!( + Address::from_str("autolaunch:scope=*my_cool_scope*").unwrap(), + Transport::Autolaunch( + Autolaunch::new() + .set_scope(Some(AutolaunchScope::Other("*my_cool_scope*".to_string()))) + ) + .into(), + ); + #[cfg(target_os = "macos")] + assert_eq!( + Address::from_str("launchd:env=my_cool_env_key").unwrap(), + Transport::Launchd(Launchd::new("my_cool_env_key")).into(), + ); + + #[cfg(all(feature = "vsock", not(feature = "tokio")))] + assert_eq!( + Address::from_str(&format!("vsock:cid=98,port=2934,guid={guid}")).unwrap(), + Address::from(Transport::Vsock(super::transport::Vsock::new(98, 2934))) + .set_guid(guid) + .unwrap(), + ); + assert_eq!( + Address::from_str("unix:dir=/some/dir").unwrap(), + Transport::Unix(Unix::new(UnixPath::Dir("/some/dir".into()))).into(), + ); + assert_eq!( + Address::from_str("unix:tmpdir=/some/dir").unwrap(), + Transport::Unix(Unix::new(UnixPath::TmpDir("/some/dir".into()))).into(), + ); + } + + #[test] + fn stringify_dbus_addresses() { + assert_eq!( + Address::from(Transport::Unix(Unix::new(UnixPath::File( + "/tmp/dbus-foo".into() + )))) + .to_string(), + "unix:path=/tmp/dbus-foo", + ); + assert_eq!( + Address::from(Transport::Unix(Unix::new(UnixPath::Dir( + "/tmp/dbus-foo".into() + )))) + .to_string(), + "unix:dir=/tmp/dbus-foo", + ); + assert_eq!( + Address::from(Transport::Unix(Unix::new(UnixPath::TmpDir( + "/tmp/dbus-foo".into() + )))) + .to_string(), + "unix:tmpdir=/tmp/dbus-foo" + ); + // FIXME: figure out how to handle abstract on Windows + #[cfg(target_os = "linux")] + assert_eq!( + Address::from(Transport::Unix(Unix::new(UnixPath::Abstract( + "/tmp/dbus-foo".into() + )))) + .to_string(), + "unix:abstract=/tmp/dbus-foo" + ); + assert_eq!( + Address::from(Transport::Tcp(Tcp::new("localhost", 4142))).to_string(), + "tcp:host=localhost,port=4142" + ); + assert_eq!( + Address::from(Transport::Tcp( + Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv4)) + )) + .to_string(), + "tcp:host=localhost,port=4142,family=ipv4" + ); + assert_eq!( + Address::from(Transport::Tcp( + Tcp::new("localhost", 4142).set_family(Some(TcpTransportFamily::Ipv6)) + )) + .to_string(), + "tcp:host=localhost,port=4142,family=ipv6" + ); + assert_eq!( + Address::from(Transport::Tcp(Tcp::new("localhost", 4142) + .set_family(Some(TcpTransportFamily::Ipv6)) + .set_nonce_file(Some(b"/a/file/path to file 1234".to_vec()) + ))) + .to_string(), + "nonce-tcp:noncefile=/a/file/path%20to%20file%201234,host=localhost,port=4142,family=ipv6" + ); + #[cfg(windows)] + assert_eq!( + Address::from(Transport::Autolaunch(Autolaunch::new())).to_string(), + "autolaunch:" + ); + #[cfg(windows)] + assert_eq!( + Address::from(Transport::Autolaunch(Autolaunch::new().set_scope(Some( + AutolaunchScope::Other("*my_cool_scope*".to_string()) + )))) + .to_string(), + "autolaunch:scope=*my_cool_scope*" + ); + #[cfg(target_os = "macos")] + assert_eq!( + Address::from(Transport::Launchd(Launchd::new("my_cool_key"))).to_string(), + "launchd:env=my_cool_key" + ); + + #[cfg(all(feature = "vsock", not(feature = "tokio")))] + { + let guid = crate::Guid::generate(); + assert_eq!( + Address::from(Transport::Vsock(super::transport::Vsock::new(98, 2934))) + .set_guid(guid.clone()) + .unwrap() + .to_string(), + format!("vsock:cid=98,port=2934,guid={guid}"), + ); + } + } + + #[test] + fn connect_tcp() { + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + let addr = Address::from_str(&format!("tcp:host=localhost,port={port}")).unwrap(); + crate::utils::block_on(async { addr.connect().await }).unwrap(); + } + + #[test] + fn connect_nonce_tcp() { + struct PercentEncoded<'a>(&'a [u8]); + + impl std::fmt::Display for PercentEncoded<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + super::transport::encode_percents(f, self.0) + } + } + + use std::io::Write; + + const TEST_COOKIE: &[u8] = b"VERILY SECRETIVE"; + + let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + + let mut cookie = tempfile::NamedTempFile::new().unwrap(); + cookie.as_file_mut().write_all(TEST_COOKIE).unwrap(); + + let encoded_path = format!( + "{}", + PercentEncoded(cookie.path().to_str().unwrap().as_ref()) + ); + + let addr = Address::from_str(&format!( + "nonce-tcp:host=localhost,port={port},noncefile={encoded_path}" + )) + .unwrap(); + + let (sender, receiver) = std::sync::mpsc::sync_channel(1); + + std::thread::spawn(move || { + use std::io::Read; + + let mut client = listener.incoming().next().unwrap().unwrap(); + + let mut buf = [0u8; 16]; + client.read_exact(&mut buf).unwrap(); + + sender.send(buf == TEST_COOKIE).unwrap(); + }); + + crate::utils::block_on(addr.connect()).unwrap(); + + let saw_cookie = receiver + .recv_timeout(std::time::Duration::from_millis(100)) + .expect("nonce file content hasn't been received by server thread in time"); + + assert!( + saw_cookie, + "nonce file content has been received, but was invalid" + ); + } +} diff --git a/zbus/src/address/transport/autolaunch.rs b/zbus/src/address/transport/autolaunch.rs new file mode 100644 index 000000000..d04245995 --- /dev/null +++ b/zbus/src/address/transport/autolaunch.rs @@ -0,0 +1,66 @@ +use crate::{Error, Result}; +use std::collections::HashMap; + +/// Transport properties of an autolaunch D-Bus address. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Autolaunch { + pub(super) scope: Option, +} + +impl Autolaunch { + /// Create a new autolaunch transport. + pub fn new() -> Self { + Self { scope: None } + } + + /// Set the `autolaunch:` address `scope` value. + pub fn set_scope(mut self, scope: Option) -> Self { + self.scope = scope; + + self + } + + /// The optional scope. + pub fn scope(&self) -> Option<&AutolaunchScope> { + self.scope.as_ref() + } + + pub(super) fn from_options(opts: HashMap<&str, &str>) -> Result { + opts.get("scope") + .map(|scope| -> Result<_> { + let decoded = super::decode_percents(scope)?; + match decoded.as_slice() { + b"install-path" => Ok(AutolaunchScope::InstallPath), + b"user" => Ok(AutolaunchScope::User), + _ => String::from_utf8(decoded) + .map(AutolaunchScope::Other) + .map_err(|_| { + Error::Address("autolaunch scope is not valid UTF-8".to_owned()) + }), + } + }) + .transpose() + .map(|scope| Self { scope }) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum AutolaunchScope { + /// Limit session bus to dbus installation path. + InstallPath, + /// Limit session bus to the recent user. + User, + /// other values - specify dedicated session bus like "release", "debug" or other. + Other(String), +} + +impl std::fmt::Display for AutolaunchScope { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::InstallPath => write!(f, "*install-path"), + Self::User => write!(f, "*user"), + Self::Other(o) => write!(f, "{o}"), + } + } +} diff --git a/zbus/src/address/transport/launchd.rs b/zbus/src/address/transport/launchd.rs new file mode 100644 index 000000000..7475d7511 --- /dev/null +++ b/zbus/src/address/transport/launchd.rs @@ -0,0 +1,54 @@ +use super::{Transport, Unix, UnixPath}; +use crate::{process::run, Result}; +use std::collections::HashMap; + +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +/// The transport properties of a launchd D-Bus address. +pub struct Launchd { + pub(super) env: String, +} + +impl Launchd { + /// Create a new launchd D-Bus address. + pub fn new(env: &str) -> Self { + Self { + env: env.to_string(), + } + } + + /// The path of the unix domain socket for the launchd created dbus-daemon. + pub fn env(&self) -> &str { + &self.env + } + + /// Determine the actual transport details behin a launchd address. + pub(super) async fn bus_address(&self) -> Result { + let output = run("launchctl", ["getenv", self.env()]) + .await + .expect("failed to wait on launchctl output"); + + if !output.status.success() { + return Err(crate::Error::Address(format!( + "launchctl terminated with code: {}", + output.status + ))); + } + + let addr = String::from_utf8(output.stdout).map_err(|e| { + crate::Error::Address(format!("Unable to parse launchctl output as UTF-8: {}", e)) + })?; + + Ok(Transport::Unix(Unix::new(UnixPath::File( + addr.trim().into(), + )))) + } + + pub(super) fn from_options(opts: HashMap<&str, &str>) -> Result { + opts.get("env") + .ok_or_else(|| crate::Error::Address("missing env key".into())) + .map(|env| Self { + env: env.to_string(), + }) + } +} diff --git a/zbus/src/address/transport/mod.rs b/zbus/src/address/transport/mod.rs new file mode 100644 index 000000000..986dd22e2 --- /dev/null +++ b/zbus/src/address/transport/mod.rs @@ -0,0 +1,417 @@ +//! D-Bus transport Information module. +//! +//! This module provides the trasport information for D-Bus addresses. + +#[cfg(windows)] +use crate::win32::windows_autolaunch_bus_address; +use crate::{Error, Result}; +#[cfg(not(feature = "tokio"))] +use async_io::Async; +#[cfg(not(feature = "tokio"))] +use std::net::TcpStream; +#[cfg(unix)] +use std::os::unix::net::{SocketAddr, UnixStream}; +use std::{collections::HashMap, ffi::OsStr}; +#[cfg(feature = "tokio")] +use tokio::net::TcpStream; +#[cfg(feature = "tokio-vsock")] +use tokio_vsock::VsockStream; +#[cfg(windows)] +use uds_windows::UnixStream; +#[cfg(all(feature = "vsock", not(feature = "tokio")))] +use vsock::VsockStream; + +use std::{ + fmt::{Display, Formatter}, + str::from_utf8_unchecked, +}; + +mod unix; +pub use unix::{Unix, UnixPath}; +mod tcp; +pub use tcp::{Tcp, TcpTransportFamily}; +#[cfg(windows)] +mod autolaunch; +#[cfg(windows)] +pub use autolaunch::{Autolaunch, AutolaunchScope}; +#[cfg(target_os = "macos")] +mod launchd; +#[cfg(target_os = "macos")] +pub use launchd::Launchd; +#[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" +))] +#[path = "vsock.rs"] +// Gotta rename to avoid name conflict with the `vsock` crate. +mod vsock_transport; +#[cfg(target_os = "linux")] +use std::os::linux::net::SocketAddrExt; +#[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" +))] +pub use vsock_transport::Vsock; + +/// The transport properties of a D-Bus address. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum Transport { + /// A Unix Domain Socket address. + Unix(Unix), + /// TCP address details + Tcp(Tcp), + /// autolaunch D-Bus address. + #[cfg(windows)] + Autolaunch(Autolaunch), + /// launchd D-Bus address. + #[cfg(target_os = "macos")] + Launchd(Launchd), + #[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" + ))] + /// VSOCK address + /// + /// This variant is only available when either `vsock` or `tokio-vsock` feature is enabled. The + /// type of `stream` is `vsock::VsockStream` with `vsock` feature and + /// `tokio_vsock::VsockStream` with `tokio-vsock` feature. + Vsock(Vsock), +} + +impl Transport { + #[cfg_attr(any(target_os = "macos", windows), async_recursion::async_recursion)] + pub(super) async fn connect(self) -> Result { + match self { + Transport::Unix(unix) => { + // This is a `path` in case of Windows until uds_windows provides the needed API: + // https://github.com/haraldh/rust_uds_windows/issues/14 + let addr = match unix.take_path() { + #[cfg(unix)] + UnixPath::File(path) => SocketAddr::from_pathname(path)?, + #[cfg(windows)] + UnixPath::File(path) => path, + #[cfg(target_os = "linux")] + UnixPath::Abstract(name) => SocketAddr::from_abstract_name(name)?, + UnixPath::Dir(_) | UnixPath::TmpDir(_) => { + // you can't connect to a unix:dir + return Err(Error::Unsupported); + } + }; + let stream = crate::Task::spawn_blocking( + move || -> Result<_> { + #[cfg(unix)] + let stream = UnixStream::connect_addr(&addr)?; + #[cfg(windows)] + let stream = UnixStream::connect(addr)?; + stream.set_nonblocking(true)?; + + Ok(stream) + }, + "unix stream connection", + ) + .await?; + #[cfg(not(feature = "tokio"))] + { + Async::new(stream) + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) + } + + #[cfg(feature = "tokio")] + { + #[cfg(unix)] + { + tokio::net::UnixStream::from_std(stream) + .map(Stream::Unix) + .map_err(|e| Error::InputOutput(e.into())) + } + + #[cfg(not(unix))] + { + let _ = path; + Err(Error::Unsupported) + } + } + } + #[cfg(all(feature = "vsock", not(feature = "tokio")))] + Transport::Vsock(addr) => { + let stream = VsockStream::connect_with_cid_port(addr.cid(), addr.port())?; + Async::new(stream).map(Stream::Vsock).map_err(Into::into) + } + + #[cfg(feature = "tokio-vsock")] + Transport::Vsock(addr) => VsockStream::connect(addr.cid(), addr.port()) + .await + .map(Stream::Vsock) + .map_err(Into::into), + + Transport::Tcp(mut addr) => match addr.take_nonce_file() { + Some(nonce_file) => { + #[allow(unused_mut)] + let mut stream = addr.connect().await?; + + #[cfg(unix)] + let nonce_file = { + use std::os::unix::ffi::OsStrExt; + std::ffi::OsStr::from_bytes(&nonce_file) + }; + + #[cfg(windows)] + let nonce_file = std::str::from_utf8(&nonce_file).map_err(|_| { + Error::Address("nonce file path is invalid UTF-8".to_owned()) + })?; + + #[cfg(not(feature = "tokio"))] + { + let nonce = std::fs::read(nonce_file)?; + let mut nonce = &nonce[..]; + + while !nonce.is_empty() { + let len = stream + .write_with(|mut s| std::io::Write::write(&mut s, nonce)) + .await?; + nonce = &nonce[len..]; + } + } + + #[cfg(feature = "tokio")] + { + let nonce = tokio::fs::read(nonce_file).await?; + tokio::io::AsyncWriteExt::write_all(&mut stream, &nonce).await?; + } + + Ok(Stream::Tcp(stream)) + } + None => addr.connect().await.map(Stream::Tcp), + }, + + #[cfg(windows)] + Transport::Autolaunch(Autolaunch { scope }) => match scope { + Some(_) => Err(Error::Address( + "Autolaunch scopes are currently unsupported".to_owned(), + )), + None => { + let addr = windows_autolaunch_bus_address()?; + addr.connect().await + } + }, + + #[cfg(target_os = "macos")] + Transport::Launchd(launchd) => { + let addr = launchd.bus_address().await?; + addr.connect().await + } + } + } + + // Helper for `FromStr` impl of `Address`. + pub(super) fn from_options(transport: &str, options: HashMap<&str, &str>) -> Result { + match transport { + #[cfg(any(unix, not(feature = "tokio")))] + "unix" => Unix::from_options(options).map(Self::Unix), + "tcp" => Tcp::from_options(options, false).map(Self::Tcp), + "nonce-tcp" => Tcp::from_options(options, true).map(Self::Tcp), + #[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" + ))] + "vsock" => Vsock::from_options(options).map(Self::Vsock), + #[cfg(windows)] + "autolaunch" => Autolaunch::from_options(options).map(Self::Autolaunch), + #[cfg(target_os = "macos")] + "launchd" => Launchd::from_options(options).map(Self::Launchd), + + _ => Err(Error::Address(format!( + "unsupported transport '{transport}'" + ))), + } + } +} + +#[cfg(not(feature = "tokio"))] +#[derive(Debug)] +pub(crate) enum Stream { + Unix(Async), + Tcp(Async), + #[cfg(feature = "vsock")] + Vsock(Async), +} + +#[cfg(feature = "tokio")] +#[derive(Debug)] +pub(crate) enum Stream { + #[cfg(unix)] + Unix(tokio::net::UnixStream), + Tcp(TcpStream), + #[cfg(feature = "tokio-vsock")] + Vsock(VsockStream), +} + +fn decode_hex(c: char) -> Result { + match c { + '0'..='9' => Ok(c as u8 - b'0'), + 'a'..='f' => Ok(c as u8 - b'a' + 10), + 'A'..='F' => Ok(c as u8 - b'A' + 10), + + _ => Err(Error::Address( + "invalid hexadecimal character in percent-encoded sequence".to_owned(), + )), + } +} + +pub(crate) fn decode_percents(value: &str) -> Result> { + let mut iter = value.chars(); + let mut decoded = Vec::new(); + + while let Some(c) = iter.next() { + if matches!(c, '-' | '0'..='9' | 'A'..='Z' | 'a'..='z' | '_' | '/' | '.' | '\\' | '*') { + decoded.push(c as u8) + } else if c == '%' { + decoded.push( + decode_hex(iter.next().ok_or_else(|| { + Error::Address("incomplete percent-encoded sequence".to_owned()) + })?)? + << 4 + | decode_hex(iter.next().ok_or_else(|| { + Error::Address("incomplete percent-encoded sequence".to_owned()) + })?)?, + ); + } else { + return Err(Error::Address("Invalid character in address".to_owned())); + } + } + + Ok(decoded) +} + +pub(super) fn encode_percents(f: &mut Formatter<'_>, mut value: &[u8]) -> std::fmt::Result { + const LOOKUP: &str = "\ +%00%01%02%03%04%05%06%07%08%09%0a%0b%0c%0d%0e%0f\ +%10%11%12%13%14%15%16%17%18%19%1a%1b%1c%1d%1e%1f\ +%20%21%22%23%24%25%26%27%28%29%2a%2b%2c%2d%2e%2f\ +%30%31%32%33%34%35%36%37%38%39%3a%3b%3c%3d%3e%3f\ +%40%41%42%43%44%45%46%47%48%49%4a%4b%4c%4d%4e%4f\ +%50%51%52%53%54%55%56%57%58%59%5a%5b%5c%5d%5e%5f\ +%60%61%62%63%64%65%66%67%68%69%6a%6b%6c%6d%6e%6f\ +%70%71%72%73%74%75%76%77%78%79%7a%7b%7c%7d%7e%7f\ +%80%81%82%83%84%85%86%87%88%89%8a%8b%8c%8d%8e%8f\ +%90%91%92%93%94%95%96%97%98%99%9a%9b%9c%9d%9e%9f\ +%a0%a1%a2%a3%a4%a5%a6%a7%a8%a9%aa%ab%ac%ad%ae%af\ +%b0%b1%b2%b3%b4%b5%b6%b7%b8%b9%ba%bb%bc%bd%be%bf\ +%c0%c1%c2%c3%c4%c5%c6%c7%c8%c9%ca%cb%cc%cd%ce%cf\ +%d0%d1%d2%d3%d4%d5%d6%d7%d8%d9%da%db%dc%dd%de%df\ +%e0%e1%e2%e3%e4%e5%e6%e7%e8%e9%ea%eb%ec%ed%ee%ef\ +%f0%f1%f2%f3%f4%f5%f6%f7%f8%f9%fa%fb%fc%fd%fe%ff"; + + loop { + let pos = value.iter().position( + |c| !matches!(c, b'-' | b'0'..=b'9' | b'A'..=b'Z' | b'a'..=b'z' | b'_' | b'/' | b'.' | b'\\' | b'*'), + ); + + if let Some(pos) = pos { + // SAFETY: The above `position()` call made sure that only ASCII chars are in the string + // up to `pos` + f.write_str(unsafe { from_utf8_unchecked(&value[..pos]) })?; + + let c = value[pos]; + value = &value[pos + 1..]; + + let pos = c as usize * 3; + f.write_str(&LOOKUP[pos..pos + 3])?; + } else { + // SAFETY: The above `position()` call made sure that only ASCII chars are in the rest + // of the string + f.write_str(unsafe { from_utf8_unchecked(value) })?; + return Ok(()); + } + } +} + +impl Display for Transport { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + fn fmt_unix_path(f: &mut Formatter<'_>, path: &OsStr) -> std::fmt::Result { + #[cfg(unix)] + { + use std::os::unix::ffi::OsStrExt; + + encode_percents(f, path.as_bytes())?; + } + + #[cfg(windows)] + write!(f, "{}", path.to_str().ok_or(std::fmt::Error)?)?; + + Ok(()) + } + + match self { + Self::Tcp(addr) => { + match addr.nonce_file() { + Some(nonce_file) => { + f.write_str("nonce-tcp:noncefile=")?; + encode_percents(f, nonce_file)?; + f.write_str(",")?; + } + None => f.write_str("tcp:")?, + } + f.write_str("host=")?; + + encode_percents(f, addr.host().as_bytes())?; + + write!(f, ",port={}", addr.port())?; + + if let Some(bind) = addr.bind() { + f.write_str(",bind=")?; + encode_percents(f, bind.as_bytes())?; + } + + if let Some(family) = addr.family() { + write!(f, ",family={family}")?; + } + } + + Self::Unix(unix) => match unix.path() { + UnixPath::File(path) => { + f.write_str("unix:path=")?; + fmt_unix_path(f, path)?; + } + #[cfg(target_os = "linux")] + UnixPath::Abstract(name) => { + f.write_str("unix:abstract=")?; + encode_percents(f, name)?; + } + UnixPath::Dir(path) => { + f.write_str("unix:dir=")?; + fmt_unix_path(f, path)?; + } + UnixPath::TmpDir(path) => { + f.write_str("unix:tmpdir=")?; + fmt_unix_path(f, path)?; + } + }, + + #[cfg(any( + all(feature = "vsock", not(feature = "tokio")), + feature = "tokio-vsock" + ))] + Self::Vsock(addr) => { + write!(f, "vsock:cid={},port={}", addr.cid(), addr.port())?; + } + + #[cfg(windows)] + Self::Autolaunch(autolaunch) => { + write!(f, "autolaunch:")?; + if let Some(scope) = autolaunch.scope() { + write!(f, "scope={scope}")?; + } + } + + #[cfg(target_os = "macos")] + Self::Launchd(launchd) => { + write!(f, "launchd:env={}", launchd.env())?; + } + } + + Ok(()) + } +} diff --git a/zbus/src/address/transport/tcp.rs b/zbus/src/address/transport/tcp.rs new file mode 100644 index 000000000..18610700a --- /dev/null +++ b/zbus/src/address/transport/tcp.rs @@ -0,0 +1,199 @@ +use crate::{Error, Result}; +#[cfg(not(feature = "tokio"))] +use async_io::Async; +#[cfg(not(feature = "tokio"))] +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; +use std::{ + collections::HashMap, + fmt::{Display, Formatter}, + str::FromStr, +}; +#[cfg(feature = "tokio")] +use tokio::net::TcpStream; + +/// A TCP transport in a D-Bus address. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Tcp { + pub(super) host: String, + pub(super) bind: Option, + pub(super) port: u16, + pub(super) family: Option, + pub(super) nonce_file: Option>, +} + +impl Tcp { + /// Create a new TCP transport with the given host and port. + pub fn new(host: &str, port: u16) -> Self { + Self { + host: host.to_owned(), + port, + bind: None, + family: None, + nonce_file: None, + } + } + + /// Set the `tcp:` address `bind` value. + pub fn set_bind(mut self, bind: Option) -> Self { + self.bind = bind; + + self + } + + /// Set the `tcp:` address `family` value. + pub fn set_family(mut self, family: Option) -> Self { + self.family = family; + + self + } + + /// Set the `tcp:` address `noncefile` value. + pub fn set_nonce_file(mut self, nonce_file: Option>) -> Self { + self.nonce_file = nonce_file; + + self + } + + /// Returns the `tcp:` address `host` value. + pub fn host(&self) -> &str { + &self.host + } + + /// Returns the `tcp:` address `bind` value. + pub fn bind(&self) -> Option<&str> { + self.bind.as_deref() + } + + /// Returns the `tcp:` address `port` value. + pub fn port(&self) -> u16 { + self.port + } + + /// Returns the `tcp:` address `family` value. + pub fn family(&self) -> Option { + self.family + } + + /// The nonce file path, if any. + pub fn nonce_file(&self) -> Option<&[u8]> { + self.nonce_file.as_deref() + } + + /// Take ownership of the nonce file path, if any. + pub fn take_nonce_file(&mut self) -> Option> { + self.nonce_file.take() + } + + pub(super) fn from_options( + opts: HashMap<&str, &str>, + nonce_tcp_required: bool, + ) -> Result { + let bind = None; + if opts.contains_key("bind") { + return Err(Error::Address("`bind` isn't yet supported".into())); + } + + let host = opts + .get("host") + .ok_or_else(|| Error::Address("tcp address is missing `host`".into()))? + .to_string(); + let port = opts + .get("port") + .ok_or_else(|| Error::Address("tcp address is missing `port`".into()))?; + let port = port + .parse::() + .map_err(|_| Error::Address("invalid tcp `port`".into()))?; + let family = opts + .get("family") + .map(|f| TcpTransportFamily::from_str(f)) + .transpose()?; + let nonce_file = opts + .get("noncefile") + .map(|f| super::decode_percents(f)) + .transpose()?; + if nonce_tcp_required && nonce_file.is_none() { + return Err(Error::Address( + "nonce-tcp address is missing `noncefile`".into(), + )); + } + + Ok(Self { + host, + bind, + port, + family, + nonce_file, + }) + } + + #[cfg(not(feature = "tokio"))] + pub(super) async fn connect(self) -> Result> { + let addrs = crate::Task::spawn_blocking( + move || -> Result> { + let addrs = (self.host(), self.port()).to_socket_addrs()?.filter(|a| { + if let Some(family) = self.family() { + if family == TcpTransportFamily::Ipv4 { + a.is_ipv4() + } else { + a.is_ipv6() + } + } else { + true + } + }); + Ok(addrs.collect()) + }, + "connect tcp", + ) + .await + .map_err(|e| Error::Address(format!("Failed to receive TCP addresses: {e}")))?; + + // we could attempt connections in parallel? + let mut last_err = Error::Address("Failed to connect".into()); + for addr in addrs { + match Async::::connect(addr).await { + Ok(stream) => return Ok(stream), + Err(e) => last_err = e.into(), + } + } + + Err(last_err) + } + + #[cfg(feature = "tokio")] + pub(super) async fn connect(self) -> Result { + TcpStream::connect((self.host(), self.port())) + .await + .map_err(|e| Error::InputOutput(e.into())) + } +} + +/// A `tcp:` address family. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum TcpTransportFamily { + Ipv4, + Ipv6, +} + +impl FromStr for TcpTransportFamily { + type Err = Error; + + fn from_str(family: &str) -> Result { + match family { + "ipv4" => Ok(Self::Ipv4), + "ipv6" => Ok(Self::Ipv6), + _ => Err(Error::Address(format!( + "invalid tcp address `family`: {family}" + ))), + } + } +} + +impl Display for TcpTransportFamily { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::Ipv4 => write!(f, "ipv4"), + Self::Ipv6 => write!(f, "ipv6"), + } + } +} diff --git a/zbus/src/address/transport/unix.rs b/zbus/src/address/transport/unix.rs new file mode 100644 index 000000000..513b7c0ec --- /dev/null +++ b/zbus/src/address/transport/unix.rs @@ -0,0 +1,75 @@ +use std::{collections::HashMap, ffi::OsString}; + +/// A Unix domain socket transport in a D-Bus address. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Unix { + path: UnixPath, +} + +impl Unix { + /// Create a new Unix transport with the given path. + pub fn new(path: UnixPath) -> Self { + Self { path } + } + + /// The path. + pub fn path(&self) -> &UnixPath { + &self.path + } + + /// Take the path, consuming `self`. + pub fn take_path(self) -> UnixPath { + self.path + } + + #[cfg(any(unix, not(feature = "tokio")))] + pub(super) fn from_options(opts: HashMap<&str, &str>) -> crate::Result { + let path = opts.get("path"); + let abs = opts.get("abstract"); + let dir = opts.get("dir"); + let tmpdir = opts.get("tmpdir"); + let path = match (path, abs, dir, tmpdir) { + (Some(p), None, None, None) => UnixPath::File(OsString::from(p)), + #[cfg(target_os = "linux")] + (None, Some(p), None, None) => UnixPath::Abstract(p.as_bytes().to_owned()), + #[cfg(not(target_os = "linux"))] + (None, Some(_), None, None) => { + return Err(crate::Error::Address( + "abstract sockets currently Linux-only".to_owned(), + )); + } + (None, None, Some(p), None) => UnixPath::Dir(OsString::from(p)), + (None, None, None, Some(p)) => UnixPath::TmpDir(OsString::from(p)), + _ => { + return Err(crate::Error::Address("unix: address is invalid".to_owned())); + } + }; + + Ok(Self::new(path)) + } +} + +/// A Unix domain socket path in a D-Bus address. +#[derive(Clone, Debug, PartialEq, Eq)] +#[non_exhaustive] +pub enum UnixPath { + /// A path to a unix domain socket on the filesystem. + File(OsString), + /// A abstract unix domain socket name. + #[cfg(target_os = "linux")] + Abstract(Vec), + /// A listenable address using the specified path, in which a socket file with a random file + /// name starting with 'dbus-' will be created by the server. See [UNIX domain socket address] + /// reference documentation. + /// + /// This address is mostly relevant to server (typically bus broker) implementations. + /// + /// [UNIX domain socket address]: https://dbus.freedesktop.org/doc/dbus-specification.html#transports-unix-domain-sockets-addresses + Dir(OsString), + /// The same as UnixDir, except that on platforms with abstract sockets, the server may attempt + /// to create an abstract socket whose name starts with this directory instead of a path-based + /// socket. + /// + /// This address is mostly relevant to server (typically bus broker) implementations. + TmpDir(OsString), +} diff --git a/zbus/src/address/transport/vsock.rs b/zbus/src/address/transport/vsock.rs new file mode 100644 index 000000000..2d920e6bc --- /dev/null +++ b/zbus/src/address/transport/vsock.rs @@ -0,0 +1,43 @@ +use crate::{Error, Result}; +use std::collections::HashMap; + +/// A `tcp:` D-Bus address. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Vsock { + pub(super) cid: u32, + pub(super) port: u32, +} + +impl Vsock { + /// Create a new VSOCK address. + pub fn new(cid: u32, port: u32) -> Self { + Self { cid, port } + } + + /// The Client ID. + pub fn cid(&self) -> u32 { + self.cid + } + + /// The port. + pub fn port(&self) -> u32 { + self.port + } + + pub(super) fn from_options(opts: HashMap<&str, &str>) -> Result { + let cid = opts + .get("cid") + .ok_or_else(|| Error::Address("VSOCK address is missing cid=".into()))?; + let cid = cid + .parse::() + .map_err(|e| Error::Address(format!("Failed to parse VSOCK cid `{}`: {}", cid, e)))?; + let port = opts + .get("port") + .ok_or_else(|| Error::Address("VSOCK address is missing port=".into()))?; + let port = port + .parse::() + .map_err(|e| Error::Address(format!("Failed to parse VSOCK port `{}`: {}", port, e)))?; + + Ok(Self { cid, port }) + } +} diff --git a/zbus/src/blocking/connection/builder.rs b/zbus/src/blocking/connection/builder.rs index 62c9e052c..3e2ce9d6e 100644 --- a/zbus/src/blocking/connection/builder.rs +++ b/zbus/src/blocking/connection/builder.rs @@ -109,8 +109,12 @@ impl<'a> Builder<'a> { /// /// The to-be-created connection will wait for incoming client authentication handshake and /// negotiation messages, for peer-to-peer communications after successful creation. - pub fn server(self, guid: &'a Guid) -> Self { - Self(self.0.server(guid)) + pub fn server(self, guid: G) -> Result + where + G: TryInto>, + G::Error: Into, + { + self.0.server(guid).map(Self) } /// Set the capacity of the main (unfiltered) queue. diff --git a/zbus/src/blocking/connection/mod.rs b/zbus/src/blocking/connection/mod.rs index 94051a8c5..334807056 100644 --- a/zbus/src/blocking/connection/mod.rs +++ b/zbus/src/blocking/connection/mod.rs @@ -339,7 +339,8 @@ mod tests { let (tx, rx) = std::sync::mpsc::channel(); let server_thread = thread::spawn(move || { let c = Builder::unix_stream(p0) - .server(&guid) + .server(guid) + .unwrap() .p2p() .build() .unwrap(); diff --git a/zbus/src/blocking/fdo.rs b/zbus/src/blocking/fdo.rs index 75d427b33..3e08ca403 100644 --- a/zbus/src/blocking/fdo.rs +++ b/zbus/src/blocking/fdo.rs @@ -17,7 +17,7 @@ use crate::{ ConnectionCredentials, ManagedObjects, ReleaseNameReply, RequestNameFlags, RequestNameReply, Result, }, - Guid, + OwnedGuid, }; gen_introspectable_proxy!(false, true); diff --git a/zbus/src/connection/builder.rs b/zbus/src/connection/builder.rs index 70528a99b..ca3e8db51 100644 --- a/zbus/src/connection/builder.rs +++ b/zbus/src/connection/builder.rs @@ -61,7 +61,8 @@ type Interfaces<'a> = pub struct Builder<'a> { target: Option, max_queued: Option, - guid: Option<&'a Guid>, + // This is only set for server case. + guid: Option>, p2p: bool, internal_executor: bool, #[derivative(Debug = "ignore")] @@ -212,10 +213,14 @@ impl<'a> Builder<'a> { /// /// The to-be-created connection will wait for incoming client authentication handshake and /// negotiation messages, for peer-to-peer communications after successful creation. - pub fn server(mut self, guid: &'a Guid) -> Self { - self.guid = Some(guid); + pub fn server(mut self, guid: G) -> Result + where + G: TryInto>, + G::Error: Into, + { + self.guid = Some(guid.try_into().map_err(Into::into)?); - self + Ok(self) } /// Set the capacity of the main (unfiltered) queue. @@ -341,8 +346,14 @@ impl<'a> Builder<'a> { let mut stream = self.stream_for_target().await?; let mut auth = match self.guid { None => { + let guid = match self.target { + Some(Target::Address(ref addr)) => { + addr.guid().map(|guid| guid.to_owned().into()) + } + _ => None, + }; // SASL Handshake - Authenticated::client(stream, self.auth_mechanisms).await? + Authenticated::client(stream, guid, self.auth_mechanisms).await? } Some(guid) => { if !self.p2p { @@ -357,7 +368,7 @@ impl<'a> Builder<'a> { Authenticated::server( stream, - guid.clone(), + guid.to_owned().into(), #[cfg(unix)] client_uid, #[cfg(windows)] @@ -465,13 +476,13 @@ impl<'a> Builder<'a> { Target::VsockStream(stream) => Split::new_boxed(stream), Target::Address(address) => match address.connect().await? { #[cfg(any(unix, not(feature = "tokio")))] - address::Stream::Unix(stream) => Split::new_boxed(stream), - address::Stream::Tcp(stream) => Split::new_boxed(stream), + address::transport::Stream::Unix(stream) => Split::new_boxed(stream), + address::transport::Stream::Tcp(stream) => Split::new_boxed(stream), #[cfg(any( all(feature = "vsock", not(feature = "tokio")), feature = "tokio-vsock" ))] - address::Stream::Vsock(stream) => Split::new_boxed(stream), + address::transport::Stream::Vsock(stream) => Split::new_boxed(stream), }, Target::Socket(stream) => stream, }) diff --git a/zbus/src/connection/handshake.rs b/zbus/src/connection/handshake.rs index 6daf42f18..b3b33a3fa 100644 --- a/zbus/src/connection/handshake.rs +++ b/zbus/src/connection/handshake.rs @@ -17,7 +17,7 @@ use xdg_home::home_dir; #[cfg(windows)] use crate::win32; -use crate::{file::FileLines, guid::Guid, Error, Result}; +use crate::{file::FileLines, guid::Guid, Error, OwnedGuid, Result}; use super::socket::{BoxedSplit, ReadHalf, WriteHalf}; @@ -52,7 +52,7 @@ pub enum AuthMechanism { pub struct Authenticated { pub(crate) socket_write: Box, /// The server Guid - pub(crate) server_guid: Guid, + pub(crate) server_guid: OwnedGuid, /// Whether file descriptor passing has been accepted by both sides #[cfg(unix)] pub(crate) cap_unix_fd: bool, @@ -65,9 +65,12 @@ impl Authenticated { /// Create a client-side `Authenticated` for the given `socket`. pub async fn client( socket: BoxedSplit, + server_guid: Option, mechanisms: Option>, ) -> Result { - ClientHandshake::new(socket, mechanisms).perform().await + ClientHandshake::new(socket, mechanisms, server_guid) + .perform() + .await } /// Create a server-side `Authenticated` for the given `socket`. @@ -75,7 +78,7 @@ impl Authenticated { /// The function takes `client_uid` on Unix only. On Windows, it takes `client_sid` instead. pub async fn server( socket: BoxedSplit, - guid: Guid, + guid: OwnedGuid, #[cfg(unix)] client_uid: Option, #[cfg(windows)] client_sid: Option, auth_mechanisms: Option>, @@ -127,7 +130,7 @@ enum Command { Error(String), NegotiateUnixFD, Rejected(Vec), - Ok(Guid), + Ok(OwnedGuid), AgreeUnixFD, } @@ -147,6 +150,7 @@ enum Command { #[derive(Debug)] pub struct ClientHandshake { common: HandshakeCommon, + server_guid: Option, step: ClientHandshakeStep, } @@ -161,7 +165,11 @@ pub trait Handshake { impl ClientHandshake { /// Start a handshake on this client socket - pub fn new(socket: BoxedSplit, mechanisms: Option>) -> ClientHandshake { + pub fn new( + socket: BoxedSplit, + mechanisms: Option>, + server_guid: Option, + ) -> ClientHandshake { let mechanisms = mechanisms.unwrap_or_else(|| { let mut mechanisms = VecDeque::new(); mechanisms.push_back(AuthMechanism::External); @@ -171,8 +179,9 @@ impl ClientHandshake { }); ClientHandshake { - common: HandshakeCommon::new(socket, mechanisms, None), + common: HandshakeCommon::new(socket, mechanisms), step: ClientHandshakeStep::Init, + server_guid, } } @@ -453,7 +462,15 @@ impl Handshake for ClientHandshake { } (WaitingForOK, Command::Ok(guid)) => { trace!("Received OK from server"); - self.common.server_guid = Some(guid); + match self.server_guid { + Some(server_guid) if server_guid != guid => { + return Err(Error::Handshake(format!( + "Server GUID mismatch: expected {server_guid}, got {guid}", + ))); + } + Some(_) => (), + None => self.server_guid = Some(guid), + } if self.common.socket.read_mut().can_pass_unix_fd() { (WaitingForAgreeUnixFD, Command::NegotiateUnixFD) } else { @@ -493,7 +510,7 @@ impl Handshake for ClientHandshake { return Ok(Authenticated { socket_write: write, socket_read: Some(read), - server_guid: self.common.server_guid.unwrap(), + server_guid: self.server_guid.unwrap(), #[cfg(unix)] cap_unix_fd: self.common.cap_unix_fd, already_received_bytes: Some(self.common.recv_buffer), @@ -539,6 +556,7 @@ enum ServerHandshakeStep { pub struct ServerHandshake<'s> { common: HandshakeCommon, step: ServerHandshakeStep, + guid: OwnedGuid, #[cfg(unix)] client_uid: Option, #[cfg(windows)] @@ -550,7 +568,7 @@ pub struct ServerHandshake<'s> { impl<'s> ServerHandshake<'s> { pub fn new( socket: BoxedSplit, - guid: Guid, + guid: OwnedGuid, #[cfg(unix)] client_uid: Option, #[cfg(windows)] client_sid: Option, mechanisms: Option>, @@ -568,7 +586,7 @@ impl<'s> ServerHandshake<'s> { }; Ok(ServerHandshake { - common: HandshakeCommon::new(socket, mechanisms, Some(guid)), + common: HandshakeCommon::new(socket, mechanisms), step: ServerHandshakeStep::WaitingForNull, #[cfg(unix)] client_uid, @@ -576,11 +594,13 @@ impl<'s> ServerHandshake<'s> { client_sid, cookie_id, cookie_context, + guid, }) } async fn auth_ok(&mut self) -> Result<()> { - let cmd = Command::Ok(self.guid().clone()); + let guid = self.guid.clone(); + let cmd = Command::Ok(guid); trace!("Sending authentication OK"); self.common.write_command(cmd).await?; self.step = ServerHandshakeStep::WaitingForBegin; @@ -678,14 +698,6 @@ impl<'s> ServerHandshake<'s> { Ok(()) } - - fn guid(&self) -> &Guid { - // SAFETY: We know that the server GUID is set because we set it in the constructor. - self.common - .server_guid - .as_ref() - .expect("Server GUID not set") - } } #[async_trait] @@ -796,9 +808,7 @@ impl Handshake for ServerHandshake<'_> { return Ok(Authenticated { socket_write: write, socket_read: Some(read), - // SAFETY: We know that the server GUID is set because we set it in the - // constructor. - server_guid: self.common.server_guid.expect("Server GUID not set"), + server_guid: self.guid, #[cfg(unix)] cap_unix_fd: self.common.cap_unix_fd, already_received_bytes: Some(self.common.recv_buffer), @@ -917,7 +927,7 @@ impl FromStr for Command { let guid = words .next() .ok_or_else(|| Error::Handshake("Missing OK server GUID!".into()))?; - Command::Ok(guid.parse()?) + Command::Ok(Guid::from_str(guid)?.into()) } Some("AGREE_UNIX_FD") => Command::AgreeUnixFD, _ => return Err(Error::Handshake(format!("Unknown command: {s}"))), @@ -931,7 +941,6 @@ impl FromStr for Command { pub struct HandshakeCommon { socket: BoxedSplit, recv_buffer: Vec, - server_guid: Option, cap_unix_fd: bool, // the current AUTH mechanism is front, ordered by priority mechanisms: VecDeque, @@ -939,15 +948,10 @@ pub struct HandshakeCommon { impl HandshakeCommon { /// Start a handshake on this client socket - pub fn new( - socket: BoxedSplit, - mechanisms: VecDeque, - server_guid: Option, - ) -> Self { + pub fn new(socket: BoxedSplit, mechanisms: VecDeque) -> Self { Self { socket, recv_buffer: Vec::new(), - server_guid, cap_unix_fd: false, mechanisms, } @@ -1066,10 +1070,11 @@ mod tests { fn handshake() { let (p0, p1) = create_async_socket_pair(); - let client = ClientHandshake::new(Split::new_boxed(p0), None); + let guid = OwnedGuid::from(Guid::generate()); + let client = ClientHandshake::new(Split::new_boxed(p0), None, Some(guid.clone())); let server = ServerHandshake::new( Split::new_boxed(p1), - Guid::generate(), + guid, Some(Uid::effective().into()), None, None, @@ -1093,7 +1098,7 @@ mod tests { let (mut p0, p1) = create_async_socket_pair(); let server = ServerHandshake::new( Split::new_boxed(p1), - Guid::generate(), + Guid::generate().into(), Some(Uid::effective().into()), None, None, @@ -1122,7 +1127,7 @@ mod tests { let (mut p0, p1) = create_async_socket_pair(); let server = ServerHandshake::new( Split::new_boxed(p1), - Guid::generate(), + Guid::generate().into(), Some(Uid::effective().into()), None, None, @@ -1149,7 +1154,7 @@ mod tests { let (mut p0, p1) = create_async_socket_pair(); let server = ServerHandshake::new( Split::new_boxed(p1), - Guid::generate(), + Guid::generate().into(), Some(Uid::effective().into()), None, None, @@ -1167,7 +1172,7 @@ mod tests { let (mut p0, p1) = create_async_socket_pair(); let server = ServerHandshake::new( Split::new_boxed(p1), - Guid::generate(), + Guid::generate().into(), Some(Uid::effective().into()), Some(vec![AuthMechanism::Anonymous].into()), None, @@ -1185,7 +1190,7 @@ mod tests { let (mut p0, p1) = create_async_socket_pair(); let server = ServerHandshake::new( Split::new_boxed(p1), - Guid::generate(), + Guid::generate().into(), Some(Uid::effective().into()), Some(vec![AuthMechanism::Anonymous].into()), None, diff --git a/zbus/src/connection/mod.rs b/zbus/src/connection/mod.rs index d807cb7df..5e68132b6 100644 --- a/zbus/src/connection/mod.rs +++ b/zbus/src/connection/mod.rs @@ -29,7 +29,7 @@ use crate::{ fdo::{self, ConnectionCredentials, RequestNameFlags, RequestNameReply}, message::{Flags, Message, Type}, proxy::CacheProperties, - DBusError, Error, Executor, Guid, MatchRule, MessageStream, ObjectServer, OwnedMatchRule, + DBusError, Error, Executor, MatchRule, MessageStream, ObjectServer, OwnedGuid, OwnedMatchRule, Result, Task, }; @@ -51,7 +51,7 @@ const DEFAULT_MAX_METHOD_RETURN_QUEUED: usize = 8; /// Inner state shared by Connection and WeakConnection #[derive(Debug)] pub(crate) struct ConnectionInner { - server_guid: Guid, + server_guid: OwnedGuid, #[cfg(unix)] cap_unix_fd: bool, bus_conn: bool, @@ -824,8 +824,8 @@ impl Connection { } /// The server's GUID. - pub fn server_guid(&self) -> &str { - self.inner.server_guid.as_str() + pub fn server_guid(&self) -> &OwnedGuid { + &self.inner.server_guid } /// The underlying executor. @@ -1330,7 +1330,7 @@ mod tests { use test_log::test; use zvariant::{Endian, NATIVE_ENDIAN}; - use crate::{fdo::DBusProxy, AuthMechanism}; + use crate::{fdo::DBusProxy, AuthMechanism, Guid}; use super::*; @@ -1447,7 +1447,8 @@ mod tests { ( Builder::tcp_stream(p0) - .server(&guid) + .server(guid) + .unwrap() .p2p() .auth_mechanisms(&[AuthMechanism::Anonymous]), Builder::tcp_stream(p1).p2p(), @@ -1463,7 +1464,8 @@ mod tests { ( Builder::tcp_stream(p0) - .server(&guid) + .server(guid) + .unwrap() .p2p() .auth_mechanisms(&[AuthMechanism::Anonymous]), Builder::tcp_stream(p1).p2p(), @@ -1503,7 +1505,7 @@ mod tests { futures_util::try_join!( Builder::unix_stream(p1).p2p().build(), - Builder::unix_stream(p0).server(&guid).p2p().build(), + Builder::unix_stream(p0).server(guid).unwrap().p2p().build(), ) } @@ -1541,7 +1543,8 @@ mod tests { futures_util::try_join!( Builder::vsock_stream(server) - .server(&guid) + .server(guid) + .unwrap() .p2p() .auth_mechanisms(&[AuthMechanism::Anonymous]) .build(), @@ -1559,7 +1562,8 @@ mod tests { futures_util::try_join!( Builder::vsock_stream(server) - .server(&guid) + .server(guid) + .unwrap() .p2p() .auth_mechanisms(&[AuthMechanism::Anonymous]) .build(), @@ -1579,10 +1583,11 @@ mod tests { #[cfg(target_os = "macos")] #[test] fn connect_launchd_session_bus() { + use crate::address::{transport::Launchd, Address, Transport}; crate::block_on(async { - let addr = crate::address::macos_launchd_bus_address("DBUS_LAUNCHD_SESSION_BUS_SOCKET") - .await - .expect("Unable to get Launchd session bus address"); + let addr = Address::from(Transport::Launchd(Launchd::new( + "DBUS_LAUNCHD_SESSION_BUS_SOCKET", + ))); addr.connect().await }) .expect("Unable to connect to session bus"); @@ -1692,7 +1697,8 @@ mod tests { let (p0, p1) = UnixStream::pair().unwrap(); let mut server_builder = Builder::unix_stream(p0) - .server(&guid) + .server(guid) + .unwrap() .p2p() .auth_mechanisms(&[AuthMechanism::Cookie]) .cookie_context(cookie_context) diff --git a/zbus/src/fdo.rs b/zbus/src/fdo.rs index 20e99c024..556352cad 100644 --- a/zbus/src/fdo.rs +++ b/zbus/src/fdo.rs @@ -17,8 +17,8 @@ use zvariant::{ }; use crate::{ - dbus_interface, dbus_proxy, message::Header, object_server::SignalContext, DBusError, Guid, - ObjectServer, + dbus_interface, dbus_proxy, message::Header, object_server::SignalContext, DBusError, + ObjectServer, OwnedGuid, }; #[rustfmt::skip] @@ -731,7 +731,7 @@ macro_rules! gen_dbus_proxy { fn get_connection_unix_user(&self, bus_name: BusName<'_>) -> Result; /// Gets the unique ID of the bus. - fn get_id(&self) -> Result; + fn get_id(&self) -> Result; /// Returns the unique connection name of the primary owner of the name given. fn get_name_owner(&self, name: BusName<'_>) -> Result; diff --git a/zbus/src/guid.rs b/zbus/src/guid.rs index 5adb8d310..0ff5343c0 100644 --- a/zbus/src/guid.rs +++ b/zbus/src/guid.rs @@ -1,15 +1,15 @@ use std::{ - borrow::{Borrow, BorrowMut}, - fmt, + borrow::{Borrow, Cow}, + fmt::{self, Debug, Display, Formatter}, iter::repeat_with, - ops::{Deref, DerefMut}, + ops::Deref, str::FromStr, time::{SystemTime, UNIX_EPOCH}, }; -use serde::{Deserialize, Serialize}; +use serde::{de, Deserialize, Serialize}; use static_assertions::assert_impl_all; -use zvariant::Type; +use zvariant::{Str, Type}; /// A D-Bus server GUID. /// @@ -20,14 +20,14 @@ use zvariant::Type; /// [UUIDs chapter]: https://dbus.freedesktop.org/doc/dbus-specification.html#uuids /// [TryFrom]: #impl-TryFrom%3C%26%27_%20str%3E #[derive(Clone, Debug, PartialEq, Eq, Hash, Type, Serialize)] -pub struct Guid(String); +pub struct Guid<'g>(Str<'g>); -assert_impl_all!(Guid: Send, Sync, Unpin); +assert_impl_all!(Guid<'_>: Send, Sync, Unpin); -impl Guid { +impl Guid<'_> { /// Generate a D-Bus GUID that can be used with e.g. /// [`connection::Builder::server`](crate::connection::Builder::server). - pub fn generate() -> Self { + pub fn generate() -> Guid<'static> { let r: Vec = repeat_with(rand::random::).take(3).collect(); let r3 = match SystemTime::now().duration_since(UNIX_EPOCH) { Ok(n) => n.as_secs() as u32, @@ -35,22 +35,34 @@ impl Guid { }; let s = format!("{:08x}{:08x}{:08x}{:08x}", r[0], r[1], r[2], r3); - Self(s) + Guid(s.into()) } /// Returns a string slice for the GUID. pub fn as_str(&self) -> &str { self.0.as_str() } + + /// Same as `try_from`, except it takes a `&'static str`. + pub fn from_static_str(guid: &'static str) -> crate::Result { + validate_guid(guid)?; + + Ok(Self(Str::from_static(guid))) + } + + /// Create an owned copy of the GUID. + pub fn to_owned(&self) -> Guid<'static> { + Guid(self.0.to_owned()) + } } -impl fmt::Display for Guid { +impl fmt::Display for Guid<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{}", self.as_str()) } } -impl TryFrom<&str> for Guid { +impl<'g> TryFrom<&'g str> for Guid<'g> { type Error = crate::Error; /// Creates a GUID from a string with 32 hex digits. @@ -58,56 +70,81 @@ impl TryFrom<&str> for Guid { /// Returns `Err(`[`Error::InvalidGUID`]`)` if the provided string is not a well-formed GUID. /// /// [`Error::InvalidGUID`]: enum.Error.html#variant.InvalidGUID - fn try_from(value: &str) -> std::result::Result { - if !valid_guid(value) { - Err(crate::Error::InvalidGUID) - } else { - Ok(Guid(value.to_string())) - } + fn try_from(value: &'g str) -> std::result::Result { + validate_guid(value)?; + + Ok(Self(Str::from(value))) + } +} + +impl<'g> TryFrom> for Guid<'g> { + type Error = crate::Error; + + /// Creates a GUID from a string with 32 hex digits. + /// + /// Returns `Err(`[`Error::InvalidGUID`]`)` if the provided string is not a well-formed GUID. + /// + /// [`Error::InvalidGUID`]: enum.Error.html#variant.InvalidGUID + fn try_from(value: Str<'g>) -> std::result::Result { + validate_guid(&value)?; + + Ok(Guid(value)) } } -impl TryFrom for Guid { +impl TryFrom for Guid<'static> { type Error = crate::Error; fn try_from(value: String) -> std::result::Result { - if !valid_guid(&value) { - Err(crate::Error::InvalidGUID) - } else { - Ok(Guid(value)) - } + validate_guid(&value)?; + + Ok(Guid(value.into())) + } +} + +impl<'g> TryFrom> for Guid<'g> { + type Error = crate::Error; + + fn try_from(value: Cow<'g, str>) -> std::result::Result { + validate_guid(&value)?; + + Ok(Guid(value.into())) } } -impl FromStr for Guid { +impl FromStr for Guid<'static> { type Err = crate::Error; fn from_str(s: &str) -> Result { - s.try_into() + s.try_into().map(|guid: Guid<'_>| guid.to_owned()) } } -impl<'de> Deserialize<'de> for Guid { +impl<'de> Deserialize<'de> for Guid<'de> { fn deserialize(deserializer: D) -> std::result::Result where D: serde::Deserializer<'de>, { - String::deserialize(deserializer) + >::deserialize(deserializer) .and_then(|s| s.try_into().map_err(serde::de::Error::custom)) } } -fn valid_guid(value: &str) -> bool { - value.as_bytes().len() == 32 && value.chars().all(|c| char::is_ascii_hexdigit(&c)) +fn validate_guid(value: &str) -> crate::Result<()> { + if value.as_bytes().len() != 32 || value.chars().any(|c| !char::is_ascii_hexdigit(&c)) { + return Err(crate::Error::InvalidGUID); + } + + Ok(()) } -impl From for String { - fn from(guid: Guid) -> Self { - guid.0 +impl From> for String { + fn from(guid: Guid<'_>) -> Self { + guid.0.into() } } -impl Deref for Guid { +impl Deref for Guid<'_> { type Target = str; fn deref(&self) -> &Self::Target { @@ -115,33 +152,95 @@ impl Deref for Guid { } } -impl DerefMut for Guid { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 +impl AsRef for Guid<'_> { + fn as_ref(&self) -> &str { + self.as_str() } } -impl AsRef for Guid { - fn as_ref(&self) -> &str { +impl Borrow for Guid<'_> { + fn borrow(&self) -> &str { self.as_str() } } -impl AsMut for Guid { - fn as_mut(&mut self) -> &mut str { - &mut self.0 +/// Owned version of [`Guid`]. +#[derive(Clone, Debug, PartialEq, Eq, Hash, Type, Serialize)] +pub struct OwnedGuid(#[serde(borrow)] Guid<'static>); + +assert_impl_all!(OwnedGuid: Send, Sync, Unpin); + +impl OwnedGuid { + /// Get a reference to the inner [`Guid`]. + pub fn inner(&self) -> &Guid<'static> { + &self.0 + } +} + +impl Deref for OwnedGuid { + type Target = Guid<'static>; + + fn deref(&self) -> &Self::Target { + &self.0 } } -impl Borrow for Guid { +impl Borrow for OwnedGuid { fn borrow(&self) -> &str { - self.as_str() + self.0.as_str() + } +} + +impl From for Guid<'static> { + fn from(o: OwnedGuid) -> Self { + o.0 + } +} + +impl<'unowned, 'owned: 'unowned> From<&'owned OwnedGuid> for Guid<'unowned> { + fn from(guid: &'owned OwnedGuid) -> Self { + guid.0.clone() + } +} + +impl From> for OwnedGuid { + fn from(guid: Guid<'_>) -> Self { + OwnedGuid(guid.to_owned()) + } +} + +impl From for Str<'static> { + fn from(value: OwnedGuid) -> Self { + value.0 .0 + } +} + +impl<'de> Deserialize<'de> for OwnedGuid { + fn deserialize(deserializer: D) -> std::result::Result + where + D: de::Deserializer<'de>, + { + String::deserialize(deserializer) + .and_then(|n| Guid::try_from(n).map_err(|e| de::Error::custom(e.to_string()))) + .map(Self) + } +} + +impl PartialEq<&str> for OwnedGuid { + fn eq(&self, other: &&str) -> bool { + self.as_str() == *other + } +} + +impl PartialEq> for OwnedGuid { + fn eq(&self, other: &Guid<'_>) -> bool { + self.0 == *other } } -impl BorrowMut for Guid { - fn borrow_mut(&mut self) -> &mut str { - &mut self.0 +impl Display for OwnedGuid { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + Display::fmt(&Guid::from(self), f) } } diff --git a/zbus/src/lib.rs b/zbus/src/lib.rs index 6100c4936..d395d9491 100644 --- a/zbus/src/lib.rs +++ b/zbus/src/lib.rs @@ -44,20 +44,6 @@ pub use error::*; pub mod address; pub use address::Address; -#[deprecated(note = "Use `address::TcpAddress` instead")] -#[doc(hidden)] -pub use address::TcpAddress; -#[deprecated(note = "Use `address::TcpAddressFamily` instead")] -#[doc(hidden)] -pub use address::TcpAddressFamily; -#[cfg(any( - all(feature = "vsock", not(feature = "tokio")), - feature = "tokio-vsock" -))] -#[deprecated(note = "Use `address::VsockAddress` instead")] -#[doc(hidden)] -pub use address::VsockAddress; - mod guid; pub use guid::*; @@ -197,8 +183,6 @@ mod tests { collections::HashMap, sync::{mpsc::channel, Arc, Condvar, Mutex}, }; - #[cfg(unix)] - use std::{fs::File, os::unix::io::AsRawFd}; use crate::utils::block_on; use enumflags2::BitFlags; @@ -302,6 +286,7 @@ mod tests { #[test] #[timeout(15000)] fn fdpass_systemd() { + use std::{fs::File, os::unix::io::AsRawFd}; use zvariant::OwnedFd; let connection = blocking::Connection::system().unwrap(); @@ -953,7 +938,7 @@ mod tests { let guid = crate::Guid::generate(); let (p0, p1) = UnixStream::pair().unwrap(); - let server = Builder::unix_stream(p0).server(&guid).p2p().build(); + let server = Builder::unix_stream(p0).server(guid).unwrap().p2p().build(); let client = Builder::unix_stream(p1).p2p().build(); let (client, server) = try_join!(client, server).unwrap(); let mut stream = MessageStream::from(client); diff --git a/zbus/tests/e2e.rs b/zbus/tests/e2e.rs index 0faaad0c2..8300b7a8f 100644 --- a/zbus/tests/e2e.rs +++ b/zbus/tests/e2e.rs @@ -935,7 +935,10 @@ async fn iface_and_proxy_(p2p: bool) { let (p0, p1) = UnixStream::pair().unwrap(); ( - connection::Builder::unix_stream(p0).server(&guid).p2p(), + connection::Builder::unix_stream(p0) + .server(guid) + .unwrap() + .p2p(), connection::Builder::unix_stream(p1).p2p(), ) } @@ -950,7 +953,10 @@ async fn iface_and_proxy_(p2p: bool) { let p0 = listener.incoming().next().unwrap().unwrap(); ( - connection::Builder::tcp_stream(p0).server(&guid).p2p(), + connection::Builder::tcp_stream(p0) + .server(guid) + .unwrap() + .p2p(), connection::Builder::tcp_stream(p1).p2p(), ) } @@ -963,7 +969,10 @@ async fn iface_and_proxy_(p2p: bool) { let p0 = listener.accept().await.unwrap().0; ( - connection::Builder::tcp_stream(p0).server(&guid).p2p(), + connection::Builder::tcp_stream(p0) + .server(guid) + .unwrap() + .p2p(), connection::Builder::tcp_stream(p1).p2p(), ) } diff --git a/zbus_xmlgen/src/main.rs b/zbus_xmlgen/src/main.rs index 1137a48df..fe442492a 100644 --- a/zbus_xmlgen/src/main.rs +++ b/zbus_xmlgen/src/main.rs @@ -67,7 +67,11 @@ fn main() -> Result<(), Box> { let mut output_target = match args.output.as_deref() { Some("-") => OutputTarget::Stdout, Some(path) => { - let file = OpenOptions::new().create(true).write(true).open(path)?; + let file = OpenOptions::new() + .create(true) + .truncate(true) + .write(true) + .open(path)?; OutputTarget::SingleFile(file) } _ => OutputTarget::MultipleFiles,