diff --git a/.vscode/settings.json b/.vscode/settings.json index 46c2c916942..c0024f29c73 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,10 @@ "editor.formatOnSave": false, "editor.formatOnPaste": false, "editor.formatOnType": false + }, + "[rust]": { + "editor.formatOnSave": true, + "editor.formatOnPaste": false, + "editor.formatOnType": false } } \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index bd7b1995242..aecbcefcb47 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -234,7 +234,7 @@ checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -269,9 +269,9 @@ dependencies = [ [[package]] name = "base64" -version = "0.21.3" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414dcefbc63d77c526a76b3afcf6fbb9b5e2791c19c3aa2297733208750c6e53" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "basic-toml" @@ -411,9 +411,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" dependencies = [ "serde", ] @@ -501,9 +501,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.29" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d87d9d13be47a5b7c3907137f1290b0459a7f80efb26be8c52afb11963bccb02" +checksum = "defd4e7873dbddba6c7c91e199c7fcb946abc4a6a4ac3195400bcfb01b5de877" dependencies = [ "android-tzdata", "iana-time-zone", @@ -592,7 +592,7 @@ dependencies = [ "heck 0.4.1", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1125,7 +1125,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1147,7 +1147,7 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core 0.20.3", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1197,7 +1197,7 @@ checksum = "53e0efad4403bfc52dc201159c4b842a246a14b98c64b55dfd0f2d89729dfeb8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1368,9 +1368,9 @@ dependencies = [ [[package]] name = "educe" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "079044df30bb07de7d846d41a184c4b00e66ebdac93ee459253474f3a47e50ae" +checksum = "0f0042ff8246a363dbe77d2ceedb073339e85a804b9a47636c6e016a9a32c05f" dependencies = [ "enum-ordinalize", "proc-macro2", @@ -1435,7 +1435,7 @@ dependencies = [ "num-traits", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1456,7 +1456,7 @@ dependencies = [ "darling 0.20.3", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1668,7 +1668,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -1750,7 +1750,7 @@ checksum = "ba330b70a5341d3bc730b8e205aaee97ddab5d9c448c4f51a7c2d924266fa8f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -2421,9 +2421,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" [[package]] name = "llvm-sys" @@ -2890,7 +2890,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -3072,7 +3072,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -3400,7 +3400,7 @@ checksum = "7f7473c2cfcf90008193dd0e3e16599455cb601a9fce322b5bb55de799664925" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -3677,9 +3677,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.11" +version = "0.38.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0c3dde1fc030af041adc40e79c0e7fbcf431dd24870053d187d7c66e4b87453" +checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" dependencies = [ "bitflags 2.4.0", "errno", @@ -3979,7 +3979,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -3995,9 +3995,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" dependencies = [ "itoa", "ryu", @@ -4204,9 +4204,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2538b18701741680e0322a2302176d3253a35388e2e62f172f64f4f16605f877" +checksum = "4031e820eb552adee9295814c0ced9e5cf38ddf1e8b7d566d6de8e2538ea989e" dependencies = [ "libc", "windows-sys 0.48.0", @@ -4307,9 +4307,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.31" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "718fa2415bcb8d8bd775917a1bf12a7931b6dfa890753378538118181e0cb398" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -4452,7 +4452,7 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -4544,7 +4544,7 @@ dependencies = [ "mio", "num_cpus", "pin-project-lite", - "socket2 0.5.3", + "socket2 0.5.4", "tokio-macros", "windows-sys 0.48.0", ] @@ -4557,7 +4557,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -4649,9 +4649,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.7.6" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ "serde", "serde_spanned", @@ -4670,9 +4670,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap 2.0.0", "serde", @@ -4750,7 +4750,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", ] [[package]] @@ -4834,9 +4834,9 @@ checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" [[package]] name = "trybuild" -version = "1.0.83" +version = "1.0.84" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6df60d81823ed9c520ee897489573da4b1d79ffbe006b8134f46de1a1aa03555" +checksum = "a5c89fd17b7536f2cf66c97cff6e811e89e728ca0ed13caeed610c779360d8b4" dependencies = [ "basic-toml", "glob", @@ -5349,7 +5349,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", "wasm-bindgen-shared", ] @@ -5406,7 +5406,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.31", + "syn 2.0.32", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -5866,7 +5866,7 @@ dependencies = [ "serde_yaml 0.8.26", "time", "tokio", - "toml 0.7.6", + "toml 0.7.8", "tracing", "tracing-subscriber", "url", @@ -6596,7 +6596,7 @@ dependencies = [ "tar", "tempfile", "thiserror", - "toml 0.7.6", + "toml 0.7.8", "url", "walkdir", "wasmer-toml 0.7.0", @@ -6898,9 +6898,9 @@ dependencies = [ [[package]] name = "xml-rs" -version = "0.8.17" +version = "0.8.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1eee6bf5926be7cf998d7381a9a23d833fd493f6a8034658a9505a4dc4b20444" +checksum = "bab77e97b50aee93da431f2cee7cd0f43b4d1da3c408042f2d7d164187774f0a" [[package]] name = "yaml-rust" diff --git a/lib/virtual-io/src/guard.rs b/lib/virtual-io/src/guard.rs index 96a085f34e4..c74a2f1c49e 100644 --- a/lib/virtual-io/src/guard.rs +++ b/lib/virtual-io/src/guard.rs @@ -1,16 +1,23 @@ -use std::{io, sync::Arc}; +use std::{ + io, + sync::{Arc, Weak}, +}; use mio::Token; -use crate::{InterestHandler, Selector}; - -pub(crate) struct HandlerWrapper(pub Box); +use crate::{InterestHandler, InterestType, InterestWakerMap, Selector}; #[derive(Debug)] #[must_use = "Leaking token guards will break the IO subsystem"] pub struct InterestGuard { + selector: Weak, pub(crate) token: Token, } +impl Drop for InterestGuard { + fn drop(&mut self) { + self.drop_internal(); + } +} impl InterestGuard { pub fn new( selector: &Arc, @@ -18,17 +25,71 @@ impl InterestGuard { source: &mut dyn mio::event::Source, interest: mio::Interest, ) -> io::Result { - let raw = Box::into_raw(Box::new(HandlerWrapper(handler))) as *const HandlerWrapper; - let new_token = Token(raw as usize); - selector.registry.register(source, new_token, interest)?; - Ok(Self { token: new_token }) + let token = selector.add(handler, source, interest)?; + Ok(Self { + selector: Arc::downgrade(selector), + token, + }) } - pub fn unregister( - guard: InterestGuard, - selector: &Selector, - source: &mut dyn mio::event::Source, - ) { - selector.tx_drop.lock().unwrap().send(guard.token).ok(); - selector.registry.deregister(source).unwrap(); + + pub fn unregister(&mut self, source: &mut dyn mio::event::Source) -> io::Result<()> { + if let Some(selector) = self.selector.upgrade() { + selector.remove(self.token, Some(source))?; + } + Ok(()) + } + + pub fn replace_handler( + &mut self, + handler: Box, + ) -> Result<(), Box> { + if let Some(selector) = self.selector.upgrade() { + selector.replace(self.token, handler); + Ok(()) + } else { + Err(handler) + } + } + + pub fn interest(&mut self, interest: InterestType) { + if let Some(selector) = self.selector.upgrade() { + selector.handle(self.token, |h| h.push_interest(interest)); + } + } + + fn drop_internal(&mut self) { + if let Some(selector) = self.selector.upgrade() { + selector.remove(self.token, None).ok(); + } + } +} + +#[derive(Debug)] +pub enum HandlerGuardState { + None, + ExternalHandler(InterestGuard), + WakerMap(InterestGuard, InterestWakerMap), +} + +pub fn state_as_waker_map<'a>( + state: &'a mut HandlerGuardState, + selector: &'_ Arc, + source: &'_ mut dyn mio::event::Source, +) -> io::Result<&'a mut InterestWakerMap> { + if !matches!(state, HandlerGuardState::WakerMap(_, _)) { + let waker_map = InterestWakerMap::default(); + *state = HandlerGuardState::WakerMap( + InterestGuard::new( + selector, + Box::new(waker_map.clone()), + source, + mio::Interest::READABLE | mio::Interest::WRITABLE, + )?, + waker_map, + ); } + Ok(match state { + HandlerGuardState::WakerMap(_, map) => map, + _ => unreachable!(), + }) } diff --git a/lib/virtual-io/src/interest.rs b/lib/virtual-io/src/interest.rs index 8da03be08b9..61763e99e08 100644 --- a/lib/virtual-io/src/interest.rs +++ b/lib/virtual-io/src/interest.rs @@ -15,23 +15,73 @@ pub enum InterestType { Error, } +#[derive(Debug)] +pub struct WakerInterestHandler { + set: HashSet, + waker: Waker, +} +impl WakerInterestHandler { + pub fn new(waker: &Waker) -> Box { + Box::new(WakerInterestHandler { + set: Default::default(), + waker: waker.clone(), + }) + } +} +impl InterestHandler for WakerInterestHandler { + fn push_interest(&mut self, interest: InterestType) { + self.set.insert(interest); + self.waker.wake_by_ref(); + } + + fn pop_interest(&mut self, interest: InterestType) -> bool { + self.set.remove(&interest) + } + + fn has_interest(&self, interest: InterestType) -> bool { + self.set.contains(&interest) + } +} + +#[derive(Debug, Clone)] +pub struct SharedWakerInterestHandler { + inner: Arc>>, +} +impl SharedWakerInterestHandler { + pub fn new(waker: &Waker) -> Box { + Box::new(Self { + inner: Arc::new(Mutex::new(WakerInterestHandler::new(waker))), + }) + } +} +impl InterestHandler for SharedWakerInterestHandler { + fn push_interest(&mut self, interest: InterestType) { + let mut inner = self.inner.lock().unwrap(); + inner.push_interest(interest); + } + + fn pop_interest(&mut self, interest: InterestType) -> bool { + let mut inner = self.inner.lock().unwrap(); + inner.pop_interest(interest) + } + + fn has_interest(&self, interest: InterestType) -> bool { + let inner = self.inner.lock().unwrap(); + inner.has_interest(interest) + } +} + pub trait InterestHandler: Send + Sync { - fn interest(&mut self, interest: InterestType); + fn push_interest(&mut self, interest: InterestType); + + fn pop_interest(&mut self, interest: InterestType) -> bool; + + fn has_interest(&self, interest: InterestType) -> bool; } impl From<&Waker> for Box { fn from(waker: &Waker) -> Self { - struct WakerHandler { - waker: Waker, - } - impl InterestHandler for WakerHandler { - fn interest(&mut self, _interest: InterestType) { - self.waker.wake_by_ref(); - } - } - Box::new(WakerHandler { - waker: waker.clone(), - }) + WakerInterestHandler::new(waker) } } @@ -61,7 +111,7 @@ pub struct InterestHandlerWaker { impl InterestHandlerWaker { pub fn wake_now(&self) { let mut handler = self.handler.lock().unwrap(); - handler.interest(self.interest); + handler.push_interest(self.interest); } pub fn set_interest(self: &Arc, interest: InterestType) -> Arc { let mut next = self.as_ref().clone(); @@ -95,101 +145,56 @@ const VTABLE: RawWakerVTable = unsafe { ) }; -#[derive(Derivative, Default)] -#[derivative(Debug)] -struct FilteredHandlerSubscriptionsInner { - #[derivative(Debug = "ignore")] - mappings: HashMap>, +#[derive(Debug, Clone, Default)] +struct InterestWakerMapState { + wakers: HashMap>, triggered: HashSet, } -#[derive(Derivative, Default, Clone)] -#[derivative(Debug)] -pub struct FilteredHandlerSubscriptions { - #[derivative(Debug = "ignore")] - inner: Arc>, -} -impl FilteredHandlerSubscriptions { - pub fn add_interest( - &self, - interest: InterestType, - mut handler: Box, - ) { - let mut inner = self.inner.lock().unwrap(); - if inner.triggered.take(&interest).is_some() { - handler.interest(interest) - } - inner.mappings.insert(interest, handler); - } -} - -pub struct FilteredHandler { - subs: FilteredHandlerSubscriptions, +#[derive(Debug, Clone, Default)] +pub struct InterestWakerMap { + state: Arc>, } -impl FilteredHandler { - pub fn new() -> Box { - Box::new(Self { - subs: Default::default(), - }) +impl InterestWakerMap { + pub fn add(&self, interest: InterestType, waker: &Waker) { + let mut state = self.state.lock().unwrap(); + let entries = state.wakers.entry(interest).or_default(); + if !entries.iter().any(|w| w.will_wake(waker)) { + entries.push(waker.clone()); + } } - pub fn add_interest( - self: Box, - interest: InterestType, - handler: Box, - ) -> Box { - self.subs.add_interest(interest, handler); - self + + pub fn pop(&self, interest: InterestType) -> bool { + let mut state = self.state.lock().unwrap(); + state.triggered.remove(&interest) } - pub fn subscriptions(&self) -> &FilteredHandlerSubscriptions { - &self.subs + + pub fn push(&self, interest: InterestType) -> bool { + let mut state = self.state.lock().unwrap(); + state.triggered.insert(interest) } } -impl InterestHandler for FilteredHandler { - fn interest(&mut self, interest: InterestType) { - let mut inner = self.subs.inner.lock().unwrap(); - if let Some(handler) = inner.mappings.get_mut(&interest) { - handler.interest(interest); +impl InterestHandler for InterestWakerMap { + fn push_interest(&mut self, interest: InterestType) { + let mut state = self.state.lock().unwrap(); + if let Some(wakers) = state.wakers.remove(&interest) { + for waker in wakers { + waker.wake(); + } } else { - inner.triggered.insert(interest); + state.triggered.insert(interest); } } -} - -#[derive(Debug, Default, Clone)] -pub struct StatefulHandlerState { - interest: Arc>>, -} - -impl StatefulHandlerState { - pub fn take(&self, interest: InterestType) -> bool { - let mut guard = self.interest.lock().unwrap(); - guard.remove(&interest) - } - pub fn set(&self, interest: InterestType) { - let mut guard = self.interest.lock().unwrap(); - guard.insert(interest); - } -} -pub struct StatefulHandler { - handler: Box, - state: StatefulHandlerState, -} - -impl StatefulHandler { - pub fn new( - handler: Box, - state: StatefulHandlerState, - ) -> Box { - Box::new(Self { handler, state }) + fn pop_interest(&mut self, interest: InterestType) -> bool { + let mut state = self.state.lock().unwrap(); + state.triggered.remove(&interest) } -} -impl InterestHandler for StatefulHandler { - fn interest(&mut self, interest: InterestType) { - self.state.set(interest); - self.handler.interest(interest) + fn has_interest(&self, interest: InterestType) -> bool { + let state = self.state.lock().unwrap(); + state.triggered.contains(&interest) } } diff --git a/lib/virtual-io/src/selector.rs b/lib/virtual-io/src/selector.rs index 1fae9637a4c..3ea1b2f3993 100644 --- a/lib/virtual-io/src/selector.rs +++ b/lib/virtual-io/src/selector.rs @@ -1,51 +1,53 @@ use std::{ - collections::HashSet, - mem::ManuallyDrop, - sync::{ - mpsc::{Receiver, Sender}, - Arc, Mutex, - }, + collections::HashMap, + io, + sync::{Arc, Mutex}, }; use derivative::Derivative; -use mio::Token; +use mio::{Registry, Token}; -use crate::{HandlerWrapper, InterestType}; +use crate::{InterestHandler, InterestType}; #[derive(Derivative)] #[derivative(Debug)] pub(crate) struct EngineInner { + seed: usize, + registry: Registry, #[derivative(Debug = "ignore")] - selector: mio::Poll, - rx_drop: Receiver, + lookup: HashMap>, } #[derive(Derivative)] #[derivative(Debug)] pub struct Selector { + token_close: Token, inner: Mutex, - #[derivative(Debug = "ignore")] - pub(crate) registry: mio::Registry, - pub(crate) tx_drop: Mutex>, closer: mio::Waker, } impl Selector { pub fn new() -> Arc { - let (tx_drop, rx_drop) = std::sync::mpsc::channel(); + let poll = mio::Poll::new().unwrap(); + let registry = poll + .registry() + .try_clone() + .expect("the selector registry failed to clone"); - let selector = mio::Poll::new().unwrap(); let engine = Arc::new(Selector { - closer: mio::Waker::new(selector.registry(), Token(0)).unwrap(), - registry: selector.registry().try_clone().unwrap(), - inner: Mutex::new(EngineInner { selector, rx_drop }), - tx_drop: Mutex::new(tx_drop), + closer: mio::Waker::new(poll.registry(), Token(0)).unwrap(), + token_close: Token(1), + inner: Mutex::new(EngineInner { + seed: 10, + lookup: Default::default(), + registry, + }), }); { let engine = engine.clone(); std::thread::spawn(move || { - Self::run(engine); + Self::run(engine, poll); }); } @@ -56,57 +58,124 @@ impl Selector { self.closer.wake().ok(); } - fn run(engine: Arc) { + #[must_use = "the token must be consumed"] + pub fn add( + &self, + handler: Box, + source: &mut dyn mio::event::Source, + interests: mio::Interest, + ) -> io::Result { + let mut guard = self.inner.lock().unwrap(); + + guard.seed = guard + .seed + .checked_add(1) + .expect("selector has ran out of token seeds"); + let token = guard.seed; + let token = Token(token); + guard.lookup.insert(token, handler); + + match source.register(&guard.registry, token, interests) { + Ok(()) => {} + Err(err) if err.kind() == io::ErrorKind::AlreadyExists => { + source.deregister(&guard.registry).ok(); + source.register(&guard.registry, token, interests)?; + } + Err(err) => return Err(err), + }; + + Ok(token) + } + + pub fn remove( + &self, + token: Token, + source: Option<&mut dyn mio::event::Source>, + ) -> io::Result<()> { + let mut guard = self.inner.lock().unwrap(); + guard.lookup.remove(&token); + + if let Some(source) = source { + guard.registry.deregister(source)?; + } + Ok(()) + } + + pub fn handle(&self, token: Token, f: F) + where + F: Fn(&mut Box), + { + let mut guard = self.inner.lock().unwrap(); + if let Some(handler) = guard.lookup.get_mut(&token) { + f(handler) + } + } + + pub fn replace(&self, token: Token, mut handler: Box) { + let mut guard = self.inner.lock().unwrap(); + + let last = guard.lookup.remove(&token); + if let Some(last) = last { + let interests = vec![ + InterestType::Readable, + InterestType::Writable, + InterestType::Closed, + InterestType::Error, + ]; + for interest in interests { + if last.has_interest(interest) && !handler.has_interest(interest) { + handler.push_interest(interest); + } + } + } + + guard.lookup.insert(token, handler); + } + + fn run(engine: Arc, mut poll: mio::Poll) { // The outer loop is used to release the scope of the // read lock whenever it needs to do so let mut events = mio::Events::with_capacity(128); loop { - let mut dropped = HashSet::new(); - - { - // Wait for an event to trigger - let mut guard = engine.inner.lock().unwrap(); - guard.selector.poll(&mut events, None).unwrap(); - - // Read all the tokens that have been destroyed - while let Ok(token) = guard.rx_drop.try_recv() { - let s = token.0 as *mut HandlerWrapper; - drop(unsafe { Box::from_raw(s) }); - dropped.insert(token); - } - } + // Wait for an event to trigger + poll.poll(&mut events, None).unwrap(); - // Loop through all the events + // Loop through all the events while under a guard lock + let mut guard = engine.inner.lock().unwrap(); for event in events.iter() { // If the event is already dropped then ignore it let token = event.token(); - if dropped.contains(&token) { - continue; - } // If its the close event then exit - if token.0 == 0 { + if token == engine.token_close { return; } + // Get the handler + let handler = match guard.lookup.get_mut(&token) { + Some(h) => h, + None => { + tracing::debug!(token = token.0, "orphaned event"); + continue; + } + }; + // Otherwise this is a waker we need to wake - let s = event.token().0 as *mut HandlerWrapper; - let mut handler = ManuallyDrop::new(unsafe { Box::from_raw(s) }); if event.is_readable() { tracing::trace!(token = ?token, interest = ?InterestType::Readable, "host epoll"); - handler.0.interest(InterestType::Readable); + handler.push_interest(InterestType::Readable); } if event.is_writable() { tracing::trace!(token = ?token, interest = ?InterestType::Writable, "host epoll"); - handler.0.interest(InterestType::Writable); + handler.push_interest(InterestType::Writable); } if event.is_read_closed() || event.is_write_closed() { tracing::trace!(token = ?token, interest = ?InterestType::Closed, "host epoll"); - handler.0.interest(InterestType::Closed); + handler.push_interest(InterestType::Closed); } if event.is_error() { tracing::trace!(token = ?token, interest = ?InterestType::Error, "host epoll"); - handler.0.interest(InterestType::Error); + handler.push_interest(InterestType::Error); } } } diff --git a/lib/virtual-net/src/client.rs b/lib/virtual-net/src/client.rs index 29ee9cdf1f4..5777c338810 100644 --- a/lib/virtual-net/src/client.rs +++ b/lib/virtual-net/src/client.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::collections::VecDeque; use std::future::Future; use std::net::IpAddr; use std::net::SocketAddr; @@ -87,6 +88,7 @@ impl RemoteNetworkingClient { recv_tx: Default::default(), recv_with_addr_tx: Default::default(), accept_tx: Default::default(), + sent_tx: Default::default(), handlers: Default::default(), stall: Default::default(), }; @@ -257,6 +259,9 @@ impl RemoteNetworkingClient { let (tx, rx_accept) = tokio::sync::mpsc::channel(100); self.common.accept_tx.lock().unwrap().insert(id, tx); + let (tx, rx_sent) = tokio::sync::mpsc::channel(100); + self.common.sent_tx.lock().unwrap().insert(id, tx); + RemoteSocket { socket_id: id, common: self.common.clone(), @@ -264,8 +269,12 @@ impl RemoteNetworkingClient { rx_recv, rx_recv_with_addr, rx_accept, + rx_sent, tx_waker: TxWaker::new(&self.common).as_waker(), pending_accept: None, + buffer_accept: Default::default(), + buffer_recv_with_addr: Default::default(), + send_available: 0, } } } @@ -335,7 +344,7 @@ impl Future for RemoteNetworkingClientDriver { if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Readable) + h.push_interest(InterestType::Readable) } })); } @@ -357,15 +366,27 @@ impl Future for RemoteNetworkingClientDriver { if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Readable) + h.push_interest(InterestType::Readable) } })); } - MessageResponse::Sent { socket_id, .. } => { + MessageResponse::Sent { + socket_id, amount, .. + } => { + let tx = { + let guard = self.common.sent_tx.lock().unwrap(); + match guard.get(&socket_id) { + Some(tx) => tx.clone(), + None => continue, + } + }; + self.tasks.push_back(Box::pin(async move { + tx.send(amount).await.ok(); + })); if let Some(h) = self.common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Writable) + h.push_interest(InterestType::Writable) } } MessageResponse::SendError { @@ -377,14 +398,14 @@ impl Future for RemoteNetworkingClientDriver { if let Some(h) = self.common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Closed) + h.push_interest(InterestType::Closed) } } _ => { if let Some(h) = self.common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Writable) + h.push_interest(InterestType::Writable) } } }, @@ -408,7 +429,7 @@ impl Future for RemoteNetworkingClientDriver { if let Some(h) = common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Readable) + h.push_interest(InterestType::Readable) } })); } @@ -416,7 +437,7 @@ impl Future for RemoteNetworkingClientDriver { if let Some(h) = self.common.handlers.lock().unwrap().get_mut(&socket_id) { - h.interest(InterestType::Closed) + h.push_interest(InterestType::Closed) } } MessageResponse::ResponseToRequest { req_id, res } => { @@ -449,7 +470,7 @@ impl TxWaker { fn wake_now(&self) { let mut guard = self.common.handlers.lock().unwrap(); for (_, handler) in guard.iter_mut() { - handler.interest(InterestType::Writable); + handler.push_interest(InterestType::Writable); } } @@ -519,6 +540,7 @@ struct RemoteCommon { recv_tx: Mutex>>>, recv_with_addr_tx: Mutex>>, accept_tx: Mutex>>, + sent_tx: Mutex>>, #[derivative(Debug = "ignore")] handlers: Mutex>>, @@ -849,7 +871,11 @@ struct RemoteSocket { rx_recv_with_addr: mpsc::Receiver, tx_waker: Waker, rx_accept: mpsc::Receiver, + rx_sent: mpsc::Receiver, pending_accept: Option, + buffer_recv_with_addr: VecDeque, + buffer_accept: VecDeque, + send_available: u64, } impl Drop for RemoteSocket { fn drop(&mut self) { @@ -913,6 +939,57 @@ impl VirtualIoSource for RemoteSocket { fn remove_handler(&mut self) { self.common.handlers.lock().unwrap().remove(&self.socket_id); } + + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if !self.rx_buffer.is_empty() { + return Poll::Ready(Ok(self.rx_buffer.len())); + } + match self.rx_recv.poll_recv(cx) { + Poll::Ready(Some(data)) => { + self.rx_buffer.extend_from_slice(&data); + return Poll::Ready(Ok(self.rx_buffer.len())); + } + Poll::Ready(None) => return Poll::Ready(Ok(0)), + Poll::Pending => {} + } + if !self.buffer_recv_with_addr.is_empty() { + let total = self + .buffer_recv_with_addr + .iter() + .map(|a| a.data.len()) + .sum(); + return Poll::Ready(Ok(total)); + } + match self.rx_recv_with_addr.poll_recv(cx) { + Poll::Ready(Some(data)) => self.buffer_recv_with_addr.push_back(data), + Poll::Ready(None) => return Poll::Ready(Ok(0)), + Poll::Pending => {} + } + if !self.buffer_accept.is_empty() { + return Poll::Ready(Ok(self.buffer_accept.len())); + } + match self.rx_accept.poll_recv(cx) { + Poll::Ready(Some(data)) => self.buffer_accept.push_back(data), + Poll::Ready(None) => {} + Poll::Pending => {} + } + Poll::Pending + } + + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if self.send_available > 0 { + return Poll::Ready(Ok(self.send_available as usize)); + } + match self.rx_sent.poll_recv(cx) { + Poll::Ready(Some(amt)) => { + self.send_available += amt; + return Poll::Ready(Ok(self.send_available as usize)); + } + Poll::Ready(None) => return Poll::Ready(Ok(0)), + Poll::Pending => {} + } + Poll::Pending + } } impl VirtualSocket for RemoteSocket { @@ -968,12 +1045,16 @@ impl VirtualSocket for RemoteSocket { impl VirtualTcpListener for RemoteSocket { fn try_accept(&mut self) -> Result<(Box, SocketAddr)> { + // We may already have accepted a connection in the `poll_read_ready` method self.touch_begin_accept()?; - - let accepted = self.rx_accept.try_recv().map_err(|err| match err { - TryRecvError::Empty => NetworkError::WouldBlock, - TryRecvError::Disconnected => NetworkError::ConnectionAborted, - })?; + let accepted = if let Some(child) = self.buffer_accept.pop_front() { + child + } else { + self.rx_accept.try_recv().map_err(|err| match err { + TryRecvError::Empty => NetworkError::WouldBlock, + TryRecvError::Disconnected => NetworkError::ConnectionAborted, + })? + }; // This placed here will mean there is always an accept request pending at the // server as the constructor invokes this method and we invoke it here after @@ -1003,6 +1084,13 @@ impl VirtualTcpListener for RemoteSocket { .unwrap() .insert(accepted.socket, tx); + let (tx, rx_sent) = tokio::sync::mpsc::channel(100); + self.common + .sent_tx + .lock() + .unwrap() + .insert(accepted.socket, tx); + let socket = RemoteSocket { socket_id: accepted.socket, common: self.common.clone(), @@ -1010,8 +1098,12 @@ impl VirtualTcpListener for RemoteSocket { rx_recv, rx_recv_with_addr, rx_accept, + rx_sent, pending_accept: None, tx_waker: TxWaker::new(&self.common).as_waker(), + buffer_accept: Default::default(), + buffer_recv_with_addr: Default::default(), + send_available: 0, }; Ok((Box::new(socket), accepted.addr)) } @@ -1062,8 +1154,11 @@ impl VirtualRawSocket for RemoteSocket { }, ) { Poll::Ready(Ok(())) => Ok(data.len()), + Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => { + self.send_available = 0; + Err(NetworkError::WouldBlock) + } Poll::Ready(Err(err)) => Err(err), - Poll::Pending => Err(NetworkError::WouldBlock), } } @@ -1078,8 +1173,11 @@ impl VirtualRawSocket for RemoteSocket { }, ) { Poll::Ready(Ok(())) => Ok(()), + Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => { + self.send_available = 0; + Err(NetworkError::WouldBlock) + } Poll::Ready(Err(err)) => Err(err), - Poll::Pending => Err(NetworkError::WouldBlock), } } @@ -1130,8 +1228,11 @@ impl VirtualConnectionlessSocket for RemoteSocket { }, ) { Poll::Ready(Ok(())) => Ok(data.len()), + Poll::Ready(Err(NetworkError::WouldBlock)) | Poll::Pending => { + self.send_available = 0; + Err(NetworkError::WouldBlock) + } Poll::Ready(Err(err)) => Err(err), - Poll::Pending => Err(NetworkError::WouldBlock), } } diff --git a/lib/virtual-net/src/host.rs b/lib/virtual-net/src/host.rs index 1b9b07c2e7b..0d14a1e2675 100644 --- a/lib/virtual-net/src/host.rs +++ b/lib/virtual-net/src/host.rs @@ -6,18 +6,25 @@ use crate::{ VirtualConnectionlessSocket, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, VirtualSocket, VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, }; +use bytes::{Buf, BytesMut}; use derivative::Derivative; -use std::io::{Read, Write}; +use std::collections::VecDeque; +use std::io::{self, Read, Write}; use std::mem::MaybeUninit; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, Shutdown, SocketAddr}; #[cfg(not(target_os = "windows"))] use std::os::fd::AsRawFd; +#[cfg(not(target_os = "windows"))] +use std::os::fd::RawFd; use std::sync::Arc; +use std::task::Poll; use std::time::Duration; use tokio::runtime::Handle; #[allow(unused_imports, dead_code)] use tracing::{debug, error, info, trace, warn}; -use virtual_mio::{InterestGuard, InterestHandler, Selector}; +use virtual_mio::{ + state_as_waker_map, HandlerGuardState, InterestGuard, InterestHandler, InterestType, Selector, +}; #[derive(Derivative)] #[derivative(Debug)] @@ -63,9 +70,10 @@ impl VirtualNetworking for LocalNetworking { Box::new(LocalTcpListener { stream: mio::net::TcpListener::from_std(sock), selector: self.selector.clone(), - handler_guard: None, + handler_guard: HandlerGuardState::None, no_delay: None, keep_alive: None, + backlog: Default::default(), }) }) .map_err(io_err_into_net_error)?; @@ -80,12 +88,28 @@ impl VirtualNetworking for LocalNetworking { ) -> Result> { let socket = mio::net::UdpSocket::bind(addr).map_err(io_err_into_net_error)?; socket2::SockRef::from(&socket).set_nonblocking(true).ok(); - Ok(Box::new(LocalUdpSocket { + + #[allow(unused_mut)] + let mut ret = LocalUdpSocket { selector: self.selector.clone(), socket, addr, - handler_guard: None, - })) + handler_guard: HandlerGuardState::None, + backlog: Default::default(), + }; + + // In windows we can not poll the socket as it is not supported and hence + // what we do is immediately set the writable flag and relay on `mio` to + // refresh that flag when the state changes. In Linux what we do is actually + // make a non-blocking `poll` call to determine this state + #[cfg(target_os = "windows")] + { + let (state, selector, socket) = ret.split_borrow(); + let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?; + map.push(InterestType::Writable); + } + + Ok(Box::new(ret)) } async fn connect_tcp( @@ -127,20 +151,20 @@ impl VirtualNetworking for LocalNetworking { pub struct LocalTcpListener { stream: mio::net::TcpListener, selector: Arc, - handler_guard: Option, + handler_guard: HandlerGuardState, no_delay: Option, keep_alive: Option, + backlog: VecDeque<(Box, SocketAddr)>, } -impl VirtualTcpListener for LocalTcpListener { - fn try_accept(&mut self) -> Result<(Box, SocketAddr)> { +impl LocalTcpListener { + fn try_accept_internal(&mut self) -> Result<(Box, SocketAddr)> { match self.stream.accept().map_err(io_err_into_net_error) { Ok((stream, addr)) => { socket2::SockRef::from(&self.stream) .set_nonblocking(true) .ok(); let mut socket = LocalTcpStream::new(self.selector.clone(), stream, addr); - socket.set_first_handler_writeable(); if let Some(no_delay) = self.no_delay { socket.set_nodelay(no_delay).ok(); } @@ -149,25 +173,48 @@ impl VirtualTcpListener for LocalTcpListener { } Ok((Box::new(socket), addr)) } - Err(NetworkError::WouldBlock) => Err(NetworkError::WouldBlock), + Err(NetworkError::WouldBlock) => { + if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard { + map.pop(InterestType::Readable); + map.pop(InterestType::Writable); + } + Err(NetworkError::WouldBlock) + } Err(err) => Err(err), } } +} - fn set_handler(&mut self, handler: Box) -> Result<()> { - if let Some(guard) = self.handler_guard.take() { - InterestGuard::unregister(guard, &self.selector, &mut self.stream); +impl VirtualTcpListener for LocalTcpListener { + fn try_accept(&mut self) -> Result<(Box, SocketAddr)> { + if let Some(child) = self.backlog.pop_front() { + return Ok(child); + } + self.try_accept_internal() + } + + fn set_handler(&mut self, mut handler: Box) -> Result<()> { + if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard { + match guard.replace_handler(handler) { + Ok(()) => return Ok(()), + Err(h) => handler = h, + } + + // the handler could not be replaced so we need to build a new handler instead + if let Err(err) = guard.unregister(&mut self.stream) { + tracing::debug!("failed to unregister previous token - {}", err); + } } let guard = InterestGuard::new( &self.selector, handler, &mut self.stream, - mio::Interest::READABLE, + mio::Interest::READABLE.add(mio::Interest::WRITABLE), ) .map_err(io_err_into_net_error)?; - self.handler_guard.replace(guard); + self.handler_guard = HandlerGuardState::ExternalHandler(guard); Ok(()) } @@ -190,11 +237,63 @@ impl VirtualTcpListener for LocalTcpListener { } } +impl LocalTcpListener { + fn split_borrow( + &mut self, + ) -> ( + &mut HandlerGuardState, + &Arc, + &mut mio::net::TcpListener, + ) { + (&mut self.handler_guard, &self.selector, &mut self.stream) + } +} + impl VirtualIoSource for LocalTcpListener { fn remove_handler(&mut self) { - if let Some(guard) = self.handler_guard.take() { - InterestGuard::unregister(guard, &self.selector, &mut self.stream); + let mut guard = HandlerGuardState::None; + std::mem::swap(&mut guard, &mut self.handler_guard); + match guard { + HandlerGuardState::ExternalHandler(mut guard) => { + guard.unregister(&mut self.stream).ok(); + } + HandlerGuardState::WakerMap(mut guard, _) => { + guard.unregister(&mut self.stream).ok(); + } + HandlerGuardState::None => {} + } + } + + fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if !self.backlog.is_empty() { + return Poll::Ready(Ok(self.backlog.len())); + } + + let (state, selector, source) = self.split_borrow(); + let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?; + map.add(InterestType::Readable, cx.waker()); + + if let Ok(child) = self.try_accept_internal() { + self.backlog.push_back(child); + return Poll::Ready(Ok(1)); } + Poll::Pending + } + + fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if !self.backlog.is_empty() { + return Poll::Ready(Ok(self.backlog.len())); + } + + let (state, selector, source) = self.split_borrow(); + let map = state_as_waker_map(state, selector, source).map_err(io_err_into_net_error)?; + map.add(InterestType::Writable, cx.waker()); + + if let Ok(child) = self.try_accept_internal() { + self.backlog.push_back(child); + return Poll::Ready(Ok(1)); + } + Poll::Pending } } @@ -204,23 +303,35 @@ pub struct LocalTcpStream { addr: SocketAddr, shutdown: Option, selector: Arc, - handler_guard: Option, - first_handler_writeable: bool, + handler_guard: HandlerGuardState, + buffer: BytesMut, } impl LocalTcpStream { fn new(selector: Arc, stream: mio::net::TcpStream, addr: SocketAddr) -> Self { - Self { + #[allow(unused_mut)] + let mut ret = Self { stream, addr, shutdown: None, selector, - handler_guard: None, - first_handler_writeable: false, + handler_guard: HandlerGuardState::None, + buffer: BytesMut::new(), + }; + + // In windows we can not poll the socket as it is not supported and hence + // what we do is immediately set the writable flag and relay on `mio` to + // refresh that flag when the state changes. In Linux what we do is actually + // make a non-blocking `poll` call to determine this state + #[cfg(target_os = "windows")] + { + let (state, selector, socket, _) = ret.split_borrow(); + if let Ok(map) = state_as_waker_map(state, selector, socket) { + map.push(InterestType::Writable); + } } - } - fn set_first_handler_writeable(&mut self) { - self.first_handler_writeable = true; + + ret } } @@ -267,6 +378,11 @@ impl VirtualTcpSocket for LocalTcpStream { #[cfg(not(target_os = "windows"))] fn set_dontroute(&mut self, val: bool) -> Result<()> { + // TODO: + // Don't route is being set by WASIX which breaks networking + // Why this is being set is unknown but we need to disable + // the functionality for now as it breaks everything + let val = val as libc::c_int; let payload = &val as *const libc::c_int as *const libc::c_void; let err = unsafe { @@ -341,7 +457,16 @@ impl VirtualConnectedSocket for LocalTcpStream { } fn try_send(&mut self, data: &[u8]) -> Result { - self.stream.write(data).map_err(io_err_into_net_error) + let ret = self.stream.write(data).map_err(io_err_into_net_error); + match &ret { + Ok(0) | Err(NetworkError::WouldBlock) => { + if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard { + map.pop(InterestType::Writable); + } + } + _ => {} + } + ret } fn try_flush(&mut self) -> Result<()> { @@ -354,6 +479,13 @@ impl VirtualConnectedSocket for LocalTcpStream { fn try_recv(&mut self, buf: &mut [MaybeUninit]) -> Result { let buf: &mut [u8] = unsafe { std::mem::transmute(buf) }; + if !self.buffer.is_empty() { + let amt = buf.len().min(self.buffer.len()); + buf[..amt].copy_from_slice(&self.buffer[..amt]); + self.buffer.advance(amt); + return Ok(amt); + } + self.stream.read(buf).map_err(io_err_into_net_error) } } @@ -376,13 +508,16 @@ impl VirtualSocket for LocalTcpStream { } fn set_handler(&mut self, mut handler: Box) -> Result<()> { - if let Some(guard) = self.handler_guard.take() { - InterestGuard::unregister(guard, &self.selector, &mut self.stream); - } + if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard { + match guard.replace_handler(handler) { + Ok(()) => return Ok(()), + Err(h) => handler = h, + } - if self.first_handler_writeable { - self.first_handler_writeable = false; - handler.interest(virtual_mio::InterestType::Writable); + // the handler could not be replaced so we need to build a new handler instead + if let Err(err) = guard.unregister(&mut self.stream) { + tracing::debug!("failed to unregister previous token - {}", err); + } } let guard = InterestGuard::new( @@ -393,18 +528,120 @@ impl VirtualSocket for LocalTcpStream { ) .map_err(io_err_into_net_error)?; - self.handler_guard.replace(guard); + self.handler_guard = HandlerGuardState::ExternalHandler(guard); Ok(()) } } +impl LocalTcpStream { + fn split_borrow( + &mut self, + ) -> ( + &mut HandlerGuardState, + &Arc, + &mut mio::net::TcpStream, + &mut BytesMut, + ) { + ( + &mut self.handler_guard, + &self.selector, + &mut self.stream, + &mut self.buffer, + ) + } +} + impl VirtualIoSource for LocalTcpStream { fn remove_handler(&mut self) { - if let Some(guard) = self.handler_guard.take() { - InterestGuard::unregister(guard, &self.selector, &mut self.stream); + let mut guard = HandlerGuardState::None; + std::mem::swap(&mut guard, &mut self.handler_guard); + match guard { + HandlerGuardState::ExternalHandler(mut guard) => { + guard.unregister(&mut self.stream).ok(); + } + HandlerGuardState::WakerMap(mut guard, _) => { + guard.unregister(&mut self.stream).ok(); + } + HandlerGuardState::None => {} } } + + fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if !self.buffer.is_empty() { + return Poll::Ready(Ok(self.buffer.len())); + } + + let (state, selector, stream, buffer) = self.split_borrow(); + let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?; + map.pop(InterestType::Readable); + map.add(InterestType::Readable, cx.waker()); + + buffer.reserve(buffer.len() + 10240); + let uninit: &mut [MaybeUninit] = buffer.spare_capacity_mut(); + let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) }; + + match stream.read(uninit_unsafe) { + Ok(0) => Poll::Ready(Ok(0)), + Ok(amt) => { + unsafe { + buffer.set_len(buffer.len() + amt); + } + Poll::Ready(Ok(amt)) + } + Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)), + Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(err) => Poll::Ready(Err(io_err_into_net_error(err))), + } + } + + fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + let (state, selector, stream, _) = self.split_borrow(); + let map = state_as_waker_map(state, selector, stream).map_err(io_err_into_net_error)?; + #[cfg(not(target_os = "windows"))] + map.pop(InterestType::Writable); + map.add(InterestType::Writable, cx.waker()); + map.add(InterestType::Closed, cx.waker()); + if map.has_interest(InterestType::Closed) { + return Poll::Ready(Ok(0)); + } + + #[cfg(not(target_os = "windows"))] + match libc_poll(stream.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) { + Some(val) if (val & libc::POLLHUP) != 0 => { + return Poll::Ready(Ok(0)); + } + Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)), + _ => {} + } + + // In windows we can not poll the socket as it is not supported and hence + // what we do is immediately set the writable flag and relay on `mio` to + // refresh that flag when the state changes. In Linux what we do is actually + // make a non-blocking `poll` call to determine this state + #[cfg(target_os = "windows")] + if map.has_interest(InterestType::Writable) { + return Poll::Ready(Ok(10240)); + } + + Poll::Pending + } +} + +#[cfg(not(target_os = "windows"))] +fn libc_poll(fd: RawFd, events: libc::c_short) -> Option { + let mut fds: [libc::pollfd; 1] = [libc::pollfd { + fd, + events, + revents: 0, + }]; + let fds_mut = &mut fds[..]; + let ret = unsafe { libc::poll(fds_mut.as_mut_ptr(), 1, 0) }; + match ret == 1 { + true => Some(fds[0].revents), + false => None, + } } #[derive(Debug)] @@ -413,7 +650,8 @@ pub struct LocalUdpSocket { #[allow(dead_code)] addr: SocketAddr, selector: Arc, - handler_guard: Option, + handler_guard: HandlerGuardState, + backlog: VecDeque<(BytesMut, SocketAddr)>, } impl VirtualUdpSocket for LocalUdpSocket { @@ -497,9 +735,19 @@ impl VirtualUdpSocket for LocalUdpSocket { impl VirtualConnectionlessSocket for LocalUdpSocket { fn try_send_to(&mut self, data: &[u8], addr: SocketAddr) -> Result { - self.socket + let ret = self + .socket .send_to(data, addr) - .map_err(io_err_into_net_error) + .map_err(io_err_into_net_error); + match &ret { + Ok(0) | Err(NetworkError::WouldBlock) => { + if let HandlerGuardState::WakerMap(_, map) = &mut self.handler_guard { + map.pop(InterestType::Writable); + } + } + _ => {} + } + ret } fn try_recv_from(&mut self, buf: &mut [MaybeUninit]) -> Result<(usize, SocketAddr)> { @@ -525,9 +773,19 @@ impl VirtualSocket for LocalUdpSocket { Ok(SocketStatus::Opened) } - fn set_handler(&mut self, handler: Box) -> Result<()> { - if let Some(guard) = self.handler_guard.take() { - InterestGuard::unregister(guard, &self.selector, &mut self.socket); + fn set_handler(&mut self, mut handler: Box) -> Result<()> { + if let HandlerGuardState::ExternalHandler(guard) = &mut self.handler_guard { + match guard.replace_handler(handler) { + Ok(()) => { + return Ok(()); + } + Err(h) => handler = h, + } + + // the handler could not be replaced so we need to build a new handler instead + if let Err(err) = guard.unregister(&mut self.socket) { + tracing::debug!("failed to unregister previous token - {}", err); + } } let guard = InterestGuard::new( @@ -538,16 +796,96 @@ impl VirtualSocket for LocalUdpSocket { ) .map_err(io_err_into_net_error)?; - self.handler_guard.replace(guard); + self.handler_guard = HandlerGuardState::ExternalHandler(guard); Ok(()) } } +impl LocalUdpSocket { + fn split_borrow( + &mut self, + ) -> ( + &mut HandlerGuardState, + &Arc, + &mut mio::net::UdpSocket, + ) { + (&mut self.handler_guard, &self.selector, &mut self.socket) + } +} + impl VirtualIoSource for LocalUdpSocket { fn remove_handler(&mut self) { - if let Some(guard) = self.handler_guard.take() { - InterestGuard::unregister(guard, &self.selector, &mut self.socket); + let mut guard = HandlerGuardState::None; + std::mem::swap(&mut guard, &mut self.handler_guard); + match guard { + HandlerGuardState::ExternalHandler(mut guard) => { + guard.unregister(&mut self.socket).ok(); + } + HandlerGuardState::WakerMap(mut guard, _) => { + guard.unregister(&mut self.socket).ok(); + } + HandlerGuardState::None => {} + } + } + + fn poll_read_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + if !self.backlog.is_empty() { + let total = self.backlog.iter().map(|a| a.0.len()).sum(); + return Poll::Ready(Ok(total)); + } + + let (state, selector, socket) = self.split_borrow(); + let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?; + map.pop(InterestType::Readable); + map.add(InterestType::Readable, cx.waker()); + + let mut buffer = BytesMut::default(); + buffer.reserve(10240); + let uninit: &mut [MaybeUninit] = buffer.spare_capacity_mut(); + let uninit_unsafe: &mut [u8] = unsafe { std::mem::transmute(uninit) }; + + match self.socket.recv_from(uninit_unsafe) { + Ok((0, _)) => Poll::Ready(Ok(0)), + Ok((amt, peer)) => { + unsafe { + buffer.set_len(amt); + } + self.backlog.push_back((buffer, peer)); + Poll::Ready(Ok(amt)) + } + Err(err) if err.kind() == io::ErrorKind::ConnectionAborted => Poll::Ready(Ok(0)), + Err(err) if err.kind() == io::ErrorKind::ConnectionReset => Poll::Ready(Ok(0)), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + Err(err) => Poll::Ready(Err(io_err_into_net_error(err))), + } + } + + fn poll_write_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + let (state, selector, socket) = self.split_borrow(); + let map = state_as_waker_map(state, selector, socket).map_err(io_err_into_net_error)?; + #[cfg(not(target_os = "windows"))] + map.pop(InterestType::Writable); + map.add(InterestType::Writable, cx.waker()); + + #[cfg(not(target_os = "windows"))] + match libc_poll(socket.as_raw_fd(), libc::POLLOUT | libc::POLLHUP) { + Some(val) if (val & libc::POLLHUP) != 0 => { + return Poll::Ready(Ok(0)); + } + Some(val) if (val & libc::POLLOUT) != 0 => return Poll::Ready(Ok(10240)), + _ => {} } + + // In windows we can not poll the socket as it is not supported and hence + // what we do is immediately set the writable flag and relay on `mio` to + // refresh that flag when the state changes. In Linux what we do is actually + // make a non-blocking `poll` call to determine this state + #[cfg(target_os = "windows")] + if map.has_interest(InterestType::Writable) { + return Poll::Ready(Ok(10240)); + } + + Poll::Pending } } diff --git a/lib/virtual-net/src/lib.rs b/lib/virtual-net/src/lib.rs index 4d5b2e3220b..e8ba8d97e75 100644 --- a/lib/virtual-net/src/lib.rs +++ b/lib/virtual-net/src/lib.rs @@ -65,6 +65,12 @@ pub struct IpRoute { pub trait VirtualIoSource: fmt::Debug + Send + Sync + 'static { /// Removes a previously registered waker using a token fn remove_handler(&mut self); + + /// Polls the source to see if there is data waiting + fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll>; + + /// Polls the source to see if data can be sent + fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll>; } /// An implementation of virtual networking diff --git a/lib/virtual-net/src/server.rs b/lib/virtual-net/src/server.rs index 7168f80ebf5..bc33c5a404d 100644 --- a/lib/virtual-net/src/server.rs +++ b/lib/virtual-net/src/server.rs @@ -617,7 +617,7 @@ impl RemoteNetworkingServerDriver { // Now we attach the handler to the main listening socket let mut handler = Box::new(self.common.handler.clone().for_socket(socket_id)); - handler.interest(virtual_mio::InterestType::Readable); + handler.push_interest(virtual_mio::InterestType::Readable); self.process_inner_noop( move |socket| match socket { RemoteAdapterSocket::TcpListener { @@ -1574,7 +1574,7 @@ impl RemoteAdapterHandler { } } impl InterestHandler for RemoteAdapterHandler { - fn interest(&mut self, interest: virtual_mio::InterestType) { + fn push_interest(&mut self, interest: virtual_mio::InterestType) { let mut guard = self.state.lock().unwrap(); guard.driver_wakers.drain(..).for_each(|w| w.wake()); let socket_id = match self.socket_id { @@ -1585,6 +1585,30 @@ impl InterestHandler for RemoteAdapterHandler { guard.readable.insert(socket_id); } } + + fn pop_interest(&mut self, interest: virtual_mio::InterestType) -> bool { + let mut guard = self.state.lock().unwrap(); + let socket_id = match self.socket_id { + Some(s) => s, + None => return false, + }; + if interest == virtual_mio::InterestType::Readable { + return guard.readable.remove(&socket_id); + } + false + } + + fn has_interest(&self, interest: virtual_mio::InterestType) -> bool { + let guard = self.state.lock().unwrap(); + let socket_id = match self.socket_id { + Some(s) => s, + None => return false, + }; + if interest == virtual_mio::InterestType::Readable { + return guard.readable.contains(&socket_id); + } + false + } } type SocketMap = HashMap; diff --git a/lib/virtual-net/src/tests.rs b/lib/virtual-net/src/tests.rs index 7a59a6988fd..960a4d2cf8b 100644 --- a/lib/virtual-net/src/tests.rs +++ b/lib/virtual-net/src/tests.rs @@ -210,3 +210,241 @@ async fn test_tcp_with_large_pipe_json_using_cbor() { let (client, server) = setup_pipe(1024000, FrameSerializationFormat::Cbor).await; test_tcp(client, server).await } + +#[cfg(target_os = "linux")] +#[traced_test] +#[tokio::test] +async fn test_google_poll() { + use futures_util::Future; + + // Resolve the address + tracing::info!("resolving www.google.com"); + let networking = LocalNetworking::new(); + let peer_addr = networking + .resolve("www.google.com", None, None) + .await + .unwrap() + .into_iter() + .next() + .expect("IP address should be returned"); + tracing::info!("www.google.com = {}", peer_addr); + + // Start the connection + tracing::info!("connecting to {}:80", peer_addr); + let mut socket = networking + .connect_tcp( + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + SocketAddr::new(peer_addr, 80), + ) + .await + .unwrap(); + tracing::info!("setting nodelay"); + socket.set_nodelay(true).unwrap(); + tracing::info!("setting keepalive"); + socket.set_keepalive(true).unwrap(); + + // Wait for it to be ready to send packets + tracing::info!("waiting for write_ready"); + struct Poller<'a> { + socket: &'a mut Box, + } + impl<'a> Future for Poller<'a> { + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.socket.poll_write_ready(cx) + } + } + Poller { + socket: &mut socket, + } + .await; + + // Send the data (GET http request) + let data = + b"GET / HTTP/1.1\r\nHost: www.google.com\r\nUser-Agent: curl/7.81.0\r\nAccept: */*\r\nConnection: Close\r\n\r\n"; + tracing::info!("sending {} bytes", data.len()); + let sent = socket.send(data).await.unwrap(); + assert_eq!(sent, data.len()); + + // Enter a loop that will return all the data + loop { + // Wait for the next bit of data + tracing::info!("waiting for read ready"); + struct Poller<'a> { + socket: &'a mut Box, + } + impl<'a> Future for Poller<'a> { + type Output = Result; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.socket.poll_read_ready(cx) + } + } + Poller { + socket: &mut socket, + } + .await; + + // Now read the data + let mut buf = [0u8; 4096]; + match socket.read(&mut buf).await { + Ok(0) => break, + Ok(amt) => { + tracing::info!("received {amt} bytes"); + continue; + } + Err(err) => { + tracing::info!("failed - {}", err); + panic!("failed to receive data"); + } + } + } + + tracing::info!("done"); +} + +#[cfg(target_os = "linux")] +#[traced_test] +#[tokio::test] +async fn test_google_epoll() { + use futures_util::Future; + use virtual_mio::SharedWakerInterestHandler; + + // Resolve the address + tracing::info!("resolving www.google.com"); + let networking = LocalNetworking::new(); + let peer_addr = networking + .resolve("www.google.com", None, None) + .await + .unwrap() + .into_iter() + .next() + .expect("IP address should be returned"); + tracing::info!("www.google.com = {}", peer_addr); + + // Start the connection + tracing::info!("connecting to {}:80", peer_addr); + let mut socket = networking + .connect_tcp( + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), + SocketAddr::new(peer_addr, 80), + ) + .await + .unwrap(); + tracing::info!("setting nodelay"); + socket.set_nodelay(true).unwrap(); + tracing::info!("setting keepalive"); + socket.set_keepalive(true).unwrap(); + + // Wait for it to be ready to send packets + tracing::info!("waiting for writability"); + struct Poller<'a> { + handler: Option>, + socket: &'a mut Box, + } + impl<'a> Future for Poller<'a> { + type Output = Result<()>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.handler.is_none() { + self.handler + .replace(SharedWakerInterestHandler::new(cx.waker())); + let handler = self.handler.as_ref().unwrap().clone(); + self.socket.set_handler(handler); + } + if self + .handler + .as_mut() + .unwrap() + .pop_interest(InterestType::Writable) + { + return Poll::Ready(Ok(())); + } + Poll::Pending + } + } + Poller { + handler: None, + socket: &mut socket, + } + .await; + + // Send the data (GET http request) + let data = + b"GET / HTTP/1.1\r\nHost: www.google.com\r\nUser-Agent: curl/7.81.0\r\nAccept: */*\r\nConnection: Close\r\n\r\n"; + tracing::info!("sending {} bytes", data.len()); + let sent = socket.try_send(data).unwrap(); + assert_eq!(sent, data.len()); + + // We detect if there are lots of false positives, that means something has gone + // wrong with the epoll implementation + let mut false_interest = 0usize; + + // Enter a loop that will return all the data + loop { + // Wait for the next bit of data + tracing::info!("waiting for readability"); + struct Poller<'a> { + handler: Option>, + socket: &'a mut Box, + } + impl<'a> Future for Poller<'a> { + type Output = Result<()>; + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.handler.is_none() { + self.handler + .replace(SharedWakerInterestHandler::new(cx.waker())); + let handler = self.handler.as_ref().unwrap().clone(); + self.socket.set_handler(handler); + } + if self + .handler + .as_mut() + .unwrap() + .pop_interest(InterestType::Readable) + { + return Poll::Ready(Ok(())); + } + Poll::Pending + } + } + Poller { + handler: None, + socket: &mut socket, + } + .await; + + // Now read the data until we block + let mut done = false; + for n in 0.. { + let mut buf: [MaybeUninit; 4096] = [MaybeUninit::uninit(); 4096]; + match socket.try_recv(&mut buf) { + Ok(0) => { + done = true; + break; + } + Ok(amt) => { + tracing::info!("received {amt} bytes"); + continue; + } + Err(NetworkError::WouldBlock) => { + if n == 0 { + false_interest += 1; + } + break; + } + Err(err) => { + tracing::info!("failed - {}", err); + panic!("failed to receive data"); + } + } + } + if done { + break; + } + } + + if false_interest > 20 { + panic!("too many false positives on the epoll ({false_interest}), something has likely gone wrong") + } + + tracing::info!("done"); +} diff --git a/lib/wasi-web/Cargo.lock b/lib/wasi-web/Cargo.lock index 9eb55fa06b0..65019adefea 100644 --- a/lib/wasi-web/Cargo.lock +++ b/lib/wasi-web/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "addr2line" -version = "0.20.0" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4fa78e18c64fce05e902adecd7a5eed15a5e0a3439f7b0e169f0252214865e3" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" dependencies = [ "gimli", ] @@ -30,9 +30,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6748e8def348ed4d14996fa801f4122cd763fff530258cdc03f64b25f89d3a5a" +checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" dependencies = [ "memchr", ] @@ -72,7 +72,7 @@ checksum = "bc00ceb34980c03614e35a3a4e218276a0a824e911d07651cd0d858a51e8c0f0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -92,9 +92,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "backtrace" -version = "0.3.68" +version = "0.3.69" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4319208da049c43661739c5fade2ba182f09d1dc2299b32298d3a31692b17e12" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" dependencies = [ "addr2line", "cc", @@ -113,9 +113,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "9ba43ea6f343b788c8764558649e08df62f86c6ef251fdaeb1ffd010a9ae50a2" [[package]] name = "bincode" @@ -213,9 +213,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "bytes" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" dependencies = [ "serde", ] @@ -237,16 +237,16 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.26" +version = "0.4.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec837a71355b28f6556dbd569b37b3f363091c0bd4b2e735674521b4c5fd9bc5" +checksum = "defd4e7873dbddba6c7c91e199c7fcb946abc4a6a4ac3195400bcfb01b5de877" dependencies = [ "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "wasm-bindgen", - "winapi", + "windows-targets", ] [[package]] @@ -273,9 +273,9 @@ checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "corosensei" -version = "0.1.3" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9847f90f32a50b0dcbd68bc23ff242798b13080b97b0569f6ed96a45ce4cf2cd" +checksum = "80128832c58ea9cbd041d2a759ec449224487b2c1e400453d99d244eead87a8e" dependencies = [ "autocfg", "cfg-if", @@ -381,7 +381,7 @@ dependencies = [ "ident_case", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -403,14 +403,14 @@ checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5" dependencies = [ "darling_core 0.20.3", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] name = "dashmap" -version = "5.5.1" +version = "5.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edd72493923899c6f10c641bdbdeddc7183d6396641d99c1a0d1597f37f92e28" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", "hashbrown 0.14.0", @@ -494,9 +494,9 @@ checksum = "8ea6672d73216c05740850c789368d371ca226dc8104d5f2e30c74252d5d6e5e" [[package]] name = "educe" -version = "0.4.22" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "079044df30bb07de7d846d41a184c4b00e66ebdac93ee459253474f3a47e50ae" +checksum = "0f0042ff8246a363dbe77d2ceedb073339e85a804b9a47636c6e016a9a32c05f" dependencies = [ "enum-ordinalize", "proc-macro2", @@ -534,7 +534,7 @@ dependencies = [ "num-traits", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -555,7 +555,7 @@ dependencies = [ "darling 0.20.3", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -566,9 +566,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b30f669a7961ef1631673d2766cc92f52d64f7ef354d4fe0ddfd30ed52f0f4f" +checksum = "136526188508e25c6fef639d7927dfb3e0e3084488bf202267829cf7fc23dbdd" dependencies = [ "errno-dragonfly", "libc", @@ -705,7 +705,7 @@ checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -763,9 +763,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.27.3" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" [[package]] name = "glob" @@ -979,9 +979,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "1a9bad9f94746442c783ca431b22403b519cd7fbeed0533fdd6328b2f2212128" [[package]] name = "lock_api" @@ -1010,9 +1010,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memmap2" @@ -1068,9 +1068,9 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f93ab6289c7b344a8a9f60f88d80aa20032336fe78da341afc91c8a2341fc75f" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ "autocfg", "num-integer", @@ -1119,9 +1119,9 @@ dependencies = [ [[package]] name = "object" -version = "0.31.1" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8bda667d9f2b5051b8833f59f3bf748b28ef54f850f4fcb389a252aa383866d1" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" dependencies = [ "memchr", ] @@ -1221,14 +1221,14 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] name = "pin-project-lite" -version = "0.2.12" +version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12cc1b0bf1727a77a54b6654e7b5f1af8604923edc8b81885f8ec92f9e3f0a05" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" [[package]] name = "pin-utils" @@ -1381,9 +1381,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.9.3" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -1393,9 +1393,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.6" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -1404,9 +1404,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "region" @@ -1481,9 +1481,9 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.8" +version = "0.38.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ed4fa021d81c8392ce04db050a3da9a60299050b7ae1cf482d862b54a7218f" +checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" dependencies = [ "bitflags 2.4.0", "errno", @@ -1542,9 +1542,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.171" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e27d1e4fd7659406c492fd6cfaf2066ba8773de45ca75e855590f856dc34a9" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] @@ -1572,20 +1572,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.171" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389894603bd18c46fa56231694f8d827779c0951a667087194cf9de94ed24682" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" dependencies = [ "itoa", "ryu", @@ -1677,9 +1677,9 @@ checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" [[package]] name = "slab" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ "autocfg", ] @@ -1724,9 +1724,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.29" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -1790,22 +1790,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -1820,9 +1820,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.26" +version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a79d09ac6b08c1ab3906a2f7cc2e81a0e27c7ae89c63812df75e52bef0751e07" +checksum = "17f6bb557fd245c28e6411aa56b6403c689ad95061f50e4be16c274e70a17e48" dependencies = [ "deranged", "itoa", @@ -1839,9 +1839,9 @@ checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" [[package]] name = "time-macros" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75c65469ed6b3a4809d987a41eb1dc918e9bc1d92211cbad7ae82931846f7451" +checksum = "1a942f44339478ef67935ab2bbaec2fb0322496cf3cbe84b261e06ac3814c572" dependencies = [ "time-core", ] @@ -1881,7 +1881,7 @@ checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -1924,9 +1924,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.7.6" +version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c17e963a819c331dcacd7ab957d80bc2b9a9c1e71c804826d2f283dd65306542" +checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ "serde", "serde_spanned", @@ -1945,9 +1945,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.19.14" +version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8123f27e969974a3dfba720fdb560be359f57b44302d280ba72e76a74480e8a" +checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ "indexmap 2.0.0", "serde", @@ -1977,7 +1977,7 @@ checksum = "5f4f31f56159e98206da9efd823404b79b6ef3143b4a7ab76e67b1751b25a4ab" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -2098,9 +2098,9 @@ checksum = "f28467d3e1d3c6586d8f25fa243f544f5800fec42d97032474e17222c2b75cfa" [[package]] name = "url" -version = "2.4.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +checksum = "143b538f18257fac9cad154828a57c6bf5157e1aa604d4816b5995bf6de87ae5" dependencies = [ "form_urlencoded", "idna", @@ -2172,7 +2172,7 @@ version = "0.5.0" dependencies = [ "anyhow", "async-trait", - "base64 0.21.2", + "base64 0.21.4", "bincode", "bytes", "derivative", @@ -2298,9 +2298,9 @@ checksum = "9d5b2c62b4012a3e1eca5a7e077d13b3bf498c4073e33ccd58626607748ceeca" [[package]] name = "walkdir" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" dependencies = [ "same-file", "winapi-util", @@ -2335,7 +2335,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", "wasm-bindgen-shared", ] @@ -2392,7 +2392,7 @@ checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2429,9 +2429,9 @@ dependencies = [ [[package]] name = "wasm-encoder" -version = "0.31.1" +version = "0.32.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41763f20eafed1399fff1afb466496d3a959f58241436cfdc17e3f5ca954de16" +checksum = "1ba64e81215916eaeb48fee292f29401d69235d62d8b8fd92a7b2844ec5ae5f7" dependencies = [ "leb128", ] @@ -2696,9 +2696,9 @@ dependencies = [ [[package]] name = "wast" -version = "63.0.0" +version = "64.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2560471f60a48b77fccefaf40796fda61c97ce1e790b59dfcec9dc3995c9f63a" +checksum = "a259b226fd6910225aa7baeba82f9d9933b6d00f2ce1b49b80fa4214328237cc" dependencies = [ "leb128", "memchr", @@ -2708,9 +2708,9 @@ dependencies = [ [[package]] name = "wat" -version = "1.0.70" +version = "1.0.71" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bdc306c2c4c2f2bf2ba69e083731d0d2a77437fc6a350a19db139636e7e416c" +checksum = "53253d920ab413fca1c7dc2161d601c79b4fdf631d0ba51dd4343bf9b556c3f6" dependencies = [ "wast", ] @@ -2732,7 +2732,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5c35d27cb4c7898571b5f25036ead587736ffb371261f9e928a28edee7abf9d" dependencies = [ "anyhow", - "base64 0.21.2", + "base64 0.21.4", "byteorder", "bytes", "flate2", @@ -2751,7 +2751,7 @@ dependencies = [ "tar", "tempfile", "thiserror", - "toml 0.7.6", + "toml 0.7.8", "url", "walkdir", "wasmer-toml", @@ -2914,9 +2914,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "winnow" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d09770118a7eb1ccaf4a594a221334119a44a814fcb0d31c5b85e83e97227a97" +checksum = "7c2e3184b9c4e92ad5167ca73039d0c42476302ab603e2fec4487511f38ccefc" dependencies = [ "memchr", ] diff --git a/lib/wasix/src/fs/inode_guard.rs b/lib/wasix/src/fs/inode_guard.rs index c6d571c9613..113c21cd369 100644 --- a/lib/wasix/src/fs/inode_guard.rs +++ b/lib/wasix/src/fs/inode_guard.rs @@ -1,7 +1,6 @@ use std::{ future::Future, io::{IoSlice, SeekFrom}, - mem::replace, ops::{Deref, DerefMut}, pin::Pin, sync::{Arc, RwLock}, @@ -11,8 +10,6 @@ use std::{ use futures::future::BoxFuture; use tokio::io::{AsyncRead, AsyncSeek, AsyncWrite}; use virtual_fs::{FsError, Pipe as VirtualPipe, VirtualFile}; -use virtual_mio::{InterestType, StatefulHandler}; -use virtual_net::net_error_into_io_err; use wasmer_wasix_types::{ types::Eventtype, wasi::{self, EpollType}, @@ -166,7 +163,6 @@ impl Future for InodeValFilePollGuardJoin { fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { // Otherwise we need to register for the event - let fd = self.fd(); let waker = cx.waker(); let mut has_read = false; let mut has_write = false; @@ -203,38 +199,7 @@ impl Future for InodeValFilePollGuardJoin { InodeValFilePollGuardMode::EventNotifications(inner) => inner.poll(waker).map(Ok), InodeValFilePollGuardMode::Socket { ref inner } => { let mut guard = inner.protected.write().unwrap(); - if guard.handler_state.take(InterestType::Readable) { - Poll::Ready(Ok(8192)) - } else { - let handler = - StatefulHandler::new(cx.waker().into(), guard.handler_state.clone()); - - let res = guard - .add_handler(handler, InterestType::Readable) - .map_err(net_error_into_io_err); - match res { - Err(err) if is_err_closed(&err) => { - tracing::trace!("socket read ready error (fd={}) - {}", fd, err); - if !replace(&mut guard.notifications.closed, true) { - Poll::Ready(Ok(0)) - } else { - Poll::Pending - } - } - Err(err) => { - tracing::debug!("poll socket error - {}", err); - if !replace(&mut guard.notifications.failed, true) { - Poll::Ready(Ok(0)) - } else { - Poll::Pending - } - } - Ok(()) => { - drop(guard); - Poll::Pending - } - } - } + guard.poll_read_ready(cx) } InodeValFilePollGuardMode::Pipe { pipe } => { let mut guard = pipe.write().unwrap(); @@ -324,42 +289,7 @@ impl Future for InodeValFilePollGuardJoin { InodeValFilePollGuardMode::EventNotifications(inner) => inner.poll(waker).map(Ok), InodeValFilePollGuardMode::Socket { ref inner } => { let mut guard = inner.protected.write().unwrap(); - if guard.handler_state.take(InterestType::Writable) { - Poll::Ready(Ok(8192)) - } else { - let handler = - StatefulHandler::new(cx.waker().into(), guard.handler_state.clone()); - - let res = guard - .add_handler(handler, InterestType::Writable) - .map_err(net_error_into_io_err); - match res { - Err(err) if is_err_closed(&err) => { - tracing::trace!( - "socket write ready error (fd={}) - err={}", - fd, - err - ); - if !replace(&mut guard.notifications.closed, true) { - Poll::Ready(Ok(0)) - } else { - Poll::Pending - } - } - Err(err) => { - tracing::debug!("poll socket error - {}", err); - if !replace(&mut guard.notifications.failed, true) { - Poll::Ready(Ok(0)) - } else { - Poll::Pending - } - } - Ok(()) => { - drop(guard); - Poll::Pending - } - } - } + guard.poll_write_ready(cx) } InodeValFilePollGuardMode::Pipe { pipe } => { let mut guard = pipe.write().unwrap(); diff --git a/lib/wasix/src/net/socket.rs b/lib/wasix/src/net/socket.rs index c828914594e..ed7fb32819e 100644 --- a/lib/wasix/src/net/socket.rs +++ b/lib/wasix/src/net/socket.rs @@ -1,22 +1,21 @@ use std::{ future::Future, + io, mem::MaybeUninit, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::{Arc, RwLock}, - task::Poll, + task::{Context, Poll}, time::Duration, }; +use derivative::Derivative; #[cfg(feature = "enable-serde")] use serde_derive::{Deserialize, Serialize}; -use virtual_mio::{ - FilteredHandler, FilteredHandlerSubscriptions, InterestHandler, InterestType, - StatefulHandlerState, -}; +use virtual_mio::InterestHandler; use virtual_net::{ - NetworkError, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, VirtualTcpListener, - VirtualTcpSocket, VirtualUdpSocket, + net_error_into_io_err, NetworkError, VirtualIcmpSocket, VirtualNetworking, VirtualRawSocket, + VirtualTcpListener, VirtualTcpSocket, VirtualUdpSocket, }; use wasmer_types::MemorySize; use wasmer_wasix_types::wasi::{Addressfamily, Errno, Rights, SockProto, Sockoption, Socktype}; @@ -34,7 +33,8 @@ pub enum InodeHttpSocketType { Headers, } -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] //#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))] pub enum InodeSocketKind { PreSocket { @@ -54,6 +54,8 @@ pub enum InodeSocketKind { read_timeout: Option, accept_timeout: Option, connect_timeout: Option, + #[derivative(Debug = "ignore")] + handler: Option>, }, Icmp(Box), Raw(Box), @@ -159,16 +161,6 @@ pub enum TimeType { //#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))] pub(crate) struct InodeSocketProtected { pub kind: InodeSocketKind, - pub notifications: InodeSocketNotifications, - pub aggregate_handler: Option, - pub handler_state: StatefulHandlerState, -} - -#[derive(Debug, Default)] -//#[cfg_attr(feature = "enable-serde", derive(Serialize, Deserialize))] -pub(crate) struct InodeSocketNotifications { - pub closed: bool, - pub failed: bool, } #[derive(Debug)] @@ -184,21 +176,23 @@ pub struct InodeSocket { } impl InodeSocket { - pub fn new(kind: InodeSocketKind) -> Self { - let handler_state: StatefulHandlerState = Default::default(); - if let InodeSocketKind::TcpStream { .. } = &kind { - handler_state.set(InterestType::Writable); - } - Self { + pub fn new(kind: InodeSocketKind) -> virtual_net::Result { + let protected = InodeSocketProtected { kind }; + Ok(Self { inner: Arc::new(InodeSocketInner { - protected: RwLock::new(InodeSocketProtected { - kind, - notifications: Default::default(), - aggregate_handler: None, - handler_state, - }), + protected: RwLock::new(protected), }), - } + }) + } + + pub fn poll_read_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut inner = self.inner.protected.write().unwrap(); + inner.poll_read_ready(cx) + } + + pub fn poll_write_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut inner = self.inner.protected.write().unwrap(); + inner.poll_read_ready(cx) } pub async fn bind( @@ -272,7 +266,7 @@ impl InodeSocket { tokio::select! { socket = socket => { let socket = socket.map_err(net_error_into_wasi_err)?; - Ok(Some(InodeSocket::new(InodeSocketKind::UdpSocket { socket, peer: None }))) + Ok(Some(InodeSocket::new(InodeSocketKind::UdpSocket { socket, peer: None }).map_err(net_error_into_wasi_err)?)) }, _ = tasks.sleep_now(timeout) => Err(Errno::Timedout) } @@ -332,7 +326,7 @@ impl InodeSocket { Ok(Some(InodeSocket::new(InodeSocketKind::TcpListener { socket, accept_timeout: Some(timeout), - }))) + }).map_err(net_error_into_wasi_err)?)) }, _ = tasks.sleep_now(timeout) => Err(Errno::Timedout) } @@ -432,6 +426,7 @@ impl InodeSocket { let timeout = timeout.unwrap_or(Duration::from_secs(30)); + let handler; let connect = { let mut inner = self.inner.protected.write().unwrap(); match &mut inner.kind { @@ -443,8 +438,10 @@ impl InodeSocket { no_delay, keep_alive, dont_route, + handler: h, .. } => { + handler = h.take(); new_write_timeout = *write_timeout; new_read_timeout = *read_timeout; match *ty { @@ -490,15 +487,25 @@ impl InodeSocket { } }; - let socket = tokio::select! { + let mut socket = tokio::select! { res = connect => res.map_err(net_error_into_wasi_err)?, _ = tasks.sleep_now(timeout) => return Err(Errno::Timedout) }; - Ok(Some(InodeSocket::new(InodeSocketKind::TcpStream { + + if let Some(handler) = handler { + socket + .set_handler(handler) + .map_err(net_error_into_wasi_err)?; + } + + let socket = InodeSocket::new(InodeSocketKind::TcpStream { socket, write_timeout: new_write_timeout, read_timeout: new_read_timeout, - }))) + }) + .map_err(net_error_into_wasi_err)?; + + Ok(Some(socket)) } pub fn status(&self) -> Result { @@ -990,10 +997,9 @@ impl InodeSocket { Poll::Ready(Err(Errno::Again)) } Err(NetworkError::WouldBlock) if !self.handler_registered => { - let res = inner.set_handler(cx.waker().into()); - if let Err(err) = res { - return Poll::Ready(Err(net_error_into_wasi_err(err))); - } + inner + .set_handler(cx.waker().into()) + .map_err(net_error_into_wasi_err)?; drop(inner); self.handler_registered = true; continue; @@ -1068,10 +1074,9 @@ impl InodeSocket { Poll::Ready(Err(Errno::Again)) } Err(NetworkError::WouldBlock) if !self.handler_registered => { - let res = inner.set_handler(cx.waker().into()); - if let Err(err) = res { - return Poll::Ready(Err(net_error_into_wasi_err(err))); - } + inner + .set_handler(cx.waker().into()) + .map_err(net_error_into_wasi_err)?; self.handler_registered = true; drop(inner); continue; @@ -1157,10 +1162,9 @@ impl InodeSocket { Poll::Ready(Err(Errno::Again)) } Err(NetworkError::WouldBlock) if !self.handler_registered => { - let res = inner.set_handler(cx.waker().into()); - if let Err(err) = res { - return Poll::Ready(Err(net_error_into_wasi_err(err))); - } + inner + .set_handler(cx.waker().into()) + .map_err(net_error_into_wasi_err)?; self.handler_registered = true; drop(inner); continue; @@ -1234,10 +1238,9 @@ impl InodeSocket { Poll::Ready(Err(Errno::Again)) } Err(NetworkError::WouldBlock) if !self.handler_registered => { - let res = inner.set_handler(cx.waker().into()); - if let Err(err) = res { - return Poll::Ready(Err(net_error_into_wasi_err(err))); - } + inner + .set_handler(cx.waker().into()) + .map_err(net_error_into_wasi_err)?; self.handler_registered = true; continue; } @@ -1299,8 +1302,34 @@ impl InodeSocketProtected { InodeSocketKind::UdpSocket { socket, .. } => socket.remove_handler(), InodeSocketKind::Raw(socket) => socket.remove_handler(), InodeSocketKind::Icmp(socket) => socket.remove_handler(), - InodeSocketKind::PreSocket { .. } => {} + InodeSocketKind::PreSocket { handler, .. } => { + handler.take(); + } + } + } + + pub fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match &mut self.kind { + InodeSocketKind::TcpListener { socket, .. } => socket.poll_read_ready(cx), + InodeSocketKind::TcpStream { socket, .. } => socket.poll_read_ready(cx), + InodeSocketKind::UdpSocket { socket, .. } => socket.poll_read_ready(cx), + InodeSocketKind::Raw(socket) => socket.poll_read_ready(cx), + InodeSocketKind::Icmp(socket) => socket.poll_read_ready(cx), + InodeSocketKind::PreSocket { .. } => Poll::Pending, + } + .map_err(net_error_into_io_err) + } + + pub fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match &mut self.kind { + InodeSocketKind::TcpListener { socket, .. } => socket.poll_write_ready(cx), + InodeSocketKind::TcpStream { socket, .. } => socket.poll_write_ready(cx), + InodeSocketKind::UdpSocket { socket, .. } => socket.poll_write_ready(cx), + InodeSocketKind::Raw(socket) => socket.poll_write_ready(cx), + InodeSocketKind::Icmp(socket) => socket.poll_write_ready(cx), + InodeSocketKind::PreSocket { .. } => Poll::Pending, } + .map_err(net_error_into_io_err) } pub fn set_handler( @@ -1313,25 +1342,11 @@ impl InodeSocketProtected { InodeSocketKind::UdpSocket { socket, .. } => socket.set_handler(handler), InodeSocketKind::Raw(socket) => socket.set_handler(handler), InodeSocketKind::Icmp(socket) => socket.set_handler(handler), - InodeSocketKind::PreSocket { .. } => Err(virtual_net::NetworkError::NotConnected), - } - } - - pub fn add_handler( - &mut self, - handler: Box, - interest: InterestType, - ) -> virtual_net::Result<()> { - if self.aggregate_handler.is_none() { - let upper = FilteredHandler::new(); - let subs = upper.subscriptions().clone(); - - self.set_handler(upper)?; - self.aggregate_handler.replace(subs); + InodeSocketKind::PreSocket { handler: h, .. } => { + h.replace(handler); + Ok(()) + } } - let upper = self.aggregate_handler.as_mut().unwrap(); - upper.add_interest(interest, handler); - Ok(()) } } diff --git a/lib/wasix/src/syscalls/wasix/epoll_ctl.rs b/lib/wasix/src/syscalls/wasix/epoll_ctl.rs index 4df7aebe9af..16946c5c5cc 100644 --- a/lib/wasix/src/syscalls/wasix/epoll_ctl.rs +++ b/lib/wasix/src/syscalls/wasix/epoll_ctl.rs @@ -138,7 +138,7 @@ impl EpollHandler { } } impl InterestHandler for EpollHandler { - fn interest(&mut self, interest: InterestType) { + fn push_interest(&mut self, interest: InterestType) { let readiness = match interest { InterestType::Readable => EpollType::EPOLLIN, InterestType::Writable => EpollType::EPOLLOUT, @@ -149,6 +149,35 @@ impl InterestHandler for EpollHandler { i.interest.insert((self.fd, readiness)); }); } + + fn pop_interest(&mut self, interest: InterestType) -> bool { + let readiness = match interest { + InterestType::Readable => EpollType::EPOLLIN, + InterestType::Writable => EpollType::EPOLLOUT, + InterestType::Closed => EpollType::EPOLLHUP, + InterestType::Error => EpollType::EPOLLERR, + }; + let mut ret = false; + self.tx.send_modify(move |i| { + ret = i.interest.iter().any(|(_, b)| *b == readiness); + i.interest.retain(|(_, b)| *b != readiness); + }); + ret + } + + fn has_interest(&self, interest: InterestType) -> bool { + let readiness = match interest { + InterestType::Readable => EpollType::EPOLLIN, + InterestType::Writable => EpollType::EPOLLOUT, + InterestType::Closed => EpollType::EPOLLHUP, + InterestType::Error => EpollType::EPOLLERR, + }; + let mut ret = false; + self.tx.send_modify(move |i| { + ret = i.interest.iter().any(|(_, b)| *b == readiness); + }); + ret + } } fn inline_waker_wake(s: &EpollJoinWaker) { diff --git a/lib/wasix/src/syscalls/wasix/sock_accept.rs b/lib/wasix/src/syscalls/wasix/sock_accept.rs index 3a9910fa3d4..ce17a0f6c54 100644 --- a/lib/wasix/src/syscalls/wasix/sock_accept.rs +++ b/lib/wasix/src/syscalls/wasix/sock_accept.rs @@ -113,7 +113,8 @@ pub fn sock_accept_internal( socket: child, write_timeout: None, read_timeout: None, - }), + }) + .map_err(net_error_into_wasi_err)?, }; let inode = state .fs diff --git a/lib/wasix/src/syscalls/wasix/sock_open.rs b/lib/wasix/src/syscalls/wasix/sock_open.rs index 508e9706d33..52848d68317 100644 --- a/lib/wasix/src/syscalls/wasix/sock_open.rs +++ b/lib/wasix/src/syscalls/wasix/sock_open.rs @@ -48,7 +48,7 @@ pub fn sock_open( let kind = match ty { Socktype::Stream | Socktype::Dgram => Kind::Socket { - socket: InodeSocket::new(InodeSocketKind::PreSocket { + socket: wasi_try!(InodeSocket::new(InodeSocketKind::PreSocket { family: af, ty, pt, @@ -65,7 +65,9 @@ pub fn sock_open( read_timeout: None, accept_timeout: None, connect_timeout: None, - }), + handler: None, + }) + .map_err(net_error_into_wasi_err)), }, _ => return Errno::Notsup, };