From 643c990cb84c5a57380e139fe06c9535568a035f Mon Sep 17 00:00:00 2001 From: Chris Denton Date: Sat, 2 Aug 2025 22:13:18 +0000 Subject: [PATCH 1/2] windows-threading: Implement scope for Pool --- crates/libs/threading/readme.md | 18 +++--- crates/libs/threading/src/lib.rs | 10 ++-- crates/libs/threading/src/pool.rs | 71 +++++++++++++++++++++-- crates/tests/libs/threading/tests/pool.rs | 71 +++++++++++------------ 4 files changed, 116 insertions(+), 54 deletions(-) diff --git a/crates/libs/threading/readme.md b/crates/libs/threading/readme.md index 83406f2d7d..0d3effedf4 100644 --- a/crates/libs/threading/readme.md +++ b/crates/libs/threading/readme.md @@ -75,16 +75,16 @@ The `for_each` function uses a `Pool` object internally, which you can also use let set = std::sync::RwLock::>::default(); let pool = windows_threading::Pool::new(); pool.set_thread_limits(2, 10); +pool.scope(|pool| { + for _ in 0..10 { + pool.submit(|| { + windows_threading::sleep(10); + let mut writer = set.write().unwrap(); + *writer.entry(windows_threading::thread_id()).or_default() += 1; + }) + } +}); -for _ in 0..10 { - pool.submit(|| { - windows_threading::sleep(10); - let mut writer = set.write().unwrap(); - *writer.entry(windows_threading::thread_id()).or_default() += 1; - }) -} - -pool.join(); println!("{:#?}", set.read().unwrap()); ``` diff --git a/crates/libs/threading/src/lib.rs b/crates/libs/threading/src/lib.rs index 9d961f757a..0dc4893d12 100644 --- a/crates/libs/threading/src/lib.rs +++ b/crates/libs/threading/src/lib.rs @@ -35,11 +35,11 @@ where F: Fn(T) + Sync, T: Send, { - let pool = Pool::new(); - - for item in i { - pool.submit(|| f(item)); - } + Pool::with_scope(|pool| { + for item in i { + pool.submit(|| f(item)); + } + }); } /// The thread identifier of the calling thread. diff --git a/crates/libs/threading/src/pool.rs b/crates/libs/threading/src/pool.rs index 9352fd65d6..f36412c60c 100644 --- a/crates/libs/threading/src/pool.rs +++ b/crates/libs/threading/src/pool.rs @@ -1,4 +1,5 @@ use super::*; +use core::{marker::PhantomData, ops::Deref}; /// A `Pool` object represents a private thread pool with its own thread limits. /// @@ -25,6 +26,15 @@ impl Pool { Self(Box::new(e)) } + /// Convenience function for creating a new pool and calling [`scope`][Self::scope]. + pub fn with_scope<'env, F>(f: F) + where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>), + { + let pool = Pool::new(); + pool.scope(f); + } + /// Sets the thread limits for the `Pool` object. pub fn set_thread_limits(&self, min: u32, max: u32) { unsafe { @@ -33,16 +43,42 @@ impl Pool { } } - /// Submits the closure to run on the `Pool`. + /// Submit the closure to the thread pool. /// - /// The closure cannot outlive the `Pool` on which it runs. - pub fn submit<'a, F: FnOnce() + Send + 'a>(&'a self, f: F) { - // This is safe because the lifetime of the closure is bounded by the `Pool`. + /// * The closure must have `'static` lifetime as the thread may outlive the lifetime in which `submit` is called. + /// * The closure must be `Send` as it will be sent to another thread for execution. + pub fn submit(&self, f: F) { + // This is safe because the closure has a `'static` lifetime. unsafe { try_submit(&*self.0, f); } } + /// Create a scope for submitting closures. + /// + /// Within this scope local variables can be sent to the pool thread for execution. + /// This is possible because `scope` will wait for all submitted closures to finish before returning, + /// Note however that it will also wait for closures that were submitted from other threads. + pub fn scope<'env, F>(&self, f: F) + where + F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>), + { + struct DropGuard<'a>(&'a Pool); + impl Drop for DropGuard<'_> { + fn drop(&mut self) { + self.0.join(); + } + } + // Ensure that we always join the pool before returning. + let _guard = DropGuard(self); + let scope = Scope { + pool: self, + env: PhantomData, + scope: PhantomData, + }; + f(&scope); + } + /// Waits for all submissions to finish. /// /// Dropping the `Pool` will also wait for all submissions to finish. @@ -74,3 +110,30 @@ impl Drop for Pool { } } } + +/// A scope to submit closures in. +/// +/// See [`scope`][Pool::scope] for details. +pub struct Scope<'scope, 'env: 'scope> { + pool: &'scope Pool, + scope: PhantomData<&'scope mut &'scope ()>, + env: PhantomData<&'env mut &'env ()>, +} + +impl<'scope, 'env> Scope<'scope, 'env> { + /// Submits the closure to run on the `Pool`. + /// + /// The closure cannot outlive the `Scope` it's run in. + pub fn submit(&'scope self, f: F) { + unsafe { + try_submit(&*self.pool.0, f); + } + } +} + +impl Deref for Scope<'_, '_> { + type Target = Pool; + fn deref(&self) -> &Self::Target { + self.pool + } +} diff --git a/crates/tests/libs/threading/tests/pool.rs b/crates/tests/libs/threading/tests/pool.rs index 92db50448c..0596936fc3 100644 --- a/crates/tests/libs/threading/tests/pool.rs +++ b/crates/tests/libs/threading/tests/pool.rs @@ -3,23 +3,24 @@ fn join() { let pool = windows_threading::Pool::new(); let counter = std::sync::RwLock::::new(0); - for _ in 0..10 { - pool.submit(|| { - let mut writer = counter.write().unwrap(); - *writer += 1; - }); - } - - pool.join(); - assert_eq!(*counter.read().unwrap(), 10); - - for _ in 0..10 { - pool.submit(|| { - let mut writer = counter.write().unwrap(); - *writer += 1; - }); - } - + pool.scope(|pool| { + for _ in 0..10 { + pool.submit(|| { + let mut writer = counter.write().unwrap(); + *writer += 1; + }); + } + + pool.join(); + assert_eq!(*counter.read().unwrap(), 10); + + for _ in 0..10 { + pool.submit(|| { + let mut writer = counter.write().unwrap(); + *writer += 1; + }); + } + }); drop(pool); assert_eq!(*counter.read().unwrap(), 20); } @@ -42,16 +43,15 @@ fn multi() { let pool = windows_threading::Pool::new(); pool.set_thread_limits(2, 10); - - for _ in 0..10 { - pool.submit(|| { - windows_threading::sleep(10); - let mut writer = set.write().unwrap(); - writer.insert(windows_threading::thread_id()); - }) - } - - pool.join(); + pool.scope(|pool| { + for _ in 0..10 { + pool.submit(|| { + windows_threading::sleep(10); + let mut writer = set.write().unwrap(); + writer.insert(windows_threading::thread_id()); + }) + } + }); assert!(set.read().unwrap().len() > 1); } @@ -61,15 +61,14 @@ fn single() { let pool = windows_threading::Pool::new(); pool.set_thread_limits(1, 1); - - for _ in 0..10 { - pool.submit(|| { - let mut writer = set.write().unwrap(); - writer.insert(windows_threading::thread_id()); - }) - } - - pool.join(); + pool.scope(|pool| { + for _ in 0..10 { + pool.submit(|| { + let mut writer = set.write().unwrap(); + writer.insert(windows_threading::thread_id()); + }) + } + }); assert_eq!(set.read().unwrap().len(), 1); } From 68abc4674707adc7f62042f0792236c279e80df2 Mon Sep 17 00:00:00 2001 From: Chris Denton Date: Sun, 3 Aug 2025 15:04:57 +0000 Subject: [PATCH 2/2] Use Arc in service threading sample --- crates/samples/services/thread/src/main.rs | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/crates/samples/services/thread/src/main.rs b/crates/samples/services/thread/src/main.rs index bd9cc5dd06..0901393aa5 100644 --- a/crates/samples/services/thread/src/main.rs +++ b/crates/samples/services/thread/src/main.rs @@ -1,11 +1,17 @@ use windows_services::*; use windows_threading::*; +use std::sync::{Arc, RwLock}; + fn main() { - let pool = Pool::new(); + let pool: Pool = Pool::new(); pool.set_thread_limits(1, 1); - Service::new() + let service_original = Arc::new(RwLock::new(Service::new())); + let service = Arc::clone(&service_original); + service_original + .write() + .unwrap() .can_pause() .can_stop() .can_fallback(|_| { @@ -13,11 +19,13 @@ fn main() { use std::io::Read; _ = std::io::stdin().read(&mut [0]); }) - .run(|service, command| { + .run(move |_, command| { log(&format!("Command: {command:?}\n")); - match command { - Command::Start | Command::Resume => pool.submit(|| service_thread(service)), + Command::Start | Command::Resume => { + let service = Arc::clone(&service); + pool.submit(move || service_thread(service)) + } Command::Pause | Command::Stop => pool.join(), _ => {} } @@ -25,7 +33,7 @@ fn main() { .unwrap(); } -fn service_thread(service: &Service) { +fn service_thread(service: Arc>>) { for i in 0..10 { log(&format!("Thread:{}... iteration:{i}\n", thread_id())); @@ -33,12 +41,14 @@ fn service_thread(service: &Service) { sleep(1000); // Services can use the `state` function to query the current service state. + let service = service.read().unwrap(); if matches!(service.state(), State::StopPending | State::PausePending) { return; } } // Services can use the `set_state` function to update the service state. + let service = service.read().unwrap(); service.set_state(State::Stopped); }