Skip to content

Commit

Permalink
Partial fix for mio net refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
theduke committed Aug 8, 2024
1 parent f9959d3 commit 2335a82
Showing 1 changed file with 56 additions and 11 deletions.
67 changes: 56 additions & 11 deletions lib/virtual-net/src/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr};
use std::os::fd::AsRawFd;
#[cfg(not(target_os = "windows"))]
use std::os::fd::RawFd;
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;

use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
Expand Down Expand Up @@ -87,7 +90,21 @@ impl VirtualNetworking for LocalNetworking {
_reuse_addr: bool,
) -> Result<Box<dyn VirtualUdpSocket + Sync>> {
let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?;
socket2::SockRef::from(&socket).set_nonblocking(true).ok();

{
#[cfg(not(windows))]
let sockref = socket2::SockRef::from(&socket);
#[cfg(windows)]
let sockref = unsafe {
socket2::SockRef::from(&std::os::windows::io::BorrowedSocket::borrow_raw(
socket.stream.as_raw_socket(),
))
.set_nonblocking(true)
.ok();
};

sockref.set_nonblocking(true).ok();
}

#[allow(unused_mut)]
let mut ret = LocalUdpSocket {
Expand Down Expand Up @@ -118,6 +135,7 @@ impl VirtualNetworking for LocalNetworking {
mut peer: SocketAddr,
) -> Result<Box<dyn VirtualTcpSocket + Sync>> {
let stream = mio::net::TcpStream::connect(peer).map_err(io_err_into_net_error)?;

socket2::SockRef::from(&stream).set_nonblocking(true).ok();
if let Ok(p) = stream.peer_addr() {
peer = p;
Expand Down Expand Up @@ -333,6 +351,20 @@ impl LocalTcpStream {

ret
}

fn sock_ref<'a>(&'a self) -> socket2::SockRef<'a> {
#[cfg(not(windows))]
let r = socket2::SockRef::from(&self.stream);

#[cfg(windows)]
let r = unsafe {
socket2::SockRef::from(&std::os::windows::io::BorrowedSocket::borrow_raw(
self.stream.as_raw_socket(),
))
};

r
}
}

impl VirtualTcpSocket for LocalTcpStream {
Expand Down Expand Up @@ -363,16 +395,14 @@ impl VirtualTcpSocket for LocalTcpStream {
}

fn set_keepalive(&mut self, keepalive: bool) -> Result<()> {
socket2::SockRef::from(&self.stream)
self.sock_ref()
.set_keepalive(true)
.map_err(io_err_into_net_error)?;
Ok(())
}

fn keepalive(&self) -> Result<bool> {
let ret = socket2::SockRef::from(&self.stream)
.keepalive()
.map_err(io_err_into_net_error)?;
let ret = self.sock_ref().keepalive().map_err(io_err_into_net_error)?;
Ok(ret)
}

Expand Down Expand Up @@ -444,16 +474,15 @@ impl VirtualTcpSocket for LocalTcpStream {

impl VirtualConnectedSocket for LocalTcpStream {
fn set_linger(&mut self, linger: Option<Duration>) -> Result<()> {
socket2::SockRef::from(&self.stream)
self.sock_ref()
.set_linger(linger)
.map_err(io_err_into_net_error)?;
Ok(())
}

fn linger(&self) -> Result<Option<Duration>> {
socket2::SockRef::from(&self.stream)
.linger()
.map_err(io_err_into_net_error)
let sockref = self.sock_ref();
sockref.linger().map_err(io_err_into_net_error)
}

fn try_send(&mut self, data: &[u8]) -> Result<usize> {
Expand Down Expand Up @@ -654,6 +683,22 @@ pub struct LocalUdpSocket {
backlog: VecDeque<(BytesMut, SocketAddr)>,
}

impl LocalUdpSocket {
fn sock_ref<'a>(&'a self) -> socket2::SockRef<'a> {
#[cfg(not(windows))]
let r = socket2::SockRef::from(&self.socket);

#[cfg(windows)]
let r = unsafe {
socket2::SockRef::from(&std::os::windows::io::BorrowedSocket::borrow_raw(
self.stream.as_raw_socket(),
))
};

r
}
}

impl VirtualUdpSocket for LocalUdpSocket {
fn set_broadcast(&mut self, broadcast: bool) -> Result<()> {
self.socket
Expand Down Expand Up @@ -702,13 +747,13 @@ impl VirtualUdpSocket for LocalUdpSocket {
}

fn join_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
socket2::SockRef::from(&self.socket)
self.sock_ref()
.join_multicast_v4(&multiaddr, &iface)
.map_err(io_err_into_net_error)
}

fn leave_multicast_v4(&mut self, multiaddr: Ipv4Addr, iface: Ipv4Addr) -> Result<()> {
socket2::SockRef::from(&self.socket)
self.sock_ref()
.leave_multicast_v4(&multiaddr, &iface)
.map_err(io_err_into_net_error)
}
Expand Down

0 comments on commit 2335a82

Please sign in to comment.