diff --git a/src/concurrency/thread.rs b/src/concurrency/thread.rs index 9cf301b78d..59e2fdd428 100644 --- a/src/concurrency/thread.rs +++ b/src/concurrency/thread.rs @@ -113,6 +113,11 @@ impl ThreadId { self.0 } + /// Create a new thread id from a `u32` without checking if this thread exists. + pub fn new_unchecked(id: u32) -> Self { + Self(id) + } + pub const MAIN_THREAD: ThreadId = ThreadId(0); } diff --git a/src/shims/windows/foreign_items.rs b/src/shims/windows/foreign_items.rs index 504efed3cf..c145cf3ceb 100644 --- a/src/shims/windows/foreign_items.rs +++ b/src/shims/windows/foreign_items.rs @@ -7,6 +7,7 @@ use rustc_span::Symbol; use self::shims::windows::handle::{Handle, PseudoHandle}; use crate::shims::os_str::bytes_to_os_str; +use crate::shims::windows::handle::HandleError; use crate::shims::windows::*; use crate::*; @@ -488,7 +489,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let thread_id = this.CreateThread(security, stacksize, start, arg, flags, thread)?; - this.write_scalar(Handle::Thread(thread_id.to_u32()).to_scalar(this), dest)?; + this.write_scalar(Handle::Thread(thread_id).to_scalar(this), dest)?; } "WaitForSingleObject" => { let [handle, timeout] = @@ -513,10 +514,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let handle = this.read_scalar(handle)?; let name = this.read_wide_str(this.read_pointer(name)?)?; - let thread = match Handle::from_scalar(handle, this)? { - Some(Handle::Thread(thread)) => this.thread_id_try_from(thread), - Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()), - _ => this.invalid_handle("SetThreadDescription")?, + let thread = match Handle::try_from_scalar(handle, this)? { + Ok(Handle::Thread(thread)) => Ok(thread), + Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()), + Ok(_) | Err(HandleError::InvalidHandle) => + this.invalid_handle("SetThreadDescription")?, + Err(HandleError::ThreadNotFound(e)) => Err(e), }; let res = match thread { Ok(thread) => { @@ -536,10 +539,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let handle = this.read_scalar(handle)?; let name_ptr = this.deref_pointer(name_ptr)?; // the pointer where we should store the ptr to the name - let thread = match Handle::from_scalar(handle, this)? { - Some(Handle::Thread(thread)) => this.thread_id_try_from(thread), - Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()), - _ => this.invalid_handle("GetThreadDescription")?, + let thread = match Handle::try_from_scalar(handle, this)? { + Ok(Handle::Thread(thread)) => Ok(thread), + Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()), + Ok(_) | Err(HandleError::InvalidHandle) => + this.invalid_handle("GetThreadDescription")?, + Err(HandleError::ThreadNotFound(e)) => Err(e), }; let (name, res) = match thread { Ok(thread) => { diff --git a/src/shims/windows/handle.rs b/src/shims/windows/handle.rs index b40c00efed..3d872b65a6 100644 --- a/src/shims/windows/handle.rs +++ b/src/shims/windows/handle.rs @@ -2,6 +2,7 @@ use std::mem::variant_count; use rustc_abi::HasDataLayout; +use crate::concurrency::thread::ThreadNotFound; use crate::*; #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -14,7 +15,7 @@ pub enum PseudoHandle { pub enum Handle { Null, Pseudo(PseudoHandle), - Thread(u32), + Thread(ThreadId), } impl PseudoHandle { @@ -34,6 +35,14 @@ impl PseudoHandle { } } +/// Errors that can occur when constructing a [`Handle`] from a Scalar. +pub enum HandleError { + /// There is no thread with the given ID. + ThreadNotFound(ThreadNotFound), + /// Can't convert scalar to handle because it is structurally invalid. + InvalidHandle, +} + impl Handle { const NULL_DISCRIMINANT: u32 = 0; const PSEUDO_DISCRIMINANT: u32 = 1; @@ -51,7 +60,7 @@ impl Handle { match self { Self::Null => 0, Self::Pseudo(pseudo_handle) => pseudo_handle.value(), - Self::Thread(thread) => thread, + Self::Thread(thread) => thread.to_u32(), } } @@ -95,7 +104,7 @@ impl Handle { match discriminant { Self::NULL_DISCRIMINANT if data == 0 => Some(Self::Null), Self::PSEUDO_DISCRIMINANT => Some(Self::Pseudo(PseudoHandle::from_value(data)?)), - Self::THREAD_DISCRIMINANT => Some(Self::Thread(data)), + Self::THREAD_DISCRIMINANT => Some(Self::Thread(ThreadId::new_unchecked(data))), _ => None, } } @@ -126,10 +135,14 @@ impl Handle { Scalar::from_target_isize(signed_handle.into(), cx) } - pub fn from_scalar<'tcx>( + /// Convert a scalar into a structured `Handle`. + /// Structurally invalid handles return [`HandleError::InvalidHandle`]. + /// If the handle is structurally valid but semantically invalid, e.g. a for non-existent thread + /// ID, returns [`HandleError::ThreadNotFound`]. + pub fn try_from_scalar<'tcx>( handle: Scalar, - cx: &impl HasDataLayout, - ) -> InterpResult<'tcx, Option> { + cx: &MiriInterpCx<'tcx>, + ) -> InterpResult<'tcx, Result> { let sign_extended_handle = handle.to_target_isize(cx)?; #[expect(clippy::cast_sign_loss)] // we want to lose the sign @@ -137,10 +150,20 @@ impl Handle { signed_handle as u32 } else { // if a handle doesn't fit in an i32, it isn't valid. - return interp_ok(None); + return interp_ok(Err(HandleError::InvalidHandle)); }; - interp_ok(Self::from_packed(handle)) + match Self::from_packed(handle) { + Some(Self::Thread(thread)) => { + // validate the thread id + match cx.machine.threads.thread_id_try_from(thread.to_u32()) { + Ok(id) => interp_ok(Ok(Self::Thread(id))), + Err(e) => interp_ok(Err(HandleError::ThreadNotFound(e))), + } + } + Some(handle) => interp_ok(Ok(handle)), + None => interp_ok(Err(HandleError::InvalidHandle)), + } } } @@ -158,14 +181,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let this = self.eval_context_mut(); let handle = this.read_scalar(handle_op)?; - let ret = match Handle::from_scalar(handle, this)? { - Some(Handle::Thread(thread)) => { - if let Ok(thread) = this.thread_id_try_from(thread) { - this.detach_thread(thread, /*allow_terminated_joined*/ true)?; - this.eval_windows("c", "TRUE") - } else { - this.invalid_handle("CloseHandle")? - } + let ret = match Handle::try_from_scalar(handle, this)? { + Ok(Handle::Thread(thread)) => { + this.detach_thread(thread, /*allow_terminated_joined*/ true)?; + this.eval_windows("c", "TRUE") } _ => this.invalid_handle("CloseHandle")?, }; diff --git a/src/shims/windows/thread.rs b/src/shims/windows/thread.rs index 7af15fc647..efc1c2286b 100644 --- a/src/shims/windows/thread.rs +++ b/src/shims/windows/thread.rs @@ -65,15 +65,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> { let handle = this.read_scalar(handle_op)?; let timeout = this.read_scalar(timeout_op)?.to_u32()?; - let thread = match Handle::from_scalar(handle, this)? { - Some(Handle::Thread(thread)) => - match this.thread_id_try_from(thread) { - Ok(thread) => thread, - Err(_) => this.invalid_handle("WaitForSingleObject")?, - }, + let thread = match Handle::try_from_scalar(handle, this)? { + Ok(Handle::Thread(thread)) => thread, // Unlike on posix, the outcome of joining the current thread is not documented. // On current Windows, it just deadlocks. - Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(), + Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(), _ => this.invalid_handle("WaitForSingleObject")?, };