diff --git a/dc/s2n-quic-dc/src/path/secret/map.rs b/dc/s2n-quic-dc/src/path/secret/map.rs index d8567bf8cb..7693524921 100644 --- a/dc/s2n-quic-dc/src/path/secret/map.rs +++ b/dc/s2n-quic-dc/src/path/secret/map.rs @@ -298,6 +298,12 @@ impl Map { self.store.test_stop_cleaner(); } + #[doc(hidden)] + #[cfg(test)] + pub fn reset_all_senders(&self) { + self.store.reset_all_senders(); + } + #[doc(hidden)] #[cfg(any(test, feature = "testing"))] pub fn test_insert(&self, peer: SocketAddr) { diff --git a/dc/s2n-quic-dc/src/path/secret/map/entry.rs b/dc/s2n-quic-dc/src/path/secret/map/entry.rs index 541a5362e9..176101d19c 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/entry.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/entry.rs @@ -300,6 +300,11 @@ impl Entry { pub fn application_data(&self) -> &Option { &self.application_data } + + #[cfg(test)] + pub fn reset_sender_counter(&self) { + self.sender.reset_counter(); + } } impl receiver::Error { diff --git a/dc/s2n-quic-dc/src/path/secret/map/state.rs b/dc/s2n-quic-dc/src/path/secret/map/state.rs index 5ead0f0b4d..296b8a2642 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/state.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/state.rs @@ -1029,6 +1029,14 @@ where Ok(None) } } + + #[cfg(test)] + fn reset_all_senders(&self) { + let peer_map = self.peers.0.read(); + for entry in peer_map.iter() { + entry.reset_sender_counter(); + } + } } impl Drop for State diff --git a/dc/s2n-quic-dc/src/path/secret/map/store.rs b/dc/s2n-quic-dc/src/path/secret/map/store.rs index c0a6710fcd..18ac6fc1ef 100644 --- a/dc/s2n-quic-dc/src/path/secret/map/store.rs +++ b/dc/s2n-quic-dc/src/path/secret/map/store.rs @@ -145,4 +145,7 @@ pub trait Store: 'static + Send + Sync { &self, session: &dyn s2n_quic_core::crypto::tls::TlsSession, ) -> Result, ApplicationDataError>; + + #[cfg(test)] + fn reset_all_senders(&self); } diff --git a/dc/s2n-quic-dc/src/path/secret/sender.rs b/dc/s2n-quic-dc/src/path/secret/sender.rs index 99785c88b7..fe26cbf861 100644 --- a/dc/s2n-quic-dc/src/path/secret/sender.rs +++ b/dc/s2n-quic-dc/src/path/secret/sender.rs @@ -73,6 +73,11 @@ impl State { // Update the key to the new minimum to start at. self.current_id.fetch_max(*min_key_id, Ordering::Relaxed); } + + #[cfg(test)] + pub fn reset_counter(&self) { + self.current_id.store(0, Ordering::Relaxed); + } } #[test] diff --git a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs index a64a4e6ae1..d2d8977bfa 100644 --- a/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs +++ b/dc/s2n-quic-dc/src/stream/server/tokio/tcp/worker.rs @@ -781,6 +781,7 @@ where .on_decrypt_success(recv_buffer.into()) .is_err() { + // we just close the stream return Ok(ControlFlow::Continue(())).into(); }; @@ -813,16 +814,12 @@ where recv_buffer, ); - let sender = uds::sender::Sender::new()?; - let dest_path = self.dest_path.clone(); + let sender = uds::sender::Sender::new(&self.dest_path)?; let tcp_stream = socket.into_std()?; // FIXME make this a manual Future impl instead of Box - let send_future = Box::pin(async move { - sender - .send_msg(&buffer, &dest_path, tcp_stream.as_fd()) - .await - }); + let send_future = + Box::pin(async move { sender.send_msg(&buffer, tcp_stream.as_fd()).await }); let event_data = SocketEventData { credential_id: credentials.id.to_vec(), diff --git a/dc/s2n-quic-dc/src/stream/tests/shared_cache.rs b/dc/s2n-quic-dc/src/stream/tests/shared_cache.rs index 1e27bd5114..0c2559ed04 100644 --- a/dc/s2n-quic-dc/src/stream/tests/shared_cache.rs +++ b/dc/s2n-quic-dc/src/stream/tests/shared_cache.rs @@ -6,37 +6,35 @@ use crate::{ path::secret::{stateless_reset::Signer, Map}, psk::{client::Provider as ClientProvider, server::Provider as ServerProvider}, stream::{ - client::tokio::Client as ClientTokio, server::manager::Server as ServerTokio, Protocol, + client::tokio::Client as ClientTokio, + server::{application, manager}, + Protocol, }, testing::{init_tracing, query_event, server_name, NoopSubscriber, TestTlsProvider}, }; use s2n_quic_core::time::StdClock; use std::{ - num::NonZeroUsize, + num::{NonZero, NonZeroUsize}, path::{Path, PathBuf}, time::Duration, }; use tracing::info; -#[tokio::test] -async fn setup_servers() { - init_tracing(); - +fn create_stream_client() -> (ClientTokio, Map) { let tls_materials_provider = TestTlsProvider {}; let test_event_subscriber = NoopSubscriber {}; - let unix_socket_path1 = PathBuf::from("/tmp/shared1.sock"); - let unix_socket_path2 = PathBuf::from("/tmp/shared2.sock"); - // Create client + let client_map = Map::new( + Signer::new(b"default"), + 100, + StdClock::default(), + test_event_subscriber.clone(), + ); + let handshake_client = ClientProvider::builder() .start( "127.0.0.1:0".parse().unwrap(), - Map::new( - Signer::new(b"default"), - 100, - StdClock::default(), - test_event_subscriber.clone(), - ), + client_map.clone(), tls_materials_provider.clone(), test_event_subscriber.clone(), query_event, @@ -53,8 +51,13 @@ async fn setup_servers() { .unwrap(); info!("Client created"); + (stream_client, client_map) +} + +async fn create_handshake_server() -> ServerProvider { + let tls_materials_provider = TestTlsProvider {}; + let test_event_subscriber = NoopSubscriber {}; - // Create manager handshake server let manager_handshake_map = Map::new( Signer::new(b"default"), 1, @@ -76,6 +79,33 @@ async fn setup_servers() { "Manager handshake server: {}", handshake_server.local_addr() ); + handshake_server +} + +fn create_application_server( + unix_socket_path: &Path, + test_event_subscriber: NoopSubscriber, +) -> application::Server { + let app_server = application::Server::::builder() + .with_protocol(Protocol::Tcp) + .with_udp(false) + .with_socket_path(unix_socket_path) + .build(test_event_subscriber.clone()) + .unwrap(); + info!("Application server created"); + app_server +} + +#[tokio::test] +async fn setup_servers() { + init_tracing(); + + let test_event_subscriber = NoopSubscriber {}; + let unix_socket_path1 = PathBuf::from("/tmp/shared1.sock"); + let unix_socket_path2 = PathBuf::from("/tmp/shared2.sock"); + + let (stream_client, _) = create_stream_client(); + let handshake_server = create_handshake_server().await; let handshake_addr = handshake_server.local_addr(); stream_client @@ -109,7 +139,7 @@ async fn test_connection( test_event_subscriber: NoopSubscriber, stream_client: &ClientTokio, ) { - let manager_server = ServerTokio::::builder() + let manager_server = manager::Server::::builder() .with_address("127.0.0.1:0".parse().unwrap()) .with_protocol(Protocol::Tcp) .with_udp(false) @@ -123,13 +153,7 @@ async fn test_connection( manager_server.acceptor_addr() ); - let app_server = crate::stream::server::application::Server::::builder() - .with_protocol(Protocol::Tcp) - .with_udp(false) - .with_socket_path(unix_socket_path) - .build(test_event_subscriber.clone()) - .unwrap(); - info!("Application server created"); + let app_server = create_application_server(unix_socket_path, test_event_subscriber); info!("All servers setup completed successfully"); @@ -189,3 +213,240 @@ async fn test_connection( ); info!("Data exchange completed successfully"); } + +#[cfg(not(target_os = "macos"))] +#[tokio::test] +async fn test_kernel_queue_full() { + init_tracing(); + let test_event_subscriber = NoopSubscriber {}; + let unix_socket_path = PathBuf::from("/tmp/kernel_queue_test.sock"); + + let (stream_client, _) = create_stream_client(); + let handshake_server = create_handshake_server().await; + + let handshake_addr = handshake_server.local_addr(); + stream_client + .handshake_with(handshake_addr, server_name()) + .await + .unwrap(); + info!("Handshake completed"); + + let manager_server = manager::Server::::builder() + .with_address("127.0.0.1:0".parse().unwrap()) + .with_protocol(Protocol::Tcp) + .with_udp(false) + .with_workers(NonZeroUsize::new(1).unwrap()) + .with_socket_path(&unix_socket_path) + .with_backlog(NonZero::new(10000).unwrap()) // configuring backlog so that streams are not dropped + .build(handshake_server.clone(), test_event_subscriber.clone()) + .unwrap(); + + info!( + "Manager server created at: {:?}", + manager_server.acceptor_addr() + ); + let acceptor_addr = manager_server.acceptor_addr().unwrap(); + + let app_server = create_application_server(&unix_socket_path, test_event_subscriber); + + let mut clients = Vec::new(); + let stream_count = 10000; + let mut buffer: Vec = Vec::new(); + + for _ in 0..stream_count { + let mut client_stream = stream_client + .connect(handshake_addr, acceptor_addr, server_name()) + .await + .unwrap(); + + // read from stream times out + let read_result = tokio::time::timeout( + Duration::from_millis(2), + client_stream.read_into(&mut buffer), + ) + .await; + assert!(matches!( + read_result.unwrap_err(), + tokio::time::error::Elapsed { .. } + )); + + clients.push(client_stream); + } + + let mut servers = Vec::new(); + for _ in 0..stream_count { + let (stream, _addr) = app_server.accept().await.unwrap(); + servers.push(stream); + } + + let test_message = b"Hello from server!"; + for mut stream in servers { + let mut message_slice = &test_message[..]; + stream.write_from(&mut message_slice).await.unwrap(); + } + + for mut stream in clients { + let mut buffer: Vec = Vec::new(); + let bytes_read = stream.read_into(&mut buffer).await.unwrap(); + assert_eq!( + &buffer[..bytes_read], + test_message, + "Client should receive the correct message" + ); + } +} + +#[cfg(not(target_os = "macos"))] +#[tokio::test] +async fn test_kernel_queue_full_application_crash() { + init_tracing(); + let test_event_subscriber = NoopSubscriber {}; + let unix_socket_path = PathBuf::from("/tmp/kernel_queue_crash.sock"); + + let (stream_client, _) = create_stream_client(); + let handshake_server = create_handshake_server().await; + + let handshake_addr = handshake_server.local_addr(); + stream_client + .handshake_with(handshake_addr, server_name()) + .await + .unwrap(); + info!("Handshake completed"); + + let manager_server = manager::Server::::builder() + .with_address("127.0.0.1:0".parse().unwrap()) + .with_protocol(Protocol::Tcp) + .with_udp(false) + .with_workers(NonZeroUsize::new(1).unwrap()) + .with_socket_path(&unix_socket_path) + .with_backlog(NonZero::new(5000).unwrap()) + .build(handshake_server.clone(), test_event_subscriber.clone()) + .unwrap(); + + info!( + "Manager server created at: {:?}", + manager_server.acceptor_addr() + ); + let acceptor_addr = manager_server.acceptor_addr().unwrap(); + + let app_server = create_application_server(&unix_socket_path, test_event_subscriber); + + let mut clients = Vec::new(); + let stream_count = 5000; + + for _ in 0..stream_count { + let mut client_stream = stream_client + .connect(handshake_addr, acceptor_addr, server_name()) + .await + .unwrap(); + + let mut buffer: Vec = Vec::new(); + let read_result = tokio::time::timeout( + Duration::from_millis(5), + client_stream.read_into(&mut buffer), + ) + .await; + assert!(matches!( + read_result.unwrap_err(), + tokio::time::error::Elapsed { .. } + )); + clients.push(client_stream); + } + + drop(app_server); + + for mut stream in clients { + let mut buffer: Vec = Vec::new(); + let read_result = stream.read_into(&mut buffer).await; + let error = read_result.unwrap_err(); + assert_eq!(error.kind(), std::io::ErrorKind::UnexpectedEof); + } +} + +#[tokio::test] +async fn test_dedup_check() { + init_tracing(); + let test_event_subscriber = NoopSubscriber {}; + let unix_socket_path1 = PathBuf::from("/tmp/dedup1.sock"); + let unix_socket_path2 = PathBuf::from("/tmp/dedup2.sock"); + + let (client, client_map) = create_stream_client(); + + let handshake_server = create_handshake_server().await; + let handshake_addr = handshake_server.local_addr(); + let res = client + .handshake_with(handshake_addr, server_name()) + .await + .unwrap(); + info!("Handshake completed, {:?}", res); + + let manager_server1 = manager::Server::::builder() + .with_address("127.0.0.1:0".parse().unwrap()) + .with_protocol(Protocol::Tcp) + .with_udp(false) + .with_workers(NonZeroUsize::new(1).unwrap()) + .with_socket_path(&unix_socket_path1) + .build(handshake_server.clone(), test_event_subscriber.clone()) + .unwrap(); + + info!( + "Manager server created at: {:?}", + manager_server1.acceptor_addr() + ); + + let manager_server2 = manager::Server::::builder() + .with_address("127.0.0.1:0".parse().unwrap()) + .with_protocol(Protocol::Tcp) + .with_udp(false) + .with_workers(NonZeroUsize::new(1).unwrap()) + .with_socket_path(&unix_socket_path2) + .build(handshake_server.clone(), test_event_subscriber.clone()) + .unwrap(); + + info!( + "Manager server created at: {:?}", + manager_server2.acceptor_addr() + ); + + let app_server1 = create_application_server(&unix_socket_path1, test_event_subscriber.clone()); + let _app_server2 = create_application_server(&unix_socket_path2, test_event_subscriber); + + let acceptor_addr1 = manager_server1.acceptor_addr().unwrap(); + let mut client_stream = client + .connect(handshake_addr, acceptor_addr1, server_name()) + .await + .unwrap(); + let (mut server_stream, _addr) = app_server1.accept().await.unwrap(); + + let test_message = b"Hello from server!"; + let data_exchange_result = tokio::try_join!( + async { + let mut buffer = Vec::::new(); + let bytes_read = client_stream.read_into(&mut buffer).await?; + assert_eq!(&buffer[..bytes_read], test_message); + Ok::<(), Box>(()) + }, + async { + let mut message_slice = &test_message[..]; + server_stream.write_from(&mut message_slice).await?; + Ok::<(), Box>(()) + } + ); + + assert!(data_exchange_result.is_ok()); + + client_map.reset_all_senders(); + + let acceptor_addr2 = manager_server2.acceptor_addr().unwrap(); + let mut client_stream2 = client + .connect(handshake_addr, acceptor_addr2, server_name()) + .await + .unwrap(); + + let mut buffer: Vec = Vec::new(); + let read_result = client_stream2.read_into(&mut buffer).await; + let error = read_result.unwrap_err(); + info!("Read error {:?}", error); + // FIXME should the server be sending a control packet on ReplayDefinitelyDetected? + assert_eq!(error.kind(), std::io::ErrorKind::UnexpectedEof); +} diff --git a/dc/s2n-quic-dc/src/uds/receiver.rs b/dc/s2n-quic-dc/src/uds/receiver.rs index dca9de2c79..9ea1b3a9a3 100644 --- a/dc/s2n-quic-dc/src/uds/receiver.rs +++ b/dc/s2n-quic-dc/src/uds/receiver.rs @@ -16,49 +16,39 @@ use std::{ path::{Path, PathBuf}, sync::Arc, }; -use tokio::io::{unix::AsyncFd, Interest, Ready}; +use tokio::io::{unix::AsyncFd, Interest}; const BUFFER_SIZE: usize = u16::MAX as usize; #[derive(Clone)] pub struct Receiver { - async_fd: Arc>, + socket_fd: Arc>, socket_path: PathBuf, } impl Receiver { pub fn new(socket_path: &Path) -> Result { + let _ = unlink(socket_path); // Required in case drop did not run previously let socket = UnixDatagram::bind(socket_path)?; socket.set_nonblocking(true)?; let async_fd = Arc::new(AsyncFd::new(OwnedFd::from(socket))?); Ok(Self { - async_fd, + socket_fd: async_fd, socket_path: socket_path.to_path_buf(), }) } pub async fn receive_msg(&self) -> Result<(Vec, OwnedFd), std::io::Error> { - loop { - let mut guard = self.async_fd.ready(Interest::READABLE).await?; - - match self.try_receive_nonblocking() { - Ok(result) => { - return Ok(result); - } - Err(nix::Error::EAGAIN) => { - guard.clear_ready_matching(Ready::READABLE); - continue; - } - Err(e) => { - return Err(std::io::Error::from(e)); - } - } - } + let res = self + .socket_fd + .async_io(Interest::READABLE, |_inner| self.try_receive_nonblocking()) + .await?; + Ok(res) } - fn try_receive_nonblocking(&self) -> Result<(Vec, OwnedFd), nix::Error> { + fn try_receive_nonblocking(&self) -> Result<(Vec, OwnedFd), std::io::Error> { let mut buffer = [0u8; BUFFER_SIZE]; let mut cmsg_buffer = nix::cmsg_space!([RawFd; 1]); let mut iov = [std::io::IoSliceMut::new(&mut buffer)]; @@ -70,7 +60,7 @@ impl Receiver { let recv_flags = MsgFlags::empty(); let msg = recvmsg::( - self.async_fd.as_raw_fd(), + self.socket_fd.as_raw_fd(), &mut iov, Some(&mut cmsg_buffer), recv_flags, @@ -98,8 +88,7 @@ impl Receiver { } } } - - Err(nix::Error::EINVAL) // No file descriptor found + Err(std::io::Error::from(nix::Error::EINVAL)) // No file descriptor found } } @@ -125,7 +114,7 @@ mod tests { let receiver_path = Path::new("/tmp/receiver.sock"); let receiver = Receiver::new(receiver_path).unwrap(); - let sender = Sender::new().unwrap(); + let sender = Sender::new(receiver_path).unwrap(); let file_path = "/tmp/test.txt"; let mut file = File::create(file_path).await.unwrap(); @@ -144,7 +133,7 @@ mod tests { .await .unwrap() }, - sender.send_msg(packet_data, receiver_path, fd_to_send) + sender.send_msg(packet_data, fd_to_send) ); match result { diff --git a/dc/s2n-quic-dc/src/uds/sender.rs b/dc/s2n-quic-dc/src/uds/sender.rs index 7bbf763ea1..787165d850 100644 --- a/dc/s2n-quic-dc/src/uds/sender.rs +++ b/dc/s2n-quic-dc/src/uds/sender.rs @@ -1,7 +1,7 @@ // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags, UnixAddr}; +use nix::sys::socket::{sendmsg, ControlMessage, MsgFlags}; use std::{ os::{ fd::{BorrowedFd, OwnedFd}, @@ -9,16 +9,17 @@ use std::{ }, path::Path, }; -use tokio::io::{unix::AsyncFd, Interest, Ready}; +use tokio::io::{unix::AsyncFd, Interest}; pub struct Sender { socket_fd: AsyncFd, } impl Sender { - pub fn new() -> Result { + pub fn new(connect_path: &Path) -> Result { let socket = UnixDatagram::unbound()?; socket.set_nonblocking(true)?; + socket.connect(connect_path)?; // without this the socket is always writable let async_fd = AsyncFd::new(OwnedFd::from(socket))?; @@ -30,36 +31,23 @@ impl Sender { pub async fn send_msg( &self, packet: &[u8], - dest_path: &Path, fd_to_send: BorrowedFd<'_>, ) -> Result<(), std::io::Error> { - loop { - let mut guard = self.socket_fd.ready(Interest::WRITABLE).await?; - - match self.try_send_nonblocking(packet, dest_path, fd_to_send) { - Ok(()) => { - return Ok(()); - } - Err(nix::Error::EAGAIN) => { - guard.clear_ready_matching(Ready::WRITABLE); - continue; - } - Err(e) => { - return Err(std::io::Error::from(e)); - } - } - } + self.socket_fd + .async_io(Interest::WRITABLE, |_inner| { + self.try_send_nonblocking(packet, fd_to_send) + }) + .await?; + Ok(()) } fn try_send_nonblocking( &self, packet: &[u8], - dest_path: &Path, fd_to_send: BorrowedFd, - ) -> Result<(), nix::Error> { + ) -> Result<(), std::io::Error> { let fds = [fd_to_send.as_raw_fd()]; let cmsg = ControlMessage::ScmRights(&fds); - let dest_addr = UnixAddr::new(dest_path)?; #[cfg(target_os = "linux")] let send_flags = MsgFlags::MSG_NOSIGNAL; @@ -67,12 +55,12 @@ impl Sender { #[cfg(not(target_os = "linux"))] let send_flags = MsgFlags::empty(); - sendmsg::( + sendmsg::<()>( self.socket_fd.as_raw_fd(), &[std::io::IoSlice::new(packet)], &[cmsg], send_flags, - Some(&dest_addr), + None, )?; Ok(())