Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions crates/api/src/callable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::runtime::Store;
use crate::trampoline::{generate_func_export, take_api_trap};
use crate::trampoline::generate_func_export;
use crate::trap::Trap;
use crate::types::FuncType;
use crate::values::Val;
Expand Down Expand Up @@ -157,8 +157,7 @@ impl WrappedCallable for WasmtimeFn {
)
})
} {
let trap = take_api_trap().unwrap_or_else(|| Trap::from_jit(error));
return Err(trap);
return Err(Trap::from_jit(error));
}

// Load the return values out of `values_vec`.
Expand Down
10 changes: 3 additions & 7 deletions crates/api/src/instance.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::externals::Extern;
use crate::module::Module;
use crate::runtime::Store;
use crate::trampoline::take_api_trap;
use crate::trap::Trap;
use crate::types::{ExportType, ExternType};
use anyhow::{Error, Result};
Expand Down Expand Up @@ -29,12 +28,9 @@ fn instantiate(
let instance = compiled_module
.instantiate(&mut resolver)
.map_err(|e| -> Error {
if let Some(trap) = take_api_trap() {
trap.into()
} else if let InstantiationError::StartTrap(trap) = e {
Trap::from_jit(trap).into()
} else {
e.into()
match e {
InstantiationError::StartTrap(trap) => Trap::from_jit(trap).into(),
other => other.into(),
}
})?;
Ok(instance)
Expand Down
107 changes: 65 additions & 42 deletions crates/api/src/trampoline/func.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
//! Support for a calling of an imported function.

use super::create_handle::create_handle;
use super::trap::{record_api_trap, TrapSink, API_TRAP_CODE};
use crate::{Callable, FuncType, Store, Val};
use super::trap::TrapSink;
use crate::{Callable, FuncType, Store, Trap, Val};
use anyhow::{bail, Result};
use std::cmp;
use std::convert::TryFrom;
use std::panic::{self, AssertUnwindSafe};
use std::rc::Rc;
use wasmtime_environ::entity::{EntityRef, PrimaryMap};
use wasmtime_environ::ir::types;
Expand Down Expand Up @@ -69,42 +70,70 @@ unsafe extern "C" fn stub_fn(
_caller_vmctx: *mut VMContext,
call_id: u32,
values_vec: *mut i128,
) -> u32 {
let instance = InstanceHandle::from_vmctx(vmctx);

let (args, returns_len) = {
let module = instance.module_ref();
let signature = &module.signatures[module.functions[FuncIndex::new(call_id as usize)]];

let mut args = Vec::new();
for i in 2..signature.params.len() {
args.push(Val::read_value_from(
values_vec.offset(i as isize - 2),
signature.params[i].value_type,
))
}
(args, signature.returns.len())
};
) {
// Here we are careful to use `catch_unwind` to ensure Rust panics don't
// unwind past us. The primary reason for this is that Rust considers it UB
// to unwind past an `extern "C"` function. Here we are in an `extern "C"`
// function and the cross into wasm was through an `extern "C"` function at
// the base of the stack as well. We'll need to wait for assorted RFCs and
// language features to enable this to be done in a sound and stable fashion
// before avoiding catching the panic here.
//
// Also note that there are intentionally no local variables on this stack
// frame. The reason for that is that some of the "raise" functions we have
// below will trigger a longjmp, which won't run local destructors if we
// have any. To prevent leaks we avoid having any local destructors by
// avoiding local variables.
let result = panic::catch_unwind(AssertUnwindSafe(|| call_stub(vmctx, call_id, values_vec)));

match result {
Ok(Ok(())) => {}

// If a trap was raised (an error returned from the imported function)
// then we smuggle the trap through `Box<dyn Error>` through to the
// call-site, which gets unwrapped in `Trap::from_jit` later on as we
// convert from the internal `Trap` type to our own `Trap` type in this
// crate.
Ok(Err(trap)) => wasmtime_runtime::raise_user_trap(Box::new(trap)),

// And finally if the imported function panicked, then we trigger the
// form of unwinding that's safe to jump over wasm code on all
// platforms.
Err(panic) => wasmtime_runtime::resume_panic(panic),
}

let mut returns = vec![Val::null(); returns_len];
let func = &instance
.host_state()
.downcast_ref::<TrampolineState>()
.expect("state")
.func;

match func.call(&args, &mut returns) {
Ok(()) => {
for (i, r#return) in returns.iter_mut().enumerate() {
// TODO check signature.returns[i].value_type ?
r#return.write_value_to(values_vec.add(i));
unsafe fn call_stub(
vmctx: *mut VMContext,
call_id: u32,
values_vec: *mut i128,
) -> Result<(), Trap> {
let instance = InstanceHandle::from_vmctx(vmctx);

let (args, returns_len) = {
let module = instance.module_ref();
let signature = &module.signatures[module.functions[FuncIndex::new(call_id as usize)]];

let mut args = Vec::new();
for i in 2..signature.params.len() {
args.push(Val::read_value_from(
values_vec.offset(i as isize - 2),
signature.params[i].value_type,
))
}
0
}
Err(trap) => {
record_api_trap(trap);
1
(args, signature.returns.len())
};

let mut returns = vec![Val::null(); returns_len];
let state = &instance
.host_state()
.downcast_ref::<TrampolineState>()
.expect("state");
state.func.call(&args, &mut returns)?;
for (i, ret) in returns.iter_mut().enumerate() {
// TODO check signature.returns[i].value_type ?
ret.write_value_to(values_vec.add(i));
}
Ok(())
}
}

Expand Down Expand Up @@ -136,9 +165,6 @@ fn make_trampoline(
// Add the `values_vec` parameter.
stub_sig.params.push(ir::AbiParam::new(pointer_type));

// Add error/trap return.
stub_sig.returns.push(ir::AbiParam::new(types::I32));

// Compute the size of the values vector. The vmctx and caller vmctx are passed separately.
let value_size = 16;
let values_vec_len = ((value_size as usize)
Expand Down Expand Up @@ -195,13 +221,10 @@ fn make_trampoline(
let callee_value = builder
.ins()
.iconst(pointer_type, stub_fn as *const VMFunctionBody as i64);
let call = builder
builder
.ins()
.call_indirect(new_sig, callee_value, &callee_args);

let call_result = builder.func.dfg.inst_results(call)[0];
builder.ins().trapnz(call_result, API_TRAP_CODE);

let mflags = MemFlags::trusted();
let mut results = Vec::new();
for (i, r) in signature.returns.iter().enumerate() {
Expand Down
1 change: 0 additions & 1 deletion crates/api/src/trampoline/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use anyhow::Result;
use std::rc::Rc;

pub use self::global::GlobalState;
pub use self::trap::take_api_trap;

pub fn generate_func_export(
ft: &FuncType,
Expand Down
25 changes: 0 additions & 25 deletions crates/api/src/trampoline/trap.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
use std::cell::Cell;

use crate::Trap;
use wasmtime_environ::ir::{SourceLoc, TrapCode};
use wasmtime_environ::TrapInformation;
use wasmtime_jit::trampoline::binemit;

// Randomly selected user TrapCode magic number 13.
pub const API_TRAP_CODE: TrapCode = TrapCode::User(13);

thread_local! {
static RECORDED_API_TRAP: Cell<Option<Trap>> = Cell::new(None);
}

pub fn record_api_trap(trap: Trap) {
RECORDED_API_TRAP.with(|data| {
let trap = Cell::new(Some(trap));
data.swap(&trap);
assert!(
trap.take().is_none(),
"Only one API trap per thread can be recorded at a moment!"
);
});
}

pub fn take_api_trap() -> Option<Trap> {
RECORDED_API_TRAP.with(|data| data.take())
}

pub(crate) struct TrapSink {
pub traps: Vec<TrapInformation>,
}
Expand Down
19 changes: 18 additions & 1 deletion crates/api/src/trap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,24 @@ impl Trap {
}

pub(crate) fn from_jit(jit: wasmtime_runtime::Trap) -> Self {
Trap::new_with_trace(jit.to_string(), jit.backtrace)
match jit {
wasmtime_runtime::Trap::User(error) => {
// Since we're the only one using the wasmtime internals (in
// theory) we should only see user errors which were originally
// created from our own `Trap` type (see the trampoline module
// with functions).
//
// If this unwrap trips for someone we'll need to tweak the
// return type of this function to probably be `anyhow::Error`
// or something like that.
*error
.downcast()
.expect("only `Trap` user errors are supported")
}
wasmtime_runtime::Trap::Wasm { desc, backtrace } => {
Trap::new_with_trace(desc.to_string(), backtrace)
}
}
}

fn new_with_trace(message: String, native_trace: Backtrace) -> Self {
Expand Down
93 changes: 93 additions & 0 deletions crates/api/tests/traps.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::Result;
use std::panic::{self, AssertUnwindSafe};
use std::rc::Rc;
use wasmtime::*;

Expand Down Expand Up @@ -215,3 +216,95 @@ wasm backtrace:
);
Ok(())
}

#[test]
fn trap_start_function_import() -> Result<()> {
struct ReturnTrap;

impl Callable for ReturnTrap {
fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> {
Err(Trap::new("user trap"))
}
}

let store = Store::default();
let binary = wat::parse_str(
r#"
(module $a
(import "" "" (func $foo))
(start $foo)
)
"#,
)?;

let module = Module::new(&store, &binary)?;
let sig = FuncType::new(Box::new([]), Box::new([]));
let func = Func::new(&store, sig, Rc::new(ReturnTrap));
let err = Instance::new(&module, &[func.into()]).err().unwrap();
assert_eq!(err.downcast_ref::<Trap>().unwrap().message(), "user trap");
Ok(())
}

#[test]
fn rust_panic_import() -> Result<()> {
struct Panic;

impl Callable for Panic {
fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> {
panic!("this is a panic");
}
}

let store = Store::default();
let binary = wat::parse_str(
r#"
(module $a
(import "" "" (func $foo))
(func (export "foo") call $foo)
)
"#,
)?;

let module = Module::new(&store, &binary)?;
let sig = FuncType::new(Box::new([]), Box::new([]));
let func = Func::new(&store, sig, Rc::new(Panic));
let instance = Instance::new(&module, &[func.into()])?;
let func = instance.exports()[0].func().unwrap().clone();
let err = panic::catch_unwind(AssertUnwindSafe(|| {
drop(func.call(&[]));
}))
.unwrap_err();
assert_eq!(err.downcast_ref::<&'static str>(), Some(&"this is a panic"));
Ok(())
}

#[test]
fn rust_panic_start_function() -> Result<()> {
struct Panic;

impl Callable for Panic {
fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> {
panic!("this is a panic");
}
}

let store = Store::default();
let binary = wat::parse_str(
r#"
(module $a
(import "" "" (func $foo))
(start $foo)
)
"#,
)?;

let module = Module::new(&store, &binary)?;
let sig = FuncType::new(Box::new([]), Box::new([]));
let func = Func::new(&store, sig, Rc::new(Panic));
let err = panic::catch_unwind(AssertUnwindSafe(|| {
drop(Instance::new(&module, &[func.into()]));
}))
.unwrap_err();
assert_eq!(err.downcast_ref::<&'static str>(), Some(&"this is a panic"));
Ok(())
}
26 changes: 23 additions & 3 deletions crates/c-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// TODO complete the C API

use std::cell::RefCell;
use std::panic::{self, AssertUnwindSafe};
use std::rc::Rc;
use std::{mem, ptr, slice};
use wasmtime::{
Expand Down Expand Up @@ -488,15 +489,34 @@ pub unsafe extern "C" fn wasm_func_call(
let val = &(*args.add(i));
params.push(val.val());
}
match func.call(&params) {
Ok(out) => {

// We're calling arbitrary code here most of the time, and we in general
// want to try to insulate callers against bugs in wasmtime/wasi/etc if we
// can. As a result we catch panics here and transform them to traps to
// allow the caller to have any insulation possible against Rust panics.
let result = panic::catch_unwind(AssertUnwindSafe(|| func.call(&params)));
match result {
Ok(Ok(out)) => {
for i in 0..func.result_arity() {
let val = &mut (*results.add(i));
*val = wasm_val_t::from_val(&out[i]);
}
ptr::null_mut()
}
Err(trap) => {
Ok(Err(trap)) => {
let trap = Box::new(wasm_trap_t {
trap: HostRef::new(trap),
});
Box::into_raw(trap)
}
Err(panic) => {
let trap = if let Some(msg) = panic.downcast_ref::<String>() {
Trap::new(msg)
} else if let Some(msg) = panic.downcast_ref::<&'static str>() {
Trap::new(*msg)
} else {
Trap::new("rust panic happened")
};
let trap = Box::new(wasm_trap_t {
trap: HostRef::new(trap),
});
Expand Down
Loading