Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(runtime-core) Host function without a vm::Ctx argument #917

Merged
25 changes: 25 additions & 0 deletions lib/runtime-core-tests/tests/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,19 @@ fn imported_functions_forms() {
(module
(type $type (func (param i32) (result i32)))
(import "env" "memory" (memory 1 1))
(import "env" "callback_fn" (func $callback_fn (type $type)))
(import "env" "callback_fn_with_vmctx" (func $callback_fn_with_vmctx (type $type)))
(import "env" "callback_fn_trap" (func $callback_fn_trap (type $type)))
(import "env" "callback_fn_trap_with_vmctx" (func $callback_fn_trap_with_vmctx (type $type)))
(func (export "function_fn") (type $type)
get_local 0
call $callback_fn)
(func (export "function_fn_with_vmctx") (type $type)
get_local 0
call $callback_fn_with_vmctx)
(func (export "function_fn_trap") (type $type)
get_local 0
call $callback_fn_trap)
(func (export "function_fn_trap_with_vmctx") (type $type)
get_local 0
call $callback_fn_trap_with_vmctx))
Expand All @@ -31,7 +39,9 @@ fn imported_functions_forms() {
let import_object = imports! {
"env" => {
"memory" => memory.clone(),
"callback_fn" => Func::new(callback_fn),
"callback_fn_with_vmctx" => Func::new(callback_fn_with_vmctx),
"callback_fn_trap" => Func::new(callback_fn_trap),
"callback_fn_trap_with_vmctx" => Func::new(callback_fn_trap_with_vmctx),
},
};
Expand Down Expand Up @@ -88,7 +98,14 @@ fn imported_functions_forms() {
};
}

call_and_assert!(function_fn, Ok(2));
call_and_assert!(function_fn_with_vmctx, Ok(2 + SHIFT));
call_and_assert!(
function_fn_trap,
Err(RuntimeError::Error {
data: Box::new(format!("foo {}", 1))
})
);
call_and_assert!(
function_fn_trap_with_vmctx,
Err(RuntimeError::Error {
Expand All @@ -97,13 +114,21 @@ fn imported_functions_forms() {
);
}

fn callback_fn(n: i32) -> Result<i32, ()> {
Ok(n + 1)
}

fn callback_fn_with_vmctx(vmctx: &mut vm::Ctx, n: i32) -> Result<i32, ()> {
let memory = vmctx.memory(0);
let shift: i32 = memory.view()[0].get();

Ok(shift + n + 1)
}

fn callback_fn_trap(n: i32) -> Result<i32, String> {
Err(format!("foo {}", n))
}

fn callback_fn_trap_with_vmctx(vmctx: &mut vm::Ctx, n: i32) -> Result<i32, String> {
let memory = vmctx.memory(0);
let shift: i32 = memory.view()[0].get();
Expand Down
182 changes: 161 additions & 21 deletions lib/runtime-core/src/typed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
export::{Context, Export, FuncPointer},
import::IsExport,
types::{FuncSig, NativeWasmType, Type, WasmExternType},
vm::{self, Ctx},
vm,
};
use std::{
any::Any,
Expand Down Expand Up @@ -52,10 +52,10 @@ impl fmt::Display for WasmTrapInfo {
/// of the `Func` struct.
pub trait Kind {}

pub type Trampoline = unsafe extern "C" fn(*mut Ctx, NonNull<vm::Func>, *const u64, *mut u64);
pub type Trampoline = unsafe extern "C" fn(*mut vm::Ctx, NonNull<vm::Func>, *const u64, *mut u64);
pub type Invoke = unsafe extern "C" fn(
Trampoline,
*mut Ctx,
*mut vm::Ctx,
NonNull<vm::Func>,
*const u64,
*mut u64,
Expand Down Expand Up @@ -124,16 +124,49 @@ pub trait WasmTypeList {
self,
f: NonNull<vm::Func>,
wasm: Wasm,
ctx: *mut Ctx,
ctx: *mut vm::Ctx,
) -> Result<Rets, RuntimeError>
where
Rets: WasmTypeList;
}

/// Empty trait to specify the kind of `ExternalFunction`: With or
/// without a `vm::Ctx` argument. See the `ExplicitVmCtx` and the
/// `ImplicitVmCtx` structures.
///
/// This type is never aimed to be used by a user. It is used by the
/// trait system to automatically generate an appropriate `wrap`
/// function.
pub trait ExternalFunctionKind {}
Hywan marked this conversation as resolved.
Show resolved Hide resolved

/// This empty structure indicates that an external function must
/// contain an explicit `vm::Ctx` argument (at first position).
///
/// ```rs,ignore
/// fn add_one(_: mut &vm::Ctx, x: i32) -> i32 {
/// x + 1
/// }
/// ```
pub struct ExplicitVmCtx {}

/// This empty structure indicates that an external function has no
/// `vm::Ctx` argument (at first position). Its signature is:
///
/// ```rs,ignore
/// fn add_one(x: i32) -> i32 {
/// x + 1
/// }
/// ```
pub struct ImplicitVmCtx {}

impl ExternalFunctionKind for ExplicitVmCtx {}
impl ExternalFunctionKind for ImplicitVmCtx {}

/// Represents a function that can be converted to a `vm::Func`
/// (function pointer) that can be called within WebAssembly.
pub trait ExternalFunction<Args, Rets>
pub trait ExternalFunction<Kind, Args, Rets>
where
Kind: ExternalFunctionKind,
Args: WasmTypeList,
Rets: WasmTypeList,
{
Expand Down Expand Up @@ -173,7 +206,7 @@ where
pub struct Func<'a, Args = (), Rets = (), Inner: Kind = Wasm> {
inner: Inner,
f: NonNull<vm::Func>,
ctx: *mut Ctx,
ctx: *mut vm::Ctx,
_phantom: PhantomData<(&'a (), Args, Rets)>,
}

Expand All @@ -188,7 +221,7 @@ where
pub(crate) unsafe fn from_raw_parts(
inner: Wasm,
f: NonNull<vm::Func>,
ctx: *mut Ctx,
ctx: *mut vm::Ctx,
) -> Func<'a, Args, Rets, Wasm> {
Func {
inner,
Expand All @@ -208,9 +241,10 @@ where
Args: WasmTypeList,
Rets: WasmTypeList,
{
pub fn new<F>(f: F) -> Func<'a, Args, Rets, Host>
pub fn new<F, Kind>(f: F) -> Func<'a, Args, Rets, Host>
where
F: ExternalFunction<Args, Rets>,
Kind: ExternalFunctionKind,
F: ExternalFunction<Kind, Args, Rets>,
{
Func {
inner: Host(()),
Expand Down Expand Up @@ -267,7 +301,7 @@ impl WasmTypeList for Infallible {
self,
_: NonNull<vm::Func>,
_: Wasm,
_: *mut Ctx,
_: *mut vm::Ctx,
) -> Result<Rets, RuntimeError>
where
Rets: WasmTypeList,
Expand Down Expand Up @@ -313,7 +347,7 @@ where
self,
f: NonNull<vm::Func>,
wasm: Wasm,
ctx: *mut Ctx,
ctx: *mut vm::Ctx,
) -> Result<Rets, RuntimeError>
where
Rets: WasmTypeList,
Expand Down Expand Up @@ -405,7 +439,7 @@ macro_rules! impl_traits {
self,
f: NonNull<vm::Func>,
wasm: Wasm,
ctx: *mut Ctx,
ctx: *mut vm::Ctx,
) -> Result<Rets, RuntimeError>
where
Rets: WasmTypeList
Expand Down Expand Up @@ -438,33 +472,91 @@ macro_rules! impl_traits {
}
}

impl< $( $x, )* Rets, Trap, FN > ExternalFunction<( $( $x ),* ), Rets> for FN
impl< $( $x, )* Rets, Trap, FN > ExternalFunction<ExplicitVmCtx, ( $( $x ),* ), Rets> for FN
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation of ExternalFunction<ExplicitVmCtx, …>, where FN must be Fn(&mut vm::Ctx, …). The wrap function has a &mut vm::Ctx argument.

Note for the ones who aren't familiar with this detail: The wrap function will be dereferenced to *mut vm::Func (opaque function pointer), and will be used as FuncPointer by the Export API. So wrap acts as a “wrapper” around the user-given function.

where
$( $x: WasmExternType, )*
Rets: WasmTypeList,
Trap: TrapEarly<Rets>,
FN: Fn(&mut vm::Ctx $( , $x )*) -> Trap,
{
#[allow(non_snake_case)]
fn to_raw(&self) -> NonNull<vm::Func> {
if mem::size_of::<Self>() == 0 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably Rust wasn't happy with you using the Sized bound?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Self refers to FN so to the Fn trait. Writing mem::size_of::<Self>() == 0 is a way to tell: FN is a function pointer, or a closure with no captured environment.

This code will change in the next PR to support closures as host functions. So far, this PR only copy-paste the existing code by definingvm::Ctx as an optional argument.

/// This is required for the llvm backend to be able to unwind through this function.
#[cfg_attr(nightly, unwind(allowed))]
extern fn wrap<$( $x, )* Rets, Trap, FN>(
vmctx: &mut vm::Ctx $( , $x: <$x as WasmExternType>::Native )*
) -> Rets::CStruct
where
$( $x: WasmExternType, )*
Rets: WasmTypeList,
Trap: TrapEarly<Rets>,
FN: Fn(&mut vm::Ctx, $( $x, )*) -> Trap,
{
let f: FN = unsafe { mem::transmute_copy(&()) };
Hywan marked this conversation as resolved.
Show resolved Hide resolved

let err = match panic::catch_unwind(
panic::AssertUnwindSafe(
|| {
f(vmctx $( , WasmExternType::from_native($x) )* ).report()
}
)
) {
Ok(Ok(returns)) => return returns.into_c_struct(),
Ok(Err(err)) => {
let b: Box<_> = err.into();
b as Box<dyn Any>
},
Err(err) => err,
};

unsafe {
(&*vmctx.module).runnable_module.do_early_trap(err)
}
}

NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap()
} else {
assert_eq!(
mem::size_of::<Self>(),
mem::size_of::<usize>(),
"you cannot use a closure that captures state for `Func`."
);

NonNull::new(unsafe {
mem::transmute_copy::<_, *mut vm::Func>(self)
}).unwrap()
}
}
}

impl< $( $x, )* Rets, Trap, FN > ExternalFunction<ImplicitVmCtx, ( $( $x ),* ), Rets> for FN
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation of ExternalFunction<ImplicitVmCtx, …>, where FN must be Fn(…) (without vm::Ctx). The wrap function still has a &mut vm::Ctx argument! Why? Because the backend always inserts a pointer to vm::Ctx in the stack, so wrap can read it.

Why grabbing vm::Ctx then? It is mandatory to trap error with do_early_trap.

where
$( $x: WasmExternType, )*
Rets: WasmTypeList,
Trap: TrapEarly<Rets>,
FN: Fn(&mut Ctx $( , $x )*) -> Trap,
FN: Fn($( $x, )*) -> Trap,
{
#[allow(non_snake_case)]
fn to_raw(&self) -> NonNull<vm::Func> {
if mem::size_of::<Self>() == 0 {
/// This is required for the llvm backend to be able to unwind through this function.
#[cfg_attr(nightly, unwind(allowed))]
extern fn wrap<$( $x, )* Rets, Trap, FN>(
ctx: &mut Ctx $( , $x: <$x as WasmExternType>::Native )*
vmctx: &mut vm::Ctx $( , $x: <$x as WasmExternType>::Native )*
) -> Rets::CStruct
where
$( $x: WasmExternType, )*
Rets: WasmTypeList,
Trap: TrapEarly<Rets>,
FN: Fn(&mut Ctx $( , $x )*) -> Trap,
FN: Fn($( $x, )*) -> Trap,
{
let f: FN = unsafe { mem::transmute_copy(&()) };

let err = match panic::catch_unwind(
panic::AssertUnwindSafe(
|| {
f(ctx $( , WasmExternType::from_native($x) )* ).report()
f($( WasmExternType::from_native($x), )* ).report()
}
)
) {
Expand All @@ -477,7 +569,7 @@ macro_rules! impl_traits {
};

unsafe {
(&*ctx.module).runnable_module.do_early_trap(err)
(&*vmctx.module).runnable_module.do_early_trap(err)
}
}

Expand All @@ -490,7 +582,7 @@ macro_rules! impl_traits {
);

NonNull::new(unsafe {
::std::mem::transmute_copy::<_, *mut vm::Func>(self)
mem::transmute_copy::<_, *mut vm::Func>(self)
}).unwrap()
}
}
Expand Down Expand Up @@ -562,9 +654,57 @@ where
#[cfg(test)]
mod tests {
use super::*;

macro_rules! test_func_arity_n {
($test_name:ident, $($x:ident),*) => {
#[test]
fn $test_name() {
use crate::vm;

fn with_vmctx(_: &mut vm::Ctx, $($x: i32),*) -> i32 {
vec![$($x),*].iter().sum()
}

fn without_vmctx($($x: i32),*) -> i32 {
vec![$($x),*].iter().sum()
}

let _func = Func::new(with_vmctx);
let _func = Func::new(without_vmctx);
}
}
}

#[test]
fn test_func_arity_0() {
fn foo(_: &mut vm::Ctx) -> i32 {
0
}

fn bar() -> i32 {
0
}

let _ = Func::new(foo);
let _ = Func::new(bar);
}

test_func_arity_n!(test_func_arity_1, a);
test_func_arity_n!(test_func_arity_2, a, b);
test_func_arity_n!(test_func_arity_3, a, b, c);
test_func_arity_n!(test_func_arity_4, a, b, c, d);
test_func_arity_n!(test_func_arity_5, a, b, c, d, e);
test_func_arity_n!(test_func_arity_6, a, b, c, d, e, f);
test_func_arity_n!(test_func_arity_7, a, b, c, d, e, f, g);
test_func_arity_n!(test_func_arity_8, a, b, c, d, e, f, g, h);
test_func_arity_n!(test_func_arity_9, a, b, c, d, e, f, g, h, i);
test_func_arity_n!(test_func_arity_10, a, b, c, d, e, f, g, h, i, j);
test_func_arity_n!(test_func_arity_11, a, b, c, d, e, f, g, h, i, j, k);
test_func_arity_n!(test_func_arity_12, a, b, c, d, e, f, g, h, i, j, k, l);

#[test]
fn test_call() {
fn foo(_ctx: &mut Ctx, a: i32, b: i32) -> (i32, i32) {
fn foo(_ctx: &mut vm::Ctx, a: i32, b: i32) -> (i32, i32) {
(a, b)
}

Expand All @@ -575,7 +715,7 @@ mod tests {
fn test_imports() {
use crate::{func, imports};

fn foo(_ctx: &mut Ctx, a: i32) -> i32 {
fn foo(_ctx: &mut vm::Ctx, a: i32) -> i32 {
a
}

Expand Down