Skip to content

Commit

Permalink
Split WgDataIo into WgReadIo and WgWriteIo to avoid Drop (#53)
Browse files Browse the repository at this point in the history
* Split WgDataIo into WgReadIo and WgWriteIo to avoid Drop

* Make WgRead/WriteIo fields private
  • Loading branch information
moubctez authored Feb 7, 2024
1 parent d253efc commit 2ecb464
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 5 additions & 14 deletions src/bsd/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use self::{
nvlist::NvList,
sockaddr::{pack_sockaddr, unpack_sockaddr},
timespec::{pack_timespec, unpack_timespec},
wgio::WgDataIo,
wgio::{WgReadIo, WgWriteIo},
};
use crate::{
host::{Host, Peer},
Expand Down Expand Up @@ -253,7 +253,7 @@ impl<'a> Key {
}

pub fn get_host(if_name: &str) -> Result<Host, IoError> {
let mut wg_data = WgDataIo::new(if_name);
let mut wg_data = WgReadIo::new(if_name);
wg_data.read_data()?;

let mut nvlist = NvList::new();
Expand All @@ -266,40 +266,31 @@ pub fn get_host(if_name: &str) -> Result<Host, IoError> {
}

pub fn set_host(if_name: &str, host: &Host) -> Result<(), IoError> {
let mut wg_data = WgDataIo::new(if_name);

let nvlist = host.as_nvlist();
// FIXME: use proper error, here and above
let mut buf = nvlist.pack().map_err(|_| IoError::MemAlloc)?;

wg_data.wgd_data = buf.as_mut_ptr();
wg_data.wgd_size = buf.len();
let mut wg_data = WgWriteIo::new(if_name, &mut buf);
wg_data.write_data()
}

pub fn set_peer(if_name: &str, peer: &Peer) -> Result<(), IoError> {
let mut wg_data = WgDataIo::new(if_name);

let mut nvlist = NvList::new();
nvlist.append_nvlist_array(NV_PEERS, vec![peer.as_nvlist()]);
// FIXME: use proper error, here and above
let mut buf = nvlist.pack().map_err(|_| IoError::MemAlloc)?;

wg_data.wgd_data = buf.as_mut_ptr();
wg_data.wgd_size = buf.len();
let mut wg_data = WgWriteIo::new(if_name, &mut buf);
wg_data.write_data()
}

pub fn delete_peer(if_name: &str, public_key: &Key) -> Result<(), IoError> {
let mut wg_data = WgDataIo::new(if_name);

let mut nvlist = NvList::new();
nvlist.append_nvlist_array(NV_PEERS, vec![public_key.as_nvlist_for_removal()]);
// FIXME: use proper error, here and above
let mut buf = nvlist.pack().map_err(|_| IoError::MemAlloc)?;

wg_data.wgd_data = buf.as_mut_ptr();
wg_data.wgd_size = buf.len();
let mut wg_data = WgWriteIo::new(if_name, &mut buf);
wg_data.write_data()
}

Expand Down
72 changes: 49 additions & 23 deletions src/bsd/wgio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ use nix::{ioctl_readwrite, sys::socket::AddressFamily};

use super::{create_socket, IoError};

// FIXME: `WgDataIo` has to be declared as public
ioctl_readwrite!(write_wireguard_data, b'i', 210, WgDataIo);
ioctl_readwrite!(read_wireguard_data, b'i', 211, WgDataIo);
// FIXME: `WgReadIo` and `WgWriteIo` have to be declared public.
ioctl_readwrite!(write_wireguard_data, b'i', 210, WgWriteIo);
ioctl_readwrite!(read_wireguard_data, b'i', 211, WgReadIo);

/// Represent `struct wg_data_io` defined in
/// https://github.com/freebsd/freebsd-src/blob/main/sys/dev/wg/if_wg.h
#[repr(C)]
pub struct WgDataIo {
pub(super) wgd_name: [u8; IF_NAMESIZE],
pub(super) wgd_data: *mut u8, // *void
pub(super) wgd_size: usize,
pub struct WgReadIo {
wgd_name: [u8; IF_NAMESIZE],
wgd_data: *mut u8, // *void
wgd_size: usize,
}

impl WgDataIo {
/// Create `WgDataIo` without data buffer.
impl WgReadIo {
/// Create `WgReadIo` without data buffer.
#[must_use]
pub fn new(if_name: &str) -> Self {
let mut wgd_name = [0u8; IF_NAMESIZE];
Expand Down Expand Up @@ -63,41 +63,67 @@ impl WgDataIo {
unsafe {
// First do ioctl with empty `wg_data` to obtain buffer size.
if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) {
error!("WgDataIo first read error {err}");
error!("WgReadIo first read error {err}");
return Err(IoError::ReadIo(err));
}
// Allocate buffer.
self.alloc_data()?;
// Second call to ioctl with allocated buffer.
if let Err(err) = read_wireguard_data(socket.as_raw_fd(), self) {
error!("WgDataIo second read error {err}");
error!("WgReadIo second read error {err}");
return Err(IoError::ReadIo(err));
}
}

Ok(())
}
}

impl Drop for WgReadIo {
fn drop(&mut self) {
if self.wgd_size != 0 {
let layout = Layout::array::<u8>(self.wgd_size).expect("Bad layout");
unsafe {
dealloc(self.wgd_data, layout);
}
}
}
}

/// Same data layout as `WgReadIo`, but avoid `Drop`.
#[repr(C)]
pub struct WgWriteIo {
wgd_name: [u8; IF_NAMESIZE],
wgd_data: *mut u8, // *void
wgd_size: usize,
}

impl WgWriteIo {
/// Create `WgWriteIo` from slice.
#[must_use]
pub fn new(if_name: &str, buf: &mut [u8]) -> Self {
let mut wgd_name = [0u8; IF_NAMESIZE];
if_name
.bytes()
.take(IF_NAMESIZE - 1)
.enumerate()
.for_each(|(i, b)| wgd_name[i] = b);
Self {
wgd_name,
wgd_data: buf.as_mut_ptr(),
wgd_size: buf.len(),
}
}

pub(super) fn write_data(&mut self) -> Result<(), IoError> {
let socket = create_socket(AddressFamily::Unix).map_err(IoError::WriteIo)?;
unsafe {
if let Err(err) = write_wireguard_data(socket.as_raw_fd(), self) {
error!("WgDataIo write error {err}");
error!("WgWriteIo write error {err}");
return Err(IoError::WriteIo(err));
}
}

Ok(())
}
}

impl Drop for WgDataIo {
fn drop(&mut self) {
if self.wgd_size != 0 {
let layout = Layout::array::<u8>(self.wgd_size).expect("Bad layout");
unsafe {
dealloc(self.wgd_data, layout);
}
}
}
}

0 comments on commit 2ecb464

Please sign in to comment.