diff --git a/quic/s2n-quic-platform/src/io/tokio.rs b/quic/s2n-quic-platform/src/io/tokio.rs index d8b304aa9c..dd1c9f874c 100644 --- a/quic/s2n-quic-platform/src/io/tokio.rs +++ b/quic/s2n-quic-platform/src/io/tokio.rs @@ -19,6 +19,7 @@ use tokio::{net::UdpSocket, runtime::Handle}; mod builder; mod clock; +mod task; #[cfg(test)] mod tests; diff --git a/quic/s2n-quic-platform/src/io/tokio/task.rs b/quic/s2n-quic-platform/src/io/tokio/task.rs new file mode 100644 index 0000000000..ab82fb65f0 --- /dev/null +++ b/quic/s2n-quic-platform/src/io/tokio/task.rs @@ -0,0 +1,47 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +// depending on the platform, some of these implementations aren't used +#![allow(dead_code)] + +mod simple; +#[cfg(unix)] +mod unix; + +cfg_if::cfg_if! { + if #[cfg(s2n_quic_platform_socket_mmsg)] { + pub use mmsg::{rx, tx}; + } else if #[cfg(s2n_quic_platform_socket_msg)] { + pub use msg::{rx, tx}; + } else { + pub use simple::{rx, tx}; + } +} + +macro_rules! libc_msg { + ($message:ident, $cfg:ident) => { + #[cfg($cfg)] + mod $message { + use super::unix; + use crate::{features::Gso, message::$message::Message, socket::ring}; + + pub async fn rx>( + socket: S, + producer: ring::Producer, + ) -> std::io::Result<()> { + unix::rx(socket, producer).await + } + + pub async fn tx>( + socket: S, + consumer: ring::Consumer, + gso: Gso, + ) -> std::io::Result<()> { + unix::tx(socket, consumer, gso).await + } + } + }; +} + +libc_msg!(msg, s2n_quic_platform_socket_msg); +libc_msg!(mmsg, s2n_quic_platform_socket_mmsg); diff --git a/quic/s2n-quic-platform/src/io/tokio/task/simple.rs b/quic/s2n-quic-platform/src/io/tokio/task/simple.rs new file mode 100644 index 0000000000..43ea714043 --- /dev/null +++ b/quic/s2n-quic-platform/src/io/tokio/task/simple.rs @@ -0,0 +1,123 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + features::Gso, + message::{simple::Message, Message as _}, + socket::{ + ring, task, + task::{rx, tx}, + }, + syscall::SocketEvents, +}; +use core::task::{Context, Poll}; +use tokio::{io, net::UdpSocket}; + +pub async fn rx>( + socket: S, + producer: ring::Producer, +) -> io::Result<()> { + let socket = socket.into(); + socket.set_nonblocking(true).unwrap(); + + let socket = UdpSocket::from_std(socket).unwrap(); + let result = task::Receiver::new(producer, socket).await; + if let Some(err) = result { + Err(err) + } else { + Ok(()) + } +} + +pub async fn tx>( + socket: S, + consumer: ring::Consumer, + gso: Gso, +) -> io::Result<()> { + let socket = socket.into(); + socket.set_nonblocking(true).unwrap(); + + let socket = UdpSocket::from_std(socket).unwrap(); + let result = task::Sender::new(consumer, socket, gso).await; + if let Some(err) = result { + Err(err) + } else { + Ok(()) + } +} + +impl tx::Socket for UdpSocket { + type Error = io::Error; + + #[inline] + fn send( + &mut self, + cx: &mut Context, + entries: &mut [Message], + events: &mut tx::Events, + ) -> io::Result<()> { + for entry in entries { + let target = (*entry.remote_address()).into(); + let payload = entry.payload_mut(); + match self.poll_send_to(cx, payload, target) { + Poll::Ready(Ok(_)) => { + if events.on_complete(1).is_break() { + return Ok(()); + } + } + Poll::Ready(Err(err)) => { + if events.on_error(err).is_break() { + return Ok(()); + } + } + Poll::Pending => { + events.blocked(); + break; + } + } + } + + Ok(()) + } +} + +impl rx::Socket for UdpSocket { + type Error = io::Error; + + #[inline] + fn recv( + &mut self, + cx: &mut Context, + entries: &mut [Message], + events: &mut rx::Events, + ) -> io::Result<()> { + for entry in entries { + let payload = entry.payload_mut(); + let mut buf = io::ReadBuf::new(payload); + match self.poll_recv_from(cx, &mut buf) { + Poll::Ready(Ok(addr)) => { + unsafe { + let len = buf.filled().len(); + entry.set_payload_len(len); + } + entry.set_remote_address(&(addr.into())); + + if events.on_complete(1).is_break() { + return Ok(()); + } + } + Poll::Ready(Err(err)) => { + if events.on_error(err).is_break() { + return Ok(()); + } + } + Poll::Pending => { + events.blocked(); + break; + } + } + } + + Ok(()) + } +} diff --git a/quic/s2n-quic-platform/src/io/tokio/task/unix.rs b/quic/s2n-quic-platform/src/io/tokio/task/unix.rs new file mode 100644 index 0000000000..ef649cd40e --- /dev/null +++ b/quic/s2n-quic-platform/src/io/tokio/task/unix.rs @@ -0,0 +1,150 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use crate::{ + features::Gso, + socket::{ + ring, + task::{rx, tx}, + }, + syscall::{SocketType, UnixMessage}, +}; +use core::task::{Context, Poll}; +use std::{io, os::unix::io::AsRawFd}; +use tokio::io::unix::AsyncFd; + +pub async fn rx, M: UnixMessage + Unpin>( + socket: S, + producer: ring::Producer, +) -> io::Result<()> { + let socket = socket.into(); + socket.set_nonblocking(true).unwrap(); + + let socket = AsyncFd::new(socket).unwrap(); + let result = rx::Receiver::new(producer, socket).await; + if let Some(err) = result { + Err(err) + } else { + Ok(()) + } +} + +pub async fn tx, M: UnixMessage + Unpin>( + socket: S, + consumer: ring::Consumer, + gso: Gso, +) -> io::Result<()> { + let socket = socket.into(); + socket.set_nonblocking(true).unwrap(); + + let socket = AsyncFd::new(socket).unwrap(); + let result = tx::Sender::new(consumer, socket, gso).await; + if let Some(err) = result { + Err(err) + } else { + Ok(()) + } +} + +impl tx::Socket for AsyncFd { + type Error = io::Error; + + #[inline] + fn send( + &mut self, + cx: &mut Context, + entries: &mut [M], + events: &mut tx::Events, + ) -> io::Result<()> { + // Call the syscall for the socket + // + // NOTE: we usually wrap this in a `AsyncFdReadyGuard::try_io`. However, here we just + // assume the socket is ready in the general case and then fall back to querying + // socket readiness if it's not. This can avoid some things like having to construct + // a `std::io::Error` with `WouldBlock` and dereferencing the registration. + M::send(self.get_ref().as_raw_fd(), entries, events); + + // yield back if we weren't blocked + if !events.is_blocked() { + return Ok(()); + } + + // * First iteration we need to clear socket readiness since the `send` call returned a + // `WouldBlock`. + // * Second iteration we need to register the waker, assuming the socket readiness was + // cleared. + // * If we got a `Ready` anyway, then clear the blocked status and have the caller try + // again. + for i in 0..2 { + match self.poll_write_ready(cx) { + Poll::Ready(guard) => { + let mut guard = guard?; + if i == 0 { + guard.clear_ready(); + } else { + events.take_blocked(); + } + } + Poll::Pending => { + return Ok(()); + } + } + } + + Ok(()) + } +} + +impl rx::Socket for AsyncFd { + type Error = io::Error; + + #[inline] + fn recv( + &mut self, + cx: &mut Context, + entries: &mut [M], + events: &mut rx::Events, + ) -> io::Result<()> { + // Call the syscall for the socket + // + // NOTE: we usually wrap this in a `AsyncFdReadyGuard::try_io`. However, here we just + // assume the socket is ready in the general case and then fall back to querying + // socket readiness if it's not. This can avoid some things like having to construct + // a `std::io::Error` with `WouldBlock` and dereferencing the registration. + M::recv( + self.get_ref().as_raw_fd(), + SocketType::NonBlocking, + entries, + events, + ); + + // yield back if we weren't blocked + if !events.is_blocked() { + return Ok(()); + } + + // * First iteration we need to clear socket readiness since the `recv` call returned a + // `WouldBlock`. + // * Second iteration we need to register the waker, assuming the socket readiness was + // cleared. + // * If we got a `Ready` anyway, then clear the blocked status and have the caller try + // again. + for i in 0..2 { + match self.poll_read_ready(cx) { + Poll::Ready(guard) => { + let mut guard = guard?; + if i == 0 { + guard.clear_ready(); + } else { + events.take_blocked(); + } + } + Poll::Pending => { + return Ok(()); + } + } + } + + Ok(()) + } +} diff --git a/quic/s2n-quic-platform/src/message/simple.rs b/quic/s2n-quic-platform/src/message/simple.rs index 10cd4bbfe1..2cda41ff41 100644 --- a/quic/s2n-quic-platform/src/message/simple.rs +++ b/quic/s2n-quic-platform/src/message/simple.rs @@ -25,8 +25,8 @@ impl Message { ExplicitCongestionNotification::default() } - pub(crate) fn remote_address(&self) -> Option { - Some(self.address) + pub(crate) fn remote_address(&self) -> &SocketAddress { + &self.address } pub(crate) fn set_remote_address(&mut self, remote_address: &SocketAddress) { diff --git a/quic/s2n-quic-platform/src/socket/std.rs b/quic/s2n-quic-platform/src/socket/std.rs index a5822ddeba..6affa00330 100644 --- a/quic/s2n-quic-platform/src/socket/std.rs +++ b/quic/s2n-quic-platform/src/socket/std.rs @@ -98,28 +98,26 @@ impl Queue { let mut entries = self.0.occupied_mut(); for entry in entries.as_mut() { - if let Some(remote_address) = entry.remote_address() { - match socket.send_to(entry.payload_mut(), &remote_address) { - Ok(_) => { - count += 1; + let remote_address = *entry.remote_address(); + match socket.send_to(entry.payload_mut(), &remote_address) { + Ok(_) => { + count += 1; - publisher.on_platform_tx(event::builder::PlatformTx { count: 1 }); - } - Err(err) if count > 0 && err.would_block() => { - break; - } - Err(err) if err.was_interrupted() || err.permission_denied() => { - break; - } - Err(err) => { - entries.finish(count); + publisher.on_platform_tx(event::builder::PlatformTx { count: 1 }); + } + Err(err) if count > 0 && err.would_block() => { + break; + } + Err(err) if err.was_interrupted() || err.permission_denied() => { + break; + } + Err(err) => { + entries.finish(count); - publisher.on_platform_tx_error(event::builder::PlatformTxError { - errno: errno().0, - }); + publisher + .on_platform_tx_error(event::builder::PlatformTxError { errno: errno().0 }); - return Err(err); - } + return Err(err); } } } diff --git a/quic/s2n-quic-platform/src/socket/task.rs b/quic/s2n-quic-platform/src/socket/task.rs index 336c294ea7..4539549dbc 100644 --- a/quic/s2n-quic-platform/src/socket/task.rs +++ b/quic/s2n-quic-platform/src/socket/task.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -pub mod events; +mod events; pub mod rx; pub mod tx; diff --git a/quic/s2n-quic-platform/src/socket/task/rx.rs b/quic/s2n-quic-platform/src/socket/task/rx.rs index fd9c259010..2e6477988a 100644 --- a/quic/s2n-quic-platform/src/socket/task/rx.rs +++ b/quic/s2n-quic-platform/src/socket/task/rx.rs @@ -12,6 +12,8 @@ use core::{ }; use futures::ready; +pub use events::RxEvents as Events; + pub trait Socket { type Error; @@ -19,7 +21,7 @@ pub trait Socket { &mut self, cx: &mut Context, entries: &mut [T], - events: &mut events::RxEvents, + events: &mut Events, ) -> Result<(), Self::Error>; } @@ -92,7 +94,7 @@ where fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let this = self.get_mut(); - let mut events = events::RxEvents::default(); + let mut events = Events::default(); while !events.take_blocked() { if ready!(this.poll_ring(u32::MAX, cx)).is_err() { diff --git a/quic/s2n-quic-platform/src/socket/task/tx.rs b/quic/s2n-quic-platform/src/socket/task/tx.rs index 72ad65a368..6fcdbf1792 100644 --- a/quic/s2n-quic-platform/src/socket/task/tx.rs +++ b/quic/s2n-quic-platform/src/socket/task/tx.rs @@ -13,6 +13,8 @@ use core::{ }; use futures::ready; +pub use events::TxEvents as Events; + pub trait Socket { type Error; @@ -20,7 +22,7 @@ pub trait Socket { &mut self, cx: &mut Context, entries: &mut [T], - events: &mut events::TxEvents, + events: &mut Events, ) -> Result<(), Self::Error>; } @@ -32,7 +34,7 @@ pub struct Sender> { /// /// This value is to avoid calling `release` too much and excessively waking up the producer. pending: u32, - events: events::TxEvents, + events: Events, } impl Sender @@ -46,7 +48,7 @@ where ring, tx, pending: 0, - events: events::TxEvents::new(gso), + events: Events::new(gso), } } diff --git a/quic/s2n-quic-platform/src/syscall.rs b/quic/s2n-quic-platform/src/syscall.rs index d49891250d..1b514d5ff9 100644 --- a/quic/s2n-quic-platform/src/syscall.rs +++ b/quic/s2n-quic-platform/src/syscall.rs @@ -1,9 +1,8 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -#![allow(unused_variables, unused_mut, clippy::let_and_return)] // some platforms contain empty - // implementations so disable any - // warnings from those +// some platforms contain empty implementations so disable any warnings from those +#![allow(unused_variables, unused_macros, unused_mut, clippy::let_and_return)] use core::ops::ControlFlow; use socket2::{Domain, Protocol, Socket, Type}; @@ -15,7 +14,7 @@ pub mod mmsg; pub mod msg; #[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[allow(dead_code)] // TODO remove once used +#[cfg_attr(not(unix), allow(dead_code))] pub enum SocketType { Blocking, NonBlocking, @@ -37,6 +36,17 @@ pub trait SocketEvents { fn on_error(&mut self, error: io::Error) -> ControlFlow<(), ()>; } +#[cfg(unix)] +pub trait UnixMessage: crate::message::Message { + fn send(fd: std::os::unix::io::RawFd, entries: &mut [Self], events: &mut E); + fn recv( + fd: std::os::unix::io::RawFd, + ty: SocketType, + entries: &mut [Self], + events: &mut E, + ); +} + pub fn udp_socket(addr: std::net::SocketAddr) -> io::Result { let domain = Domain::for_address(addr); let socket_type = Type::DGRAM; diff --git a/quic/s2n-quic-platform/src/syscall/mmsg.rs b/quic/s2n-quic-platform/src/syscall/mmsg.rs index 4ea4bf0d08..5b31f9b39c 100644 --- a/quic/s2n-quic-platform/src/syscall/mmsg.rs +++ b/quic/s2n-quic-platform/src/syscall/mmsg.rs @@ -1,11 +1,21 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -#![allow(dead_code)] // TODO remove once used - -use super::{SocketEvents, SocketType}; +use super::{SocketEvents, SocketType, UnixMessage}; use libc::mmsghdr; -use std::os::unix::io::AsRawFd; +use std::os::unix::io::{AsRawFd, RawFd}; + +impl UnixMessage for mmsghdr { + #[inline] + fn send(fd: RawFd, entries: &mut [Self], events: &mut E) { + send(&fd, entries, events) + } + + #[inline] + fn recv(fd: RawFd, ty: SocketType, entries: &mut [Self], events: &mut E) { + recv(&fd, ty, entries, events) + } +} #[inline] pub fn send( diff --git a/quic/s2n-quic-platform/src/syscall/msg.rs b/quic/s2n-quic-platform/src/syscall/msg.rs index 2efa19d732..3b11435bbe 100644 --- a/quic/s2n-quic-platform/src/syscall/msg.rs +++ b/quic/s2n-quic-platform/src/syscall/msg.rs @@ -1,11 +1,22 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -#![allow(dead_code)] // TODO remove once used - -use super::{SocketEvents, SocketType}; +use super::{SocketEvents, SocketType, UnixMessage}; +use crate::message::Message as _; use libc::msghdr; -use std::os::unix::io::AsRawFd; +use std::os::unix::io::{AsRawFd, RawFd}; + +impl UnixMessage for msghdr { + #[inline] + fn send(fd: RawFd, entries: &mut [Self], events: &mut E) { + send(&fd, entries, events) + } + + #[inline] + fn recv(fd: RawFd, ty: SocketType, entries: &mut [Self], events: &mut E) { + recv(&fd, ty, entries, events) + } +} #[inline] pub fn send<'a, Sock: AsRawFd, P: IntoIterator, E: SocketEvents>( @@ -117,7 +128,13 @@ pub fn recv<'a, Sock: AsRawFd, P: IntoIterator, E: Socket let result = libc!(recvmsg(sockfd, msg, flags)); let cf = match result { - Ok(_) => events.on_complete(1), + Ok(payload_len) => { + // update the message based on the return size of the syscall + unsafe { + msg.set_payload_len(payload_len.min(u16::MAX as _).max(0) as _); + } + events.on_complete(1) + } Err(err) => events.on_error(err), };