diff --git a/src/concurrency/data_race.rs b/src/concurrency/data_race.rs index c18b780998..6ae083e61f 100644 --- a/src/concurrency/data_race.rs +++ b/src/concurrency/data_race.rs @@ -1944,7 +1944,7 @@ impl GlobalState { callback: impl FnOnce(&VClock) -> R, ) -> R { let thread = threads.active_thread(); - let span = threads.active_thread_ref().current_user_relevant_span(); + let span = threads.active_thread_ref().current_fiber().current_user_relevant_span(); let (index, mut clocks) = self.thread_state_mut(thread); let r = callback(&clocks.clock); // Increment the clock, so that all following events cannot be confused with anything that diff --git a/src/concurrency/thread.rs b/src/concurrency/thread.rs index 1c404d419e..b29c0173e1 100644 --- a/src/concurrency/thread.rs +++ b/src/concurrency/thread.rs @@ -164,25 +164,51 @@ enum ThreadJoinStatus { Joined, } -/// A thread. -pub struct Thread<'tcx> { - state: ThreadState<'tcx>, +/// A fiber identifier. +#[derive(Clone, Copy, Debug, PartialOrd, Ord, PartialEq, Eq, Hash)] +pub struct FiberId(u32); - /// Name of the thread. - thread_name: Option>, +impl FiberId { + pub fn to_u32(self) -> u32 { + self.0 + } - /// The virtual call stack. - stack: Vec>>, + /// Create a new fiber id from a `u32` without checking if this fiber exists. + pub fn new_unchecked(id: u32) -> Self { + Self(id) + } +} - /// A span that explains where the thread (or more specifically, its current root - /// frame) "comes from". - pub(crate) origin_span: Span, +impl Idx for FiberId { + fn new(idx: usize) -> Self { + FiberId(u32::try_from(idx).unwrap()) + } + + fn index(self) -> usize { + usize::try_from(self.0).unwrap() + } +} + +impl From for u64 { + fn from(t: FiberId) -> Self { + t.0.into() + } +} + +/// A fiber +pub struct Fiber<'fcx> { + /// Id of the fiber + #[allow(dead_code)] // TODO: Remove + pub(crate) id: FiberId, + + /// The virtual call stack. + stack: Vec>>, /// The function to call when the stack ran empty, to figure out what to do next. /// Conceptually, this is the interpreter implementation of the things that happen 'after' the /// Rust language entry point for this thread returns (usually implemented by the C or OS runtime). /// (`None` is an error, it means the callback has not been set up yet or is actively running.) - pub(crate) on_stack_empty: Option>, + pub(crate) on_stack_empty: Option>, /// The index of the topmost user-relevant frame in `stack`. This field must contain /// the value produced by `get_top_user_relevant_frame`. @@ -190,9 +216,6 @@ pub struct Thread<'tcx> { /// maintained inside `MiriMachine::after_stack_push` and `MiriMachine::after_stack_pop`. top_user_relevant_frame: Option, - /// The join status. - join_status: ThreadJoinStatus, - /// Stack of active unwind payloads for the current thread. Used for storing /// the argument of the call to `miri_start_unwind` (the payload) when unwinding. /// This is pointer-sized, and matches the `Payload` type in `src/libpanic_unwind/miri.rs`. @@ -201,35 +224,14 @@ pub struct Thread<'tcx> { /// which then forwards it to 'Resume'. However this argument is implicit in MIR, /// so we have to store it out-of-band. When there are multiple active unwinds, /// the innermost one is always caught first, so we can store them as a stack. - pub(crate) unwind_payloads: Vec>, + pub(crate) unwind_payloads: Vec>, - /// Last OS error location in memory. It is a 32-bit integer. - pub(crate) last_error: Option>, + /// A span that explains where the fiber (or more specifically, its current root + /// frame) "comes from". + pub(crate) origin_span: Span, } -pub type StackEmptyCallback<'tcx> = - Box) -> InterpResult<'tcx, Poll<()>> + 'tcx>; - -impl<'tcx> Thread<'tcx> { - /// Get the name of the current thread if it was set. - fn thread_name(&self) -> Option<&[u8]> { - self.thread_name.as_deref() - } - - /// Return whether this thread is enabled or not. - pub fn is_enabled(&self) -> bool { - self.state.is_enabled() - } - - /// Get the name of the current thread for display purposes; will include thread ID if not set. - fn thread_display_name(&self, id: ThreadId) -> String { - if let Some(ref thread_name) = self.thread_name { - String::from_utf8_lossy(thread_name).into_owned() - } else { - format!("unnamed-{}", id.index()) - } - } - +impl<'fcx> Fiber<'fcx> { /// Return the top user-relevant frame, if there is one. `skip` indicates how many top frames /// should be skipped. /// Note that the choice to return `None` here when there is no user-relevant frame is part of @@ -287,6 +289,87 @@ impl<'tcx> Thread<'tcx> { .map(|frame_idx| self.stack[frame_idx].current_span()) .unwrap_or(rustc_span::DUMMY_SP) } + + fn new(on_stack_empty: Option>, id: FiberId) -> Self { + Self { + id, + stack: Vec::new(), + origin_span: DUMMY_SP, + top_user_relevant_frame: None, + unwind_payloads: Vec::new(), + on_stack_empty, + } + } +} + +impl VisitProvenance for Fiber<'_> { + fn visit_provenance(&self, visit: &mut VisitWith<'_>) { + let Fiber { + id: _, + stack, + origin_span: _, + top_user_relevant_frame: _, + unwind_payloads: panic_payload, + on_stack_empty: _, // we assume the closure captures no GC-relevant state + } = self; + + for payload in panic_payload { + payload.visit_provenance(visit); + } + + for frame in stack { + frame.visit_provenance(visit); + } + } +} + +/// A thread. +pub struct Thread<'tcx> { + state: ThreadState<'tcx>, + + /// Name of the thread. + thread_name: Option>, + + /// The fiber this thread currently runs + current_fiber: Fiber<'tcx>, + + /// The join status. + join_status: ThreadJoinStatus, + + /// Last OS error location in memory. It is a 32-bit integer. + pub(crate) last_error: Option>, +} + +pub type StackEmptyCallback<'tcx> = + Box) -> InterpResult<'tcx, Poll<()>> + 'tcx>; + +impl<'tcx> Thread<'tcx> { + /// Get the name of the current thread if it was set. + fn thread_name(&self) -> Option<&[u8]> { + self.thread_name.as_deref() + } + + /// Return whether this thread is enabled or not. + pub fn is_enabled(&self) -> bool { + self.state.is_enabled() + } + + /// Get the name of the current thread for display purposes; will include thread ID if not set. + fn thread_display_name(&self, id: ThreadId) -> String { + if let Some(ref thread_name) = self.thread_name { + String::from_utf8_lossy(thread_name).into_owned() + } else { + format!("unnamed-{}", id.index()) + } + } + + pub fn current_fiber(&self) -> &Fiber<'tcx> { + &self.current_fiber + } + + pub fn current_fiber_mut(&mut self) -> &mut Fiber<'tcx> { + &mut self.current_fiber + } } impl<'tcx> std::fmt::Debug for Thread<'tcx> { @@ -302,42 +385,27 @@ impl<'tcx> std::fmt::Debug for Thread<'tcx> { } impl<'tcx> Thread<'tcx> { - fn new(name: Option<&str>, on_stack_empty: Option>) -> Self { + fn new( + name: Option<&str>, + on_stack_empty: Option>, + fiber_id: FiberId, + ) -> Self { Self { state: ThreadState::Enabled, thread_name: name.map(|name| Vec::from(name.as_bytes())), - stack: Vec::new(), - origin_span: DUMMY_SP, - top_user_relevant_frame: None, + current_fiber: Fiber::new(on_stack_empty, fiber_id), join_status: ThreadJoinStatus::Joinable, - unwind_payloads: Vec::new(), last_error: None, - on_stack_empty, } } } impl VisitProvenance for Thread<'_> { fn visit_provenance(&self, visit: &mut VisitWith<'_>) { - let Thread { - unwind_payloads: panic_payload, - last_error, - stack, - origin_span: _, - top_user_relevant_frame: _, - state: _, - thread_name: _, - join_status: _, - on_stack_empty: _, // we assume the closure captures no GC-relevant state - } = self; + let Thread { current_fiber, last_error, state: _, thread_name: _, join_status: _ } = self; - for payload in panic_payload { - payload.visit_provenance(visit); - } + current_fiber.visit_provenance(visit); last_error.visit_provenance(visit); - for frame in stack { - frame.visit_provenance(visit) - } } } @@ -421,9 +489,30 @@ pub enum TimeoutAnchor { #[derive(Debug, Copy, Clone)] pub struct ThreadNotFound; +#[derive(Debug, Copy, Clone)] +struct FiberIdAllocator { + next_id: u32, +} + +impl FiberIdAllocator { + fn new(_seed: Option) -> Self { + Self { next_id: 0 } + } + + fn alloc(&mut self) -> FiberId { + let id = self.next_id; + self.next_id = self.next_id.checked_add(1).unwrap(); + FiberId::new_unchecked(id) + } + + #[allow(dead_code)] + fn dealloc(&mut self, _id: FiberId) {} +} + /// A set of threads. #[derive(Debug)] pub struct ThreadManager<'tcx> { + fiber_id_allocator: FiberIdAllocator, /// Identifier of the currently active thread. active_thread: ThreadId, /// Threads used in the program. @@ -442,6 +531,7 @@ pub struct ThreadManager<'tcx> { impl VisitProvenance for ThreadManager<'_> { fn visit_provenance(&self, visit: &mut VisitWith<'_>) { let ThreadManager { + fiber_id_allocator: _, threads, thread_local_allocs, active_thread: _, @@ -461,9 +551,11 @@ impl VisitProvenance for ThreadManager<'_> { impl<'tcx> ThreadManager<'tcx> { pub(crate) fn new(config: &MiriConfig) -> Self { let mut threads = IndexVec::new(); + let mut fiber_id_allocator = FiberIdAllocator::new(config.seed); // Create the main thread and add it to the list of threads. - threads.push(Thread::new(Some("main"), None)); + threads.push(Thread::new(Some("main"), None, fiber_id_allocator.alloc())); Self { + fiber_id_allocator, active_thread: ThreadId::MAIN_THREAD, threads, thread_local_allocs: Default::default(), @@ -476,7 +568,7 @@ impl<'tcx> ThreadManager<'tcx> { ecx: &mut MiriInterpCx<'tcx>, on_main_stack_empty: StackEmptyCallback<'tcx>, ) { - ecx.machine.threads.threads[ThreadId::MAIN_THREAD].on_stack_empty = + ecx.machine.threads.threads[ThreadId::MAIN_THREAD].current_fiber.on_stack_empty = Some(on_main_stack_empty); if ecx.tcx.sess.target.os != Os::Windows { // The main thread can *not* be joined on except on windows. @@ -511,14 +603,14 @@ impl<'tcx> ThreadManager<'tcx> { /// Borrow the stack of the active thread. pub fn active_thread_stack(&self) -> &[Frame<'tcx, Provenance, FrameExtra<'tcx>>] { - &self.threads[self.active_thread].stack + &self.threads[self.active_thread].current_fiber.stack } /// Mutably borrow the stack of the active thread. pub fn active_thread_stack_mut( &mut self, ) -> &mut Vec>> { - &mut self.threads[self.active_thread].stack + &mut self.threads[self.active_thread].current_fiber.stack } pub fn all_blocked_stacks( @@ -527,13 +619,14 @@ impl<'tcx> ThreadManager<'tcx> { self.threads .iter_enumerated() .filter(|(_id, t)| matches!(t.state, ThreadState::Blocked { .. })) - .map(|(id, t)| (id, &t.stack[..])) + .map(|(id, t)| (id, &t.current_fiber.stack[..])) } /// Create a new thread and returns its id. fn create_thread(&mut self, on_stack_empty: StackEmptyCallback<'tcx>) -> ThreadId { let new_thread_id = ThreadId::new(self.threads.len()); - self.threads.push(Thread::new(None, Some(on_stack_empty))); + let new_fiber_id = self.fiber_id_allocator.alloc(); + self.threads.push(Thread::new(None, Some(on_stack_empty), new_fiber_id)); new_thread_id } @@ -715,13 +808,14 @@ trait EvalContextPrivExt<'tcx>: MiriInterpCxExt<'tcx> { fn run_on_stack_empty(&mut self) -> InterpResult<'tcx, Poll<()>> { let this = self.eval_context_mut(); let active_thread = this.active_thread_mut(); - active_thread.origin_span = DUMMY_SP; // reset, the old value no longer applied + active_thread.current_fiber.origin_span = DUMMY_SP; // reset, the old value no longer applied let mut callback = active_thread + .current_fiber .on_stack_empty .take() .expect("`on_stack_empty` not set up, or already running"); let res = callback(this)?; - this.active_thread_mut().on_stack_empty = Some(callback); + this.active_thread_mut().current_fiber.on_stack_empty = Some(callback); interp_ok(res) } @@ -968,7 +1062,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // Mark thread as terminated. let thread = this.active_thread_mut(); - assert!(thread.stack.is_empty(), "only threads with an empty stack can be terminated"); + assert!( + thread.current_fiber.stack.is_empty(), + "only threads with an empty stack can be terminated" + ); thread.state = ThreadState::Terminated; // Deallocate TLS. @@ -1217,6 +1314,18 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { this.machine.threads.active_thread_ref() } + #[inline] + fn active_fiber_ref(&self) -> &Fiber<'tcx> { + let this = self.eval_context_ref(); + this.machine.threads.active_thread_ref().current_fiber() + } + + #[inline] + fn active_fiber_mut(&mut self) -> &mut Fiber<'tcx> { + let this = self.eval_context_mut(); + this.machine.threads.active_thread_mut().current_fiber_mut() + } + #[inline] fn get_total_thread_count(&self) -> usize { let this = self.eval_context_ref(); diff --git a/src/diagnostics.rs b/src/diagnostics.rs index 64c7096fc5..6b85012e9f 100644 --- a/src/diagnostics.rs +++ b/src/diagnostics.rs @@ -554,7 +554,9 @@ fn report_msg<'tcx>( thread: Option, machine: &MiriMachine<'tcx>, ) { - let origin_span = thread.map(|t| machine.threads.thread_ref(t).origin_span).unwrap_or(DUMMY_SP); + let origin_span = thread + .map(|t| machine.threads.thread_ref(t).current_fiber().origin_span) + .unwrap_or(DUMMY_SP); let span = stacktrace.first().map(|fi| fi.span).unwrap_or(origin_span); // The only time we do not have an origin span is for `main`, and there we check the signature // upfront. So we should always have a span here. diff --git a/src/helpers.rs b/src/helpers.rs index f4fc478481..52bf118f67 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -482,8 +482,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { ) -> InterpResult<'tcx> { let this = self.eval_context_mut(); assert!(this.active_thread_stack().is_empty()); - assert!(this.active_thread_ref().origin_span.is_dummy()); - this.active_thread_mut().origin_span = span; + assert!(this.active_fiber_ref().origin_span.is_dummy()); + this.active_fiber_mut().origin_span = span; this.call_function(f, caller_abi, args, dest, ReturnContinuation::Stop { cleanup: true }) } @@ -1068,7 +1068,7 @@ impl<'tcx> MiriMachine<'tcx> { /// This function is backed by a cache, and can be assumed to be very fast. /// It will work even when the stack is empty. pub fn current_user_relevant_span(&self) -> Span { - self.threads.active_thread_ref().current_user_relevant_span() + self.threads.active_thread_ref().current_fiber().current_user_relevant_span() } /// Returns the span of the *caller* of the current operation, again @@ -1089,7 +1089,7 @@ impl<'tcx> MiriMachine<'tcx> { } fn top_user_relevant_frame(&self) -> Option { - self.threads.active_thread_ref().top_user_relevant_frame() + self.threads.active_thread_ref().current_fiber().top_user_relevant_frame() } /// This is the source of truth for the `user_relevance` flag in our `FrameExtra`. diff --git a/src/machine.rs b/src/machine.rs index f17bd5ac43..cc76e9930f 100644 --- a/src/machine.rs +++ b/src/machine.rs @@ -1773,11 +1773,11 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> { #[inline(always)] fn after_stack_push(ecx: &mut InterpCx<'tcx, Self>) -> InterpResult<'tcx> { - if ecx.frame().extra.user_relevance >= ecx.active_thread_ref().current_user_relevance() { + if ecx.frame().extra.user_relevance >= ecx.active_fiber_ref().current_user_relevance() { // We just pushed a frame that's at least as relevant as the so-far most relevant frame. // That means we are now the most relevant frame. let stack_len = ecx.active_thread_stack().len(); - ecx.active_thread_mut().set_top_user_relevant_frame(stack_len - 1); + ecx.active_fiber_mut().set_top_user_relevant_frame(stack_len - 1); } interp_ok(()) } @@ -1790,7 +1790,7 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> { ecx.on_stack_pop(frame)?; } if ecx - .active_thread_ref() + .active_fiber_ref() .top_user_relevant_frame() .expect("there should always be a most relevant frame for a non-empty stack") == ecx.frame_idx() @@ -1800,7 +1800,7 @@ impl<'tcx> Machine<'tcx> for MiriMachine<'tcx> { // (If this ever becomes a bottleneck, we could have `push` store the previous // user-relevant frame and restore that here.) // We have to skip the frame that is just being popped. - ecx.active_thread_mut().recompute_top_user_relevant_frame(/* skip */ 1); + ecx.active_fiber_mut().recompute_top_user_relevant_frame(/* skip */ 1); } // tracing-tree can autoamtically annotate scope changes, but it gets very confused by our // concurrency and what it prints is just plain wrong. So we print our own information diff --git a/src/shims/alloc.rs b/src/shims/alloc.rs index b4d53c36d1..ec24370e21 100644 --- a/src/shims/alloc.rs +++ b/src/shims/alloc.rs @@ -32,6 +32,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { | Arch::Mips32r6 | Arch::PowerPC | Arch::PowerPC64 + | Arch::PowerPC64LE | Arch::Sparc | Arch::Wasm32 | Arch::Hexagon diff --git a/src/shims/unwind.rs b/src/shims/unwind.rs index 0dd2b20487..7e6bd8b33e 100644 --- a/src/shims/unwind.rs +++ b/src/shims/unwind.rs @@ -50,8 +50,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { trace!("miri_start_unwind: {:?}", this.frame().instance()); let payload = this.read_immediate(payload)?; - let thread = this.active_thread_mut(); - thread.unwind_payloads.push(payload); + let fiber = this.active_fiber_mut(); + fiber.unwind_payloads.push(payload); interp_ok(()) } @@ -133,7 +133,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { // The Thread's `panic_payload` holds what was passed to `miri_start_unwind`. // This is exactly the second argument we need to pass to `catch_fn`. - let payload = this.active_thread_mut().unwind_payloads.pop().unwrap(); + let payload = this.active_fiber_mut().unwind_payloads.pop().unwrap(); // Push the `catch_fn` stackframe. let f_instance = this.get_ptr_fn(catch_unwind.catch_fn)?.as_instance()?; diff --git a/tests/utils/miri_extern.rs b/tests/utils/miri_extern.rs index e9cde20412..1bec6b61c1 100644 --- a/tests/utils/miri_extern.rs +++ b/tests/utils/miri_extern.rs @@ -171,4 +171,15 @@ extern "Rust" { /// /// As far as Miri is concerned, this is equivalent to `yield_now`. pub fn miri_spin_loop(); + + pub fn miri_fiber_create(body: extern "Rust" fn(*mut ()), data: *mut ()) -> usize; + + pub fn miri_fiber_current() -> usize; + + pub fn miri_fiber_switch(target: usize); + + /// Exit the current fiber to the fiber with the given ID. Acts as + /// `miri_fiber_switch` but indicates that the current fiber has exited + /// and will never be switched to again. + pub fn miri_fiber_exit_to(target: usize) -> !; }