diff --git a/glommio/src/executor/mod.rs b/glommio/src/executor/mod.rs index 5cc2d67ee..7c0184113 100644 --- a/glommio/src/executor/mod.rs +++ b/glommio/src/executor/mod.rs @@ -4133,6 +4133,19 @@ mod test { }); } + #[test] + fn wake_refcount_overflow() { + LocalExecutor::default().run(async { + const NUM_CLONES: usize = u16::MAX as usize; + + crate::spawn_local(poll_fn::<(), _>(move |cx| { + let _wakers = Vec::from_iter((0..NUM_CLONES).map(|_| cx.waker().clone())); + Poll::Ready(()) + })) + .await; + }) + } + #[test] fn blocking_function() { LocalExecutor::default().run(async { diff --git a/glommio/src/task/header.rs b/glommio/src/task/header.rs index 33fef6070..ab10b4259 100644 --- a/glommio/src/task/header.rs +++ b/glommio/src/task/header.rs @@ -7,7 +7,7 @@ use core::{fmt, task::Waker}; #[cfg(feature = "debugging")] use std::cell::Cell; use std::sync::{ - atomic::{AtomicI16, Ordering}, + atomic::{AtomicI32, Ordering}, Arc, }; @@ -16,6 +16,9 @@ use crate::{ task::{raw::TaskVTable, state::*, utils::abort_on_panic}, }; +pub(crate) type RefCount = i32; +pub(crate) type AtomicRefCount = AtomicI32; + /// The header of a task. /// /// This header is stored right at the beginning of every heap-allocated task. @@ -31,7 +34,7 @@ pub(crate) struct Header { pub(crate) latency_matters: bool, /// Current reference count of the task. - pub(crate) references: AtomicI16, + pub(crate) references: AtomicRefCount, /// The task that is blocked on the `JoinHandle`. /// diff --git a/glommio/src/task/join_handle.rs b/glommio/src/task/join_handle.rs index 07c84310e..1fdbc7200 100644 --- a/glommio/src/task/join_handle.rs +++ b/glommio/src/task/join_handle.rs @@ -16,7 +16,10 @@ use core::{ use crate::task::debugging::TaskDebugger; use crate::{ dbg_context, - task::{header::Header, state::*}, + task::{ + header::{Header, RefCount}, + state::*, + }, }; use std::sync::atomic::Ordering; @@ -70,7 +73,7 @@ impl JoinHandle { // If we schedule it, need to bump the reference count, since after run() we // decrement it. let refs = (*header).references.fetch_add(1, Ordering::Relaxed); - assert_ne!(refs, i16::MAX); + assert_ne!(refs, RefCount::MAX); ((*header).vtable.schedule)(ptr); } @@ -132,7 +135,7 @@ impl Drop for JoinHandle { if refs == 0 { if state & CLOSED == 0 { let refs = (*header).references.fetch_add(1, Ordering::Relaxed); - assert_ne!(refs, i16::MAX); + assert_ne!(refs, RefCount::MAX); ((*header).vtable.schedule)(ptr); } else { ((*header).vtable.destroy)(ptr); diff --git a/glommio/src/task/raw.rs b/glommio/src/task/raw.rs index a42a8e237..04b7bd787 100644 --- a/glommio/src/task/raw.rs +++ b/glommio/src/task/raw.rs @@ -13,14 +13,14 @@ use core::{ }; #[cfg(feature = "debugging")] use std::cell::Cell; -use std::sync::atomic::{AtomicI16, Ordering}; +use std::sync::atomic::Ordering; #[cfg(feature = "debugging")] use crate::task::debugging::TaskDebugger; use crate::{ dbg_context, sys, task::{ - header::Header, + header::{AtomicRefCount, Header, RefCount}, state::*, utils::{abort, abort_on_panic, extend}, Task, @@ -132,7 +132,7 @@ where notifier: sys::get_sleep_notifier_for(executor_id).unwrap(), state: SCHEDULED | HANDLE, latency_matters, - references: AtomicI16::new(0), + references: AtomicRefCount::new(0), awaiter: None, vtable: &TaskVTable { schedule: Self::schedule, @@ -276,12 +276,12 @@ where #[track_caller] fn increment_references(header: &Header) { let refs = header.references.fetch_add(1, Ordering::Relaxed); - assert_ne!(refs, i16::MAX, "Waker invariant broken: {header:?}"); + assert_ne!(refs, RefCount::MAX, "Waker invariant broken: {header:?}"); } #[inline] #[track_caller] - fn decrement_references(header: &Header) -> i16 { + fn decrement_references(header: &Header) -> RefCount { let refs = header.references.fetch_sub(1, Ordering::Relaxed); assert_ne!(refs, 0, "Waker invariant broken: {header:?}"); refs - 1 diff --git a/glommio/src/task/task_impl.rs b/glommio/src/task/task_impl.rs index 553ffd2f1..ade8ff1b0 100644 --- a/glommio/src/task/task_impl.rs +++ b/glommio/src/task/task_impl.rs @@ -9,7 +9,12 @@ use core::{fmt, future::Future, marker::PhantomData, mem, ptr::NonNull}; use crate::task::debugging::TaskDebugger; use crate::{ dbg_context, - task::{header::Header, raw::RawTask, state::*, JoinHandle}, + task::{ + header::{Header, RefCount}, + raw::RawTask, + state::*, + JoinHandle, + }, }; use std::sync::atomic::Ordering; @@ -133,7 +138,7 @@ impl Task { unsafe { let refs = (*header).references.fetch_add(1, Ordering::Relaxed); - assert_ne!(refs, i16::MAX); + assert_ne!(refs, RefCount::MAX); ((*header).vtable.run)(ptr) } }