diff --git a/Cargo.lock b/Cargo.lock index 7f35e26..4711812 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -27,9 +27,9 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.5" +version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2faccea4cc4ab4a667ce676a30e8ec13922a692c99bb8f5b11f1502c72e04220" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" [[package]] name = "anstyle-parse" diff --git a/src/bsd/mod.rs b/src/bsd/mod.rs index 160aa5b..749fb92 100644 --- a/src/bsd/mod.rs +++ b/src/bsd/mod.rs @@ -21,7 +21,7 @@ use self::{ nvlist::NvList, sockaddr::{pack_sockaddr, unpack_sockaddr}, timespec::{pack_timespec, unpack_timespec}, - wgio::WgDataIo, + wgio::{WgReadIo, WgWriteIo}, }; use crate::{ host::{Host, Peer}, @@ -253,7 +253,7 @@ impl<'a> Key { } pub fn get_host(if_name: &str) -> Result { - let mut wg_data = WgDataIo::new(if_name); + let mut wg_data = WgReadIo::new(if_name); wg_data.read_data()?; let mut nvlist = NvList::new(); @@ -266,40 +266,31 @@ pub fn get_host(if_name: &str) -> Result { } pub fn set_host(if_name: &str, host: &Host) -> Result<(), IoError> { - let mut wg_data = WgDataIo::new(if_name); - let nvlist = host.as_nvlist(); // FIXME: use proper error, here and above let mut buf = nvlist.pack().map_err(|_| IoError::MemAlloc)?; - wg_data.wgd_data = buf.as_mut_ptr(); - wg_data.wgd_size = buf.len(); + let mut wg_data = WgWriteIo::new(if_name, &mut buf); wg_data.write_data() } pub fn set_peer(if_name: &str, peer: &Peer) -> Result<(), IoError> { - let mut wg_data = WgDataIo::new(if_name); - let mut nvlist = NvList::new(); nvlist.append_nvlist_array(NV_PEERS, vec![peer.as_nvlist()]); // FIXME: use proper error, here and above let mut buf = nvlist.pack().map_err(|_| IoError::MemAlloc)?; - wg_data.wgd_data = buf.as_mut_ptr(); - wg_data.wgd_size = buf.len(); + let mut wg_data = WgWriteIo::new(if_name, &mut buf); wg_data.write_data() } pub fn delete_peer(if_name: &str, public_key: &Key) -> Result<(), IoError> { - let mut wg_data = WgDataIo::new(if_name); - let mut nvlist = NvList::new(); nvlist.append_nvlist_array(NV_PEERS, vec![public_key.as_nvlist_for_removal()]); // FIXME: use proper error, here and above let mut buf = nvlist.pack().map_err(|_| IoError::MemAlloc)?; - wg_data.wgd_data = buf.as_mut_ptr(); - wg_data.wgd_size = buf.len(); + let mut wg_data = WgWriteIo::new(if_name, &mut buf); wg_data.write_data() } diff --git a/src/bsd/wgio.rs b/src/bsd/wgio.rs index 88e4621..c495c7a 100644 --- a/src/bsd/wgio.rs +++ b/src/bsd/wgio.rs @@ -10,21 +10,21 @@ use nix::{ioctl_readwrite, sys::socket::AddressFamily}; use super::{create_socket, IoError}; -// FIXME: `WgDataIo` has to be declared as public -ioctl_readwrite!(write_wireguard_data, b'i', 210, WgDataIo); -ioctl_readwrite!(read_wireguard_data, b'i', 211, WgDataIo); +// FIXME: `WgReadIo` and `WgWriteIo` have to be declared public. +ioctl_readwrite!(write_wireguard_data, b'i', 210, WgWriteIo); +ioctl_readwrite!(read_wireguard_data, b'i', 211, WgReadIo); /// Represent `struct wg_data_io` defined in /// https://github.com/freebsd/freebsd-src/blob/main/sys/dev/wg/if_wg.h #[repr(C)] -pub struct WgDataIo { - pub(super) wgd_name: [u8; IF_NAMESIZE], - pub(super) wgd_data: *mut u8, // *void - pub(super) wgd_size: usize, +pub struct WgReadIo { + wgd_name: [u8; IF_NAMESIZE], + wgd_data: *mut u8, // *void + wgd_size: usize, } -impl WgDataIo { - /// Create `WgDataIo` without data buffer. +impl WgReadIo { + /// Create `WgReadIo` without data buffer. #[must_use] pub fn new(if_name: &str) -> Self { let mut wgd_name = [0u8; IF_NAMESIZE]; @@ -63,26 +63,63 @@ impl WgDataIo { unsafe { // First do ioctl with empty `wg_data` to obtain buffer size. if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) { - error!("WgDataIo first read error {err}"); + error!("WgReadIo first read error {err}"); return Err(IoError::ReadIo(err)); } // Allocate buffer. self.alloc_data()?; // Second call to ioctl with allocated buffer. if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) { - error!("WgDataIo second read error {err}"); + error!("WgReadIo second read error {err}"); return Err(IoError::ReadIo(err)); } } Ok(()) } +} + +impl Drop for WgReadIo { + fn drop(&mut self) { + if self.wgd_size != 0 { + let layout = Layout::array::(self.wgd_size).expect("Bad layout"); + unsafe { + dealloc(self.wgd_data, layout); + } + } + } +} + +/// Same data layout as `WgReadIo`, but avoid `Drop`. +#[repr(C)] +pub struct WgWriteIo { + wgd_name: [u8; IF_NAMESIZE], + wgd_data: *mut u8, // *void + wgd_size: usize, +} + +impl WgWriteIo { + /// Create `WgWriteIo` from slice. + #[must_use] + pub fn new(if_name: &str, buf: &mut [u8]) -> Self { + let mut wgd_name = [0u8; IF_NAMESIZE]; + if_name + .bytes() + .take(IF_NAMESIZE - 1) + .enumerate() + .for_each(|(i, b)| wgd_name[i] = b); + Self { + wgd_name, + wgd_data: buf.as_mut_ptr(), + wgd_size: buf.len(), + } + } pub(super) fn write_data(&mut self) -> Result<(), IoError> { let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?; unsafe { if let Err(err) = write_wireguard_data(socket.as_raw_fd(), self) { - error!("WgDataIo write error {err}"); + error!("WgWriteIo write error {err}"); return Err(IoError::WriteIo(err)); } } @@ -90,14 +127,3 @@ impl WgDataIo { Ok(()) } } - -impl Drop for WgDataIo { - fn drop(&mut self) { - if self.wgd_size != 0 { - let layout = Layout::array::(self.wgd_size).expect("Bad layout"); - unsafe { - dealloc(self.wgd_data, layout); - } - } - } -}