Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved traphandler code #2318

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/api/src/externals/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ impl Function {
// Call the trampoline.
if let Err(error) = unsafe {
wasmer_call_trampoline(
&self.store,
self.store.trap_handler(),
self.exported.vm_function.vmctx,
trampoline,
self.exported.vm_function.address,
Expand Down
4 changes: 2 additions & 2 deletions lib/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ pub mod vm {
//! The vm module re-exports wasmer-vm types.

pub use wasmer_vm::{
Memory, MemoryError, MemoryStyle, Table, TableStyle, VMExtern, VMMemoryDefinition,
VMTableDefinition,
Memory, MemoryError, MemoryStyle, Table, TableStyle, TrapHandler, VMExtern,
VMMemoryDefinition, VMTableDefinition,
};
}

Expand Down
2 changes: 1 addition & 1 deletion lib/api/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ impl Module {
// as some of the Instance elements may have placed in other
// instance tables.
self.artifact
.finish_instantiation(&self.store, &instance_handle)?;
.finish_instantiation(self.store.trap_handler(), &instance_handle)?;

Ok(instance_handle)
}
Expand Down
2 changes: 1 addition & 1 deletion lib/api/src/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ macro_rules! impl_native_traits {
};
unsafe {
wasmer_vm::wasmer_call_trampoline(
&self.store,
self.store.trap_handler(),
self.vmctx(),
trampoline,
self.address(),
Expand Down
37 changes: 13 additions & 24 deletions lib/api/src/store.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use crate::tunables::BaseTunables;
use loupe::MemoryUsage;
use std::any::Any;
use std::fmt;
use std::sync::{Arc, RwLock};
use std::sync::Arc;
#[cfg(all(feature = "compiler", feature = "engine"))]
use wasmer_compiler::CompilerConfig;
use wasmer_engine::{is_wasm_pc, Engine, Tunables};
use wasmer_vm::{init_traps, TrapHandler, TrapHandlerFn};
use wasmer_vm::{init_traps, TrapHandler};

/// The store represents all global state that can be manipulated by
/// WebAssembly programs. It consists of the runtime representation
Expand All @@ -23,7 +22,7 @@ pub struct Store {
engine: Arc<dyn Engine + Send + Sync>,
tunables: Arc<dyn Tunables + Send + Sync>,
#[loupe(skip)]
syrusakbary marked this conversation as resolved.
Show resolved Hide resolved
trap_handler: Arc<RwLock<Option<Box<TrapHandlerFn>>>>,
pub(crate) trap_handler: Arc<Option<Box<TrapHandler<'static>>>>,
}

impl Store {
Expand All @@ -35,10 +34,15 @@ impl Store {
Self::new_with_tunables(engine, BaseTunables::for_target(engine.target()))
}

/// Set the trap handler in this store.
pub fn set_trap_handler(&self, handler: Option<Box<TrapHandlerFn>>) {
let mut m = self.trap_handler.write().unwrap();
*m = handler;
/// Sets a [`TrapHandler`] for the `Store`
pub fn set_trap_handler(&mut self, handler: Option<Box<TrapHandler<'static>>>) {
self.trap_handler = Arc::new(handler);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I like this, but lets make sure the new semantics are what we intended. In the old code, calling set_trap_handler would globally update the trap handler in all Stores. In the new code it only updates the trap handler for this specific store (and Stores Cloned from it). All the old trap handlers will continue to work here because they're kept alive by the Arc. (This is actually a great example of how using the raw pointer could have caused problems, the moment set is called again, the trap handler being used could have been invalid.)

If the move from Store's having 1 trap handler to Stores having N trap handlers in a tree-like structure is intentional then 👍 , if not then we probably need to rethink some of this.

And if this is intentional we may want to call it out more in the doc comments.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also side note but if you want to remove the extra Box in Arc<Option<Box<TrapHandler<'static>>> let me know, I think it should be possible to do so in many cases.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the first comment :-).

}

/// Retrieves the [`TrapHandler`] pointer if is already set for the `Store`.
pub fn trap_handler(&self) -> Option<*const TrapHandler<'static>> {
syrusakbary marked this conversation as resolved.
Show resolved Hide resolved
let trap_handler = (*self.trap_handler).as_ref()?;
Some(&**trap_handler as *const _)
}

/// Creates a new `Store` with a specific [`Engine`] and [`Tunables`].
Expand All @@ -53,7 +57,7 @@ impl Store {
Self {
engine: engine.cloned(),
tunables: Arc::new(tunables),
trap_handler: Arc::new(RwLock::new(None)),
trap_handler: Arc::new(None),
}
}

Expand Down Expand Up @@ -81,21 +85,6 @@ impl PartialEq for Store {
}
}

unsafe impl TrapHandler for Store {
#[inline]
fn as_any(&self) -> &dyn Any {
self
}

fn custom_trap_handler(&self, call: &dyn Fn(&TrapHandlerFn) -> bool) -> bool {
if let Some(handler) = *&self.trap_handler.read().unwrap().as_ref() {
call(handler)
} else {
false
}
}
}

// This is required to be able to set the trap_handler in the
// Store.
unsafe impl Send for Store {}
Expand Down
2 changes: 1 addition & 1 deletion lib/engine/src/artifact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ pub trait Artifact: Send + Sync + Upcastable + MemoryUsage {
/// See [`InstanceHandle::finish_instantiation`].
unsafe fn finish_instantiation(
&self,
trap_handler: &dyn TrapHandler,
trap_handler: Option<*const TrapHandler<'static>>,
syrusakbary marked this conversation as resolved.
Show resolved Hide resolved
handle: &InstanceHandle,
) -> Result<(), InstantiationError> {
let data_initializers = self
Expand Down
7 changes: 6 additions & 1 deletion lib/object/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,12 @@ pub fn emit_compilation(
section: SymbolSection::Section(section_id),
flags: SymbolFlags::None,
});
obj.add_symbol_data(symbol_id, section_id, custom_section.bytes.as_slice(), align);
obj.add_symbol_data(
symbol_id,
section_id,
custom_section.bytes.as_slice(),
align,
);
(section_id, symbol_id)
}
})
Expand Down
7 changes: 5 additions & 2 deletions lib/vm/src/instance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ impl Instance {
}

/// Invoke the WebAssembly start function of the instance, if one is present.
fn invoke_start_function(&self, trap_handler: &dyn TrapHandler) -> Result<(), Trap> {
fn invoke_start_function(
&self,
trap_handler: Option<*const TrapHandler<'static>>,
) -> Result<(), Trap> {
let start_index = match self.module.start_function {
Some(idx) => idx,
None => return Ok(()),
Expand Down Expand Up @@ -1015,7 +1018,7 @@ impl InstanceHandle {
/// Only safe to call immediately after instantiation.
pub unsafe fn finish_instantiation(
&self,
trap_handler: &dyn TrapHandler,
trap_handler: Option<*const TrapHandler<'static>>,
data_initializers: &[DataInitializer<'_>],
) -> Result<(), Trap> {
let instance = self.instance().as_ref();
Expand Down
2 changes: 1 addition & 1 deletion lib/vm/src/trap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ mod traphandlers;
pub use trapcode::TrapCode;
pub use traphandlers::{
catch_traps, catch_traps_with_result, raise_lib_trap, raise_user_trap, wasmer_call_trampoline,
TlsRestore, Trap, TrapHandler, TrapHandlerFn,
TlsRestore, Trap, TrapHandler,
};
pub use traphandlers::{init_traps, resume_panic};
65 changes: 25 additions & 40 deletions lib/vm/src/trap/traphandlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ pub use tls::TlsRestore;
cfg_if::cfg_if! {
if #[cfg(unix)] {
/// Function which may handle custom signals while processing traps.
pub type TrapHandlerFn = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool;
pub type TrapHandler<'a> = dyn Fn(libc::c_int, *const libc::siginfo_t, *const libc::c_void) -> bool + Send + Sync + 'a;
} else if #[cfg(target_os = "windows")] {
/// Function which may handle custom signals while processing traps.
pub type TrapHandlerFn = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool;
pub type TrapHandler<'a> = dyn Fn(winapi::um::winnt::PEXCEPTION_POINTERS) -> bool + Send + Sync + 'a;
}
}

Expand Down Expand Up @@ -525,7 +525,7 @@ impl Trap {
/// Wildly unsafe because it calls raw function pointers and reads/writes raw
/// function pointers.
pub unsafe fn wasmer_call_trampoline(
trap_handler: &impl TrapHandler,
trap_handler: Option<*const TrapHandler<'static>>,
vmctx: VMFunctionEnvironment,
trampoline: VMTrampoline,
callee: *const VMFunctionBody,
Expand All @@ -542,7 +542,10 @@ pub unsafe fn wasmer_call_trampoline(
/// returning them as a `Result`.
///
/// Highly unsafe since `closure` won't have any dtors run.
pub unsafe fn catch_traps<F>(trap_handler: &dyn TrapHandler, mut closure: F) -> Result<(), Trap>
pub unsafe fn catch_traps<F>(
trap_handler: Option<*const TrapHandler<'static>>,
mut closure: F,
) -> Result<(), Trap>
where
F: FnMut(),
{
Expand Down Expand Up @@ -572,7 +575,7 @@ where
///
/// Check [`catch_traps`].
pub unsafe fn catch_traps_with_result<F, R>(
trap_handler: &dyn TrapHandler,
trap_handler: Option<*const TrapHandler<'static>>,
mut closure: F,
) -> Result<R, Trap>
where
Expand All @@ -587,31 +590,15 @@ where

/// Temporary state stored on the stack which is registered in the `tls` module
/// below for calls into wasm.
pub struct CallThreadState<'a> {
pub struct CallThreadState {
unwind: UnsafeCell<MaybeUninit<UnwindReason>>,
trap_handler: Option<*const TrapHandler<'static>>,
jmp_buf: Cell<*const u8>,
reset_guard_page: Cell<bool>,
prev: Cell<tls::Ptr>,
trap_handler: &'a (dyn TrapHandler + 'a),
handling_trap: Cell<bool>,
}

/// A package of functionality needed by `catch_traps` to figure out what to do
/// when handling a trap.
///
/// Note that this is an `unsafe` trait at least because it's being run in the
/// context of a synchronous signal handler, so it needs to be careful to not
/// access too much state in answering these queries.
pub unsafe trait TrapHandler {
/// Converts this object into an `Any` to dynamically check its type.
fn as_any(&self) -> &dyn Any;

/// Uses `call` to call a custom signal handler, if one is specified.
///
/// Returns `true` if `call` returns true, otherwise returns `false`.
fn custom_trap_handler(&self, call: &dyn Fn(&TrapHandlerFn) -> bool) -> bool;
}

enum UnwindReason {
/// A panic caused by the host
Panic(Box<dyn Any + Send>),
Expand All @@ -627,9 +614,9 @@ enum UnwindReason {
},
}

impl<'a> CallThreadState<'a> {
impl CallThreadState {
#[inline]
fn new(trap_handler: &'a (dyn TrapHandler + 'a)) -> CallThreadState<'a> {
fn new(trap_handler: Option<*const TrapHandler<'static>>) -> CallThreadState {
Self {
unwind: UnsafeCell::new(MaybeUninit::uninit()),
jmp_buf: Cell::new(ptr::null()),
Expand Down Expand Up @@ -690,7 +677,7 @@ impl<'a> CallThreadState<'a> {
pc: *const u8,
reset_guard_page: bool,
signal_trap: Option<TrapCode>,
call_handler: impl Fn(&TrapHandlerFn) -> bool,
call_handler: impl Fn(&TrapHandler) -> bool,
) -> *const u8 {
// If we hit a fault while handling a previous trap, that's quite bad,
// so bail out and let the system handle this recursive segfault.
Expand All @@ -704,8 +691,10 @@ impl<'a> CallThreadState<'a> {
// First up see if we have a custom trap handler,
// in which case run it. If anything handles the trap then we
// return that the trap was handled.
if self.trap_handler.custom_trap_handler(&call_handler) {
return 1 as *const _;
if let Some(handler) = self.trap_handler {
if unsafe { call_handler(&*handler) } {
return 1 as *const _;
}
}

// TODO: stack overflow can happen at any random time (i.e. in malloc()
Expand Down Expand Up @@ -735,7 +724,7 @@ impl<'a> CallThreadState<'a> {
}
}

impl<'a> Drop for CallThreadState<'a> {
impl Drop for CallThreadState {
fn drop(&mut self) {
if self.reset_guard_page.get() {
reset_guard_page();
Expand All @@ -751,7 +740,6 @@ impl<'a> Drop for CallThreadState<'a> {
mod tls {
use super::CallThreadState;
use crate::Trap;
use std::mem;
use std::ptr;

pub use raw::Ptr;
Expand All @@ -773,7 +761,7 @@ mod tls {
use std::cell::Cell;
use std::ptr;

pub type Ptr = *const CallThreadState<'static>;
pub type Ptr = *const CallThreadState;

// The first entry here is the `Ptr` which is what's used as part of the
// public interface of this module. The second entry is a boolean which
Expand Down Expand Up @@ -849,30 +837,27 @@ mod tls {
/// Configures thread local state such that for the duration of the
/// execution of `closure` any call to `with` will yield `ptr`, unless this
/// is recursively called again.
pub fn set<R>(state: &CallThreadState<'_>, closure: impl FnOnce() -> R) -> Result<R, Trap> {
struct Reset<'a, 'b>(&'a CallThreadState<'b>);
#[inline]
pub fn set<R>(state: &CallThreadState, closure: impl FnOnce() -> R) -> Result<R, Trap> {
struct Reset<'a>(&'a CallThreadState);

impl Drop for Reset<'_, '_> {
impl Drop for Reset<'_> {
#[inline]
fn drop(&mut self) {
raw::replace(self.0.prev.replace(ptr::null()))
.expect("tls should be previously initialized");
}
}

// Note that this extension of the lifetime to `'static` should be
// safe because we only ever access it below with an anonymous
// lifetime, meaning `'static` never leaks out of this module.
let ptr = unsafe { mem::transmute::<*const CallThreadState<'_>, _>(state) };
let prev = raw::replace(ptr)?;
let prev = raw::replace(state)?;
state.prev.set(prev);
let _reset = Reset(state);
Ok(closure())
}

/// Returns the last pointer configured with `set` above. Panics if `set`
/// has not been previously called and not returned.
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState<'_>>) -> R) -> R {
pub fn with<R>(closure: impl FnOnce(Option<&CallThreadState>) -> R) -> R {
let p = raw::get();
unsafe { closure(if p.is_null() { None } else { Some(&*p) }) }
}
Expand Down