diff --git a/CHANGELOG.md b/CHANGELOG.md index 8acac077215..3a38c33ff55 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## **[Unreleased]** +- [#925](https://github.com/wasmerio/wasmer/pull/925) Host functions can be closures with a captured environment. +- [#917](https://github.com/wasmerio/wasmer/pull/917) Host functions (aka imported functions) may not have `&mut vm::Ctx` as first argument, i.e. the presence of the `&mut vm::Ctx` argument is optional. +- [#915](https://github.com/wasmerio/wasmer/pull/915) All backends share the same definition of `Trampoline` (defined in `wasmer-runtime-core`). - [#952](https://github.com/wasmerio/wasmer/pull/952) Use C preprocessor to properly hide trampoline functions on Windows and non-x86_64 targets. ## 0.10.0 - 2019-11-11 diff --git a/lib/clif-backend/src/code.rs b/lib/clif-backend/src/code.rs index f1d8489da13..1bd1958e3de 100644 --- a/lib/clif-backend/src/code.rs +++ b/lib/clif-backend/src/code.rs @@ -691,7 +691,9 @@ impl FuncEnvironment for FunctionEnvironment { } /// Generates a call IR with `callee` and `call_args` and inserts it at `pos` - /// TODO: add support for imported functions + /// + /// It's about generating code that calls a local or imported function; in + /// WebAssembly: `(call $foo)`. fn translate_call( &mut self, mut pos: FuncCursor, @@ -763,20 +765,31 @@ impl FuncEnvironment for FunctionEnvironment { readonly: true, }); - let imported_vmctx_addr = pos.func.create_global_value(ir::GlobalValueData::Load { - base: imported_func_struct_addr, - offset: (vm::ImportedFunc::offset_vmctx() as i32).into(), - global_type: ptr_type, - readonly: true, - }); + let imported_func_ctx_addr = + pos.func.create_global_value(ir::GlobalValueData::Load { + base: imported_func_struct_addr, + offset: (vm::ImportedFunc::offset_func_ctx() as i32).into(), + global_type: ptr_type, + readonly: true, + }); + + let imported_func_ctx_vmctx_addr = + pos.func.create_global_value(ir::GlobalValueData::Load { + base: imported_func_ctx_addr, + offset: (vm::FuncCtx::offset_vmctx() as i32).into(), + global_type: ptr_type, + readonly: true, + }); let imported_func_addr = pos.ins().global_value(ptr_type, imported_func_addr); - let imported_vmctx_addr = pos.ins().global_value(ptr_type, imported_vmctx_addr); + let imported_func_ctx_vmctx_addr = pos + .ins() + .global_value(ptr_type, imported_func_ctx_vmctx_addr); let sig_ref = pos.func.dfg.ext_funcs[callee].signature; let mut args = Vec::with_capacity(call_args.len() + 1); - args.push(imported_vmctx_addr); + args.push(imported_func_ctx_vmctx_addr); args.extend(call_args.iter().cloned()); Ok(pos diff --git a/lib/llvm-backend/src/intrinsics.rs b/lib/llvm-backend/src/intrinsics.rs index b23f1b13cb1..717ed043d70 100644 --- a/lib/llvm-backend/src/intrinsics.rs +++ b/lib/llvm-backend/src/intrinsics.rs @@ -207,8 +207,13 @@ impl Intrinsics { context.struct_type(&[i8_ptr_ty_basic, i64_ty_basic, i8_ptr_ty_basic], false); let local_table_ty = local_memory_ty; let local_global_ty = i64_ty; - let imported_func_ty = - context.struct_type(&[i8_ptr_ty_basic, ctx_ptr_ty.as_basic_type_enum()], false); + let func_ctx_ty = + context.struct_type(&[ctx_ptr_ty.as_basic_type_enum(), i8_ptr_ty_basic], false); + let func_ctx_ptr_ty = func_ctx_ty.ptr_type(AddressSpace::Generic); + let imported_func_ty = context.struct_type( + &[i8_ptr_ty_basic, func_ctx_ptr_ty.as_basic_type_enum()], + false, + ); let sigindex_ty = i32_ty; let rt_intrinsics_ty = i8_ty; let stack_lower_bound_ty = i8_ty; @@ -1066,16 +1071,20 @@ impl<'a> CtxType<'a> { "imported_func_ptr", ) }; - let (func_ptr_ptr, ctx_ptr_ptr) = unsafe { + let (func_ptr_ptr, func_ctx_ptr_ptr) = unsafe { ( cache_builder.build_struct_gep(imported_func_ptr, 0, "func_ptr_ptr"), - cache_builder.build_struct_gep(imported_func_ptr, 1, "ctx_ptr_ptr"), + cache_builder.build_struct_gep(imported_func_ptr, 1, "func_ctx_ptr_ptr"), ) }; let func_ptr = cache_builder .build_load(func_ptr_ptr, "func_ptr") .into_pointer_value(); + let func_ctx_ptr = cache_builder + .build_load(func_ctx_ptr_ptr, "func_ctx_ptr") + .into_pointer_value(); + let ctx_ptr_ptr = unsafe { cache_builder.build_struct_gep(func_ctx_ptr, 0, "ctx_ptr") }; let ctx_ptr = cache_builder .build_load(ctx_ptr_ptr, "ctx_ptr") .into_pointer_value(); diff --git a/lib/llvm-backend/src/stackmap.rs b/lib/llvm-backend/src/stackmap.rs index a56c3c6a38d..4a9dbf81633 100644 --- a/lib/llvm-backend/src/stackmap.rs +++ b/lib/llvm-backend/src/stackmap.rs @@ -161,7 +161,7 @@ impl StackmapEntry { ValueSemantic::ImportedFuncCtx(idx) => MachineValue::VmctxDeref(vec![ Ctx::offset_imported_funcs() as usize, vm::ImportedFunc::size() as usize * idx - + vm::ImportedFunc::offset_vmctx() as usize, + + vm::ImportedFunc::offset_func_ctx() as usize, 0, ]), ValueSemantic::DynamicSigindice(idx) => { diff --git a/lib/runtime-core-tests/tests/imports.rs b/lib/runtime-core-tests/tests/imports.rs index 69b9040cdc6..6b7223c6735 100644 --- a/lib/runtime-core-tests/tests/imports.rs +++ b/lib/runtime-core-tests/tests/imports.rs @@ -1,31 +1,131 @@ use wasmer_runtime_core::{ compile_with, error::RuntimeError, imports, memory::Memory, typed_func::Func, - types::MemoryDescriptor, units::Pages, vm, + types::MemoryDescriptor, units::Pages, vm, Instance, }; use wasmer_runtime_core_tests::{get_compiler, wat2wasm}; -#[test] -fn imported_functions_forms() { +macro_rules! call_and_assert { + ($instance:ident, $function:ident, $expected_value:expr) => { + let $function: Func = $instance.func(stringify!($function)).unwrap(); + + let result = $function.call(1); + + match (result, $expected_value) { + (Ok(value), expected_value) => assert_eq!( + Ok(value), + expected_value, + concat!("Expected right when calling `", stringify!($function), "`.") + ), + ( + Err(RuntimeError::Error { data }), + Err(RuntimeError::Error { + data: expected_data, + }), + ) => { + if let (Some(data), Some(expected_data)) = ( + data.downcast_ref::<&str>(), + expected_data.downcast_ref::<&str>(), + ) { + assert_eq!( + data, expected_data, + concat!("Expected right when calling `", stringify!($function), "`.") + ) + } else if let (Some(data), Some(expected_data)) = ( + data.downcast_ref::(), + expected_data.downcast_ref::(), + ) { + assert_eq!( + data, expected_data, + concat!("Expected right when calling `", stringify!($function), "`.") + ) + } else { + assert!(false, "Unexpected error, cannot compare it.") + } + } + (result, expected_value) => assert!( + false, + format!( + "Unexpected assertion for `{}`: left = `{:?}`, right = `{:?}`.", + stringify!($function), + result, + expected_value + ) + ), + } + }; +} + +/// The shift that is set in the instance memory. The value is part of +/// the result returned by the imported functions if the memory is +/// read properly. +const SHIFT: i32 = 10; + +/// The shift that is captured in the environment of a closure. The +/// value is part of the result returned by the imported function if +/// the closure captures its environment properly. +#[allow(non_upper_case_globals)] +const shift: i32 = 100; + +fn imported_functions_forms(test: &dyn Fn(&Instance)) { const MODULE: &str = r#" (module (type $type (func (param i32) (result i32))) (import "env" "memory" (memory 1 1)) (import "env" "callback_fn" (func $callback_fn (type $type))) + (import "env" "callback_closure" (func $callback_closure (type $type))) + (import "env" "callback_closure_with_env" (func $callback_closure_with_env (type $type))) (import "env" "callback_fn_with_vmctx" (func $callback_fn_with_vmctx (type $type))) + (import "env" "callback_closure_with_vmctx" (func $callback_closure_with_vmctx (type $type))) + (import "env" "callback_closure_with_vmctx_and_env" (func $callback_closure_with_vmctx_and_env (type $type))) (import "env" "callback_fn_trap" (func $callback_fn_trap (type $type))) + (import "env" "callback_closure_trap" (func $callback_closure_trap (type $type))) (import "env" "callback_fn_trap_with_vmctx" (func $callback_fn_trap_with_vmctx (type $type))) + (import "env" "callback_closure_trap_with_vmctx" (func $callback_closure_trap_with_vmctx (type $type))) + (import "env" "callback_closure_trap_with_vmctx_and_env" (func $callback_closure_trap_with_vmctx_and_env (type $type))) + (func (export "function_fn") (type $type) get_local 0 call $callback_fn) + + (func (export "function_closure") (type $type) + get_local 0 + call $callback_closure) + + (func (export "function_closure_with_env") (type $type) + get_local 0 + call $callback_closure_with_env) + (func (export "function_fn_with_vmctx") (type $type) get_local 0 call $callback_fn_with_vmctx) + + (func (export "function_closure_with_vmctx") (type $type) + get_local 0 + call $callback_closure_with_vmctx) + + (func (export "function_closure_with_vmctx_and_env") (type $type) + get_local 0 + call $callback_closure_with_vmctx_and_env) + (func (export "function_fn_trap") (type $type) get_local 0 call $callback_fn_trap) + + (func (export "function_closure_trap") (type $type) + get_local 0 + call $callback_closure_trap) + (func (export "function_fn_trap_with_vmctx") (type $type) get_local 0 - call $callback_fn_trap_with_vmctx)) + call $callback_fn_trap_with_vmctx) + + (func (export "function_closure_trap_with_vmctx") (type $type) + get_local 0 + call $callback_closure_trap_with_vmctx) + + (func (export "function_closure_trap_with_vmctx_and_env") (type $type) + get_local 0 + call $callback_closure_trap_with_vmctx_and_env)) "#; let wasm_binary = wat2wasm(MODULE.as_bytes()).expect("WAST not valid or malformed"); @@ -33,85 +133,77 @@ fn imported_functions_forms() { let memory_descriptor = MemoryDescriptor::new(Pages(1), Some(Pages(1)), false).unwrap(); let memory = Memory::new(memory_descriptor).unwrap(); - const SHIFT: i32 = 10; memory.view()[0].set(SHIFT); let import_object = imports! { "env" => { "memory" => memory.clone(), + + // Regular function. "callback_fn" => Func::new(callback_fn), + + // Closure without a captured environment. + "callback_closure" => Func::new(|n: i32| -> Result { + Ok(n + 1) + }), + + // Closure with a captured environment (a single variable + an instance of `Memory`). + "callback_closure_with_env" => Func::new(move |n: i32| -> Result { + let shift_ = shift + memory.view::()[0].get(); + + Ok(shift_ + n + 1) + }), + + // Regular function with an explicit `vmctx`. "callback_fn_with_vmctx" => Func::new(callback_fn_with_vmctx), + + // Closure without a captured environment but with an explicit `vmctx`. + "callback_closure_with_vmctx" => Func::new(|vmctx: &mut vm::Ctx, n: i32| -> Result { + let memory = vmctx.memory(0); + let shift_: i32 = memory.view()[0].get(); + + Ok(shift_ + n + 1) + }), + + // Closure with a captured environment (a single variable) and with an explicit `vmctx`. + "callback_closure_with_vmctx_and_env" => Func::new(move |vmctx: &mut vm::Ctx, n: i32| -> Result { + let memory = vmctx.memory(0); + let shift_ = shift + memory.view::()[0].get(); + + Ok(shift_ + n + 1) + }), + + // Trap a regular function. "callback_fn_trap" => Func::new(callback_fn_trap), + + // Trap a closure without a captured environment. + "callback_closure_trap" => Func::new(|n: i32| -> Result { + Err(format!("bar {}", n + 1)) + }), + + // Trap a regular function with an explicit `vmctx`. "callback_fn_trap_with_vmctx" => Func::new(callback_fn_trap_with_vmctx), + + // Trap a closure without a captured environment but with an explicit `vmctx`. + "callback_closure_trap_with_vmctx" => Func::new(|vmctx: &mut vm::Ctx, n: i32| -> Result { + let memory = vmctx.memory(0); + let shift_: i32 = memory.view()[0].get(); + + Err(format!("qux {}", shift_ + n + 1)) + }), + + // Trap a closure with a captured environment (a single variable) and with an explicit `vmctx`. + "callback_closure_trap_with_vmctx_and_env" => Func::new(move |vmctx: &mut vm::Ctx, n: i32| -> Result { + let memory = vmctx.memory(0); + let shift_ = shift + memory.view::()[0].get(); + + Err(format!("! {}", shift_ + n + 1)) + }), }, }; let instance = module.instantiate(&import_object).unwrap(); - macro_rules! call_and_assert { - ($function:ident, $expected_value:expr) => { - let $function: Func = instance.func(stringify!($function)).unwrap(); - - let result = $function.call(1); - - match (result, $expected_value) { - (Ok(value), expected_value) => assert_eq!( - Ok(value), - expected_value, - concat!("Expected right when calling `", stringify!($function), "`.") - ), - ( - Err(RuntimeError::Error { data }), - Err(RuntimeError::Error { - data: expected_data, - }), - ) => { - if let (Some(data), Some(expected_data)) = ( - data.downcast_ref::<&str>(), - expected_data.downcast_ref::<&str>(), - ) { - assert_eq!( - data, expected_data, - concat!("Expected right when calling `", stringify!($function), "`.") - ) - } else if let (Some(data), Some(expected_data)) = ( - data.downcast_ref::(), - expected_data.downcast_ref::(), - ) { - assert_eq!( - data, expected_data, - concat!("Expected right when calling `", stringify!($function), "`.") - ) - } else { - assert!(false, "Unexpected error, cannot compare it.") - } - } - (result, expected_value) => assert!( - false, - format!( - "Unexpected assertion for `{}`: left = `{:?}`, right = `{:?}`.", - stringify!($function), - result, - expected_value - ) - ), - } - }; - } - - call_and_assert!(function_fn, Ok(2)); - call_and_assert!(function_fn_with_vmctx, Ok(2 + SHIFT)); - call_and_assert!( - function_fn_trap, - Err(RuntimeError::Error { - data: Box::new(format!("foo {}", 1)) - }) - ); - call_and_assert!( - function_fn_trap_with_vmctx, - Err(RuntimeError::Error { - data: Box::new(format!("baz {}", 2 + SHIFT)) - }) - ); + test(&instance); } fn callback_fn(n: i32) -> Result { @@ -120,18 +212,83 @@ fn callback_fn(n: i32) -> Result { fn callback_fn_with_vmctx(vmctx: &mut vm::Ctx, n: i32) -> Result { let memory = vmctx.memory(0); - let shift: i32 = memory.view()[0].get(); + let shift_: i32 = memory.view()[0].get(); - Ok(shift + n + 1) + Ok(shift_ + n + 1) } fn callback_fn_trap(n: i32) -> Result { - Err(format!("foo {}", n)) + Err(format!("foo {}", n + 1)) } fn callback_fn_trap_with_vmctx(vmctx: &mut vm::Ctx, n: i32) -> Result { let memory = vmctx.memory(0); - let shift: i32 = memory.view()[0].get(); + let shift_: i32 = memory.view()[0].get(); + + Err(format!("baz {}", shift_ + n + 1)) +} - Err(format!("baz {}", shift + n + 1)) +macro_rules! test { + ($test_name:ident, $function:ident, $expected_value:expr) => { + #[test] + fn $test_name() { + imported_functions_forms(&|instance| { + call_and_assert!(instance, $function, $expected_value); + }); + } + }; } + +test!(test_fn, function_fn, Ok(2)); +test!(test_closure, function_closure, Ok(2)); +test!( + test_closure_with_env, + function_closure_with_env, + Ok(2 + shift + SHIFT) +); +test!(test_fn_with_vmctx, function_fn_with_vmctx, Ok(2 + SHIFT)); +test!( + test_closure_with_vmctx, + function_closure_with_vmctx, + Ok(2 + SHIFT) +); +test!( + test_closure_with_vmctx_and_env, + function_closure_with_vmctx_and_env, + Ok(2 + shift + SHIFT) +); +test!( + test_fn_trap, + function_fn_trap, + Err(RuntimeError::Error { + data: Box::new(format!("foo {}", 2)) + }) +); +test!( + test_closure_trap, + function_closure_trap, + Err(RuntimeError::Error { + data: Box::new(format!("bar {}", 2)) + }) +); +test!( + test_fn_trap_with_vmctx, + function_fn_trap_with_vmctx, + Err(RuntimeError::Error { + data: Box::new(format!("baz {}", 2 + SHIFT)) + }) +); +test!( + test_closure_trap_with_vmctx, + function_closure_trap_with_vmctx, + Err(RuntimeError::Error { + data: Box::new(format!("qux {}", 2 + SHIFT)) + }) +); +test!( + test_closure_trap_with_vmctx_and_env, + function_closure_trap_with_vmctx_and_env, + Err(RuntimeError::Error { + data: Box::new(format!("! {}", 2 + shift + SHIFT)) + }) +); diff --git a/lib/runtime-core/src/backing.rs b/lib/runtime-core/src/backing.rs index 09c5743b993..c88cb953dfb 100644 --- a/lib/runtime-core/src/backing.rs +++ b/lib/runtime-core/src/backing.rs @@ -15,7 +15,11 @@ use crate::{ }, vm, }; -use std::{fmt::Debug, slice}; +use std::{ + fmt::Debug, + ptr::{self, NonNull}, + slice, +}; /// Size of the array for internal instance usage pub const INTERNALS_SIZE: usize = 256; @@ -383,9 +387,9 @@ impl LocalBacking { vmctx, ), LocalOrImport::Import(imported_func_index) => { - let vm::ImportedFunc { func, vmctx } = + let vm::ImportedFunc { func, func_ctx } = imports.vm_functions[imported_func_index]; - (func, vmctx) + (func, unsafe { func_ctx.as_ref() }.vmctx.as_ptr()) } }; @@ -416,9 +420,9 @@ impl LocalBacking { vmctx, ), LocalOrImport::Import(imported_func_index) => { - let vm::ImportedFunc { func, vmctx } = + let vm::ImportedFunc { func, func_ctx } = imports.vm_functions[imported_func_index]; - (func, vmctx) + (func, unsafe { func_ctx.as_ref() }.vmctx.as_ptr()) } }; @@ -546,6 +550,15 @@ impl ImportBacking { } } +impl Drop for ImportBacking { + fn drop(&mut self) { + // Properly drop the `vm::FuncCtx` in `vm::ImportedFunc`. + for (_imported_func_index, imported_func) in (*self.vm_functions).iter_mut() { + let _: Box = unsafe { Box::from_raw(imported_func.func_ctx.as_ptr()) }; + } + } +} + fn import_functions( module: &ModuleInner, imports: &ImportObject, @@ -569,6 +582,7 @@ fn import_functions( let import = imports.maybe_with_namespace(namespace, |namespace| namespace.get_export(name)); + match import { Some(Export::Function { func, @@ -578,10 +592,28 @@ fn import_functions( if *expected_sig == *signature { functions.push(vm::ImportedFunc { func: func.inner(), - vmctx: match ctx { - Context::External(ctx) => ctx, - Context::Internal => vmctx, - }, + func_ctx: NonNull::new(Box::into_raw(Box::new(vm::FuncCtx { + // ^^^^^^^^ `vm::FuncCtx` is purposely leaked. + // It is dropped by the specific `Drop` + // implementation of `ImportBacking`. + vmctx: NonNull::new(match ctx { + Context::External(vmctx) => vmctx, + Context::ExternalWithEnv(vmctx_, _) => { + if vmctx_.is_null() { + vmctx + } else { + vmctx_ + } + } + Context::Internal => vmctx, + }) + .expect("`vmctx` must not be null."), + func_env: match ctx { + Context::ExternalWithEnv(_, func_env) => func_env, + _ => None, + }, + }))) + .unwrap(), }); } else { link_errors.push(LinkError::IncorrectImportSignature { @@ -610,8 +642,8 @@ fn import_functions( None => { if imports.allow_missing_functions { functions.push(vm::ImportedFunc { - func: ::std::ptr::null(), - vmctx: ::std::ptr::null_mut(), + func: ptr::null(), + func_ctx: unsafe { NonNull::new_unchecked(ptr::null_mut()) }, // TODO: Non-senseā€¦ }); } else { link_errors.push(LinkError::ImportNotFound { diff --git a/lib/runtime-core/src/export.rs b/lib/runtime-core/src/export.rs index 213ea06b82f..8729d979752 100644 --- a/lib/runtime-core/src/export.rs +++ b/lib/runtime-core/src/export.rs @@ -6,13 +6,18 @@ use crate::{ module::ModuleInner, table::Table, types::FuncSig, vm, }; use indexmap::map::Iter as IndexMapIter; -use std::sync::Arc; +use std::{ptr::NonNull, sync::Arc}; /// A kind of Context. #[derive(Debug, Copy, Clone)] pub enum Context { /// External context include a mutable pointer to `Ctx`. External(*mut vm::Ctx), + + /// External context with an environment include a mutable pointer + /// to `Ctx` and an optional non-null pointer to `FuncEnv`. + ExternalWithEnv(*mut vm::Ctx, Option>), + /// Internal context. Internal, } diff --git a/lib/runtime-core/src/instance.rs b/lib/runtime-core/src/instance.rs index 8a8ea457d57..eb47c98e191 100644 --- a/lib/runtime-core/src/instance.rs +++ b/lib/runtime-core/src/instance.rs @@ -113,9 +113,13 @@ impl Instance { let ctx_ptr = match start_index.local_or_import(&instance.module.info) { LocalOrImport::Local(_) => instance.inner.vmctx, - LocalOrImport::Import(imported_func_index) => { - instance.inner.import_backing.vm_functions[imported_func_index].vmctx + LocalOrImport::Import(imported_func_index) => unsafe { + instance.inner.import_backing.vm_functions[imported_func_index] + .func_ctx + .as_ref() } + .vmctx + .as_ptr(), }; let sig_index = *instance @@ -132,7 +136,7 @@ impl Instance { .expect("wasm trampoline"); let start_func: Func<(), (), Wasm> = - unsafe { Func::from_raw_parts(wasm_trampoline, func_ptr, ctx_ptr) }; + unsafe { Func::from_raw_parts(wasm_trampoline, func_ptr, None, ctx_ptr) }; start_func.call()?; } @@ -199,9 +203,13 @@ impl Instance { let ctx = match func_index.local_or_import(&self.module.info) { LocalOrImport::Local(_) => self.inner.vmctx, - LocalOrImport::Import(imported_func_index) => { - self.inner.import_backing.vm_functions[imported_func_index].vmctx + LocalOrImport::Import(imported_func_index) => unsafe { + self.inner.import_backing.vm_functions[imported_func_index] + .func_ctx + .as_ref() } + .vmctx + .as_ptr(), }; let func_wasm_inner = self @@ -210,20 +218,26 @@ impl Instance { .get_trampoline(&self.module.info, sig_index) .unwrap(); - let func_ptr = match func_index.local_or_import(&self.module.info) { - LocalOrImport::Local(local_func_index) => self - .module - .runnable_module - .get_func(&self.module.info, local_func_index) - .unwrap(), - LocalOrImport::Import(import_func_index) => NonNull::new( - self.inner.import_backing.vm_functions[import_func_index].func as *mut _, - ) - .unwrap(), + let (func_ptr, func_env) = match func_index.local_or_import(&self.module.info) { + LocalOrImport::Local(local_func_index) => ( + self.module + .runnable_module + .get_func(&self.module.info, local_func_index) + .unwrap(), + None, + ), + LocalOrImport::Import(import_func_index) => { + let imported_func = &self.inner.import_backing.vm_functions[import_func_index]; + + ( + NonNull::new(imported_func.func as *mut _).unwrap(), + unsafe { imported_func.func_ctx.as_ref() }.func_env, + ) + } }; let typed_func: Func = - unsafe { Func::from_raw_parts(func_wasm_inner, func_ptr, ctx) }; + unsafe { Func::from_raw_parts(func_wasm_inner, func_ptr, func_env, ctx) }; Ok(typed_func) } else { @@ -412,6 +426,7 @@ impl InstanceInner { ctx: match ctx { Context::Internal => Context::External(self.vmctx), ctx @ Context::External(_) => ctx, + ctx @ Context::ExternalWithEnv(_, _) => ctx, }, signature, } @@ -454,15 +469,16 @@ impl InstanceInner { ), LocalOrImport::Import(imported_func_index) => { let imported_func = &self.import_backing.vm_functions[imported_func_index]; + let func_ctx = unsafe { imported_func.func_ctx.as_ref() }; + ( imported_func.func as *const _, - Context::External(imported_func.vmctx), + Context::ExternalWithEnv(func_ctx.vmctx.as_ptr(), func_ctx.func_env), ) } }; let signature = SigRegistry.lookup_signature_ref(&module.info.signatures[sig_index]); - // let signature = &module.info.signatures[sig_index]; (unsafe { FuncPointer::new(func_ptr) }, ctx, signature) } @@ -581,9 +597,13 @@ fn call_func_with_index( let ctx_ptr = match func_index.local_or_import(info) { LocalOrImport::Local(_) => local_ctx, - LocalOrImport::Import(imported_func_index) => { - import_backing.vm_functions[imported_func_index].vmctx + LocalOrImport::Import(imported_func_index) => unsafe { + import_backing.vm_functions[imported_func_index] + .func_ctx + .as_ref() } + .vmctx + .as_ptr(), }; let wasm = runnable diff --git a/lib/runtime-core/src/typed_func.rs b/lib/runtime-core/src/typed_func.rs index 7904cb0b6d0..d2fe59a617d 100644 --- a/lib/runtime-core/src/typed_func.rs +++ b/lib/runtime-core/src/typed_func.rs @@ -192,7 +192,7 @@ where Rets: WasmTypeList, { /// Conver to function pointer. - fn to_raw(&self) -> NonNull; + fn to_raw(self) -> (NonNull, Option>); } /// Represents a TrapEarly type. @@ -230,8 +230,9 @@ where /// Represents a function that can be used by WebAssembly. pub struct Func<'a, Args = (), Rets = (), Inner: Kind = Wasm> { inner: Inner, - f: NonNull, - ctx: *mut vm::Ctx, + func: NonNull, + func_env: Option>, + vmctx: *mut vm::Ctx, _phantom: PhantomData<(&'a (), Args, Rets)>, } @@ -245,20 +246,22 @@ where { pub(crate) unsafe fn from_raw_parts( inner: Wasm, - f: NonNull, - ctx: *mut vm::Ctx, + func: NonNull, + func_env: Option>, + vmctx: *mut vm::Ctx, ) -> Func<'a, Args, Rets, Wasm> { Func { inner, - f, - ctx, + func, + func_env, + vmctx, _phantom: PhantomData, } } /// Get the underlying func pointer. pub fn get_vm_func(&self) -> NonNull { - self.f + self.func } } @@ -268,15 +271,18 @@ where Rets: WasmTypeList, { /// Creates a new `Func`. - pub fn new(f: F) -> Func<'a, Args, Rets, Host> + pub fn new(func: F) -> Func<'a, Args, Rets, Host> where Kind: ExternalFunctionKind, F: ExternalFunction, { + let (func, func_env) = func.to_raw(); + Func { inner: Host(()), - f: f.to_raw(), - ctx: ptr::null_mut(), + func, + func_env, + vmctx: ptr::null_mut(), _phantom: PhantomData, } } @@ -414,7 +420,7 @@ where { /// Call wasm function and return results. pub fn call(&self, a: A) -> Result { - unsafe { ::call(a, self.f, self.inner, self.ctx) } + unsafe { ::call(a, self.func, self.inner, self.vmctx) } } } @@ -506,56 +512,113 @@ macro_rules! impl_traits { $( $x: WasmExternType, )* Rets: WasmTypeList, Trap: TrapEarly, - FN: Fn(&mut vm::Ctx $( , $x )*) -> Trap, + FN: Fn(&mut vm::Ctx $( , $x )*) -> Trap + 'static, { #[allow(non_snake_case)] - fn to_raw(&self) -> NonNull { - if mem::size_of::() == 0 { - /// This is required for the llvm backend to be able to unwind through this function. - #[cfg_attr(nightly, unwind(allowed))] - extern fn wrap<$( $x, )* Rets, Trap, FN>( - vmctx: &mut vm::Ctx $( , $x: <$x as WasmExternType>::Native )* - ) -> Rets::CStruct - where - $( $x: WasmExternType, )* - Rets: WasmTypeList, - Trap: TrapEarly, - FN: Fn(&mut vm::Ctx, $( $x, )*) -> Trap, - { - let f: FN = unsafe { mem::transmute_copy(&()) }; - - let err = match panic::catch_unwind( - panic::AssertUnwindSafe( - || { - f(vmctx $( , WasmExternType::from_native($x) )* ).report() - } - ) - ) { - Ok(Ok(returns)) => return returns.into_c_struct(), - Ok(Err(err)) => { - let b: Box<_> = err.into(); - b as Box - }, - Err(err) => err, - }; - - unsafe { - (&*vmctx.module).runnable_module.do_early_trap(err) - } + fn to_raw(self) -> (NonNull, Option>) { + // The `wrap` function is a wrapper around the + // imported function. It manages the argument passed + // to the imported function (in this case, the + // `vmctx` along with the regular WebAssembly + // arguments), and it manages the trapping. + // + // It is also required for the LLVM backend to be + // able to unwind through this function. + #[cfg_attr(nightly, unwind(allowed))] + extern fn wrap<$( $x, )* Rets, Trap, FN>( + vmctx: &vm::Ctx $( , $x: <$x as WasmExternType>::Native )* + ) -> Rets::CStruct + where + $( $x: WasmExternType, )* + Rets: WasmTypeList, + Trap: TrapEarly, + FN: Fn(&mut vm::Ctx, $( $x, )*) -> Trap, + { + // Get the pointer to this `wrap` function. + let self_pointer = wrap::<$( $x, )* Rets, Trap, FN> as *const vm::Func; + + // Get the collection of imported functions. + let vm_imported_functions = unsafe { &(*vmctx.import_backing).vm_functions }; + + // Retrieve the `vm::FuncCtx`. + let mut func_ctx: NonNull = vm_imported_functions + .iter() + .find_map(|(_, imported_func)| { + if imported_func.func == self_pointer { + Some(imported_func.func_ctx) + } else { + None + } + }) + .expect("Import backing is not well-formed, cannot find `func_ctx`."); + let func_ctx = unsafe { func_ctx.as_mut() }; + + // Extract `vm::Ctx` from `vm::FuncCtx`. The + // pointer is always non-null. + let vmctx = unsafe { func_ctx.vmctx.as_mut() }; + + // Extract `vm::FuncEnv` from `vm::FuncCtx`. + let func_env = func_ctx.func_env; + + let func: &FN = match func_env { + // The imported function is a regular + // function, a closure without a captured + // environment, or a closure with a captured + // environment. + Some(func_env) => unsafe { + let func: NonNull = func_env.cast(); + + &*func.as_ptr() + }, + + // This branch is supposed to be unreachable. + None => unreachable!() + }; + + // Catch unwind in case of errors. + let err = match panic::catch_unwind( + panic::AssertUnwindSafe( + || { + func(vmctx $( , WasmExternType::from_native($x) )* ).report() + // ^^^^^ The imported function + // expects `vm::Ctx` as first + // argument; provide it. + } + ) + ) { + Ok(Ok(returns)) => return returns.into_c_struct(), + Ok(Err(err)) => { + let b: Box<_> = err.into(); + b as Box + }, + Err(err) => err, + }; + + // At this point, there is an error that needs to + // be trapped. + unsafe { + (&*vmctx.module).runnable_module.do_early_trap(err) } - - NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap() - } else { - assert_eq!( - mem::size_of::(), - mem::size_of::(), - "you cannot use a closure that captures state for `Func`." - ); - - NonNull::new(unsafe { - mem::transmute_copy::<_, *mut vm::Func>(self) - }).unwrap() } + + // Extract the captured environment of the imported + // function if any. + let func_env: Option> = + // `FN` is a function pointer, or a closure + // _without_ a captured environment. + if mem::size_of::() == 0 { + NonNull::new(&self as *const _ as *mut vm::FuncEnv) + } + // `FN` is a closure _with_ a captured + // environment. + else { + NonNull::new(Box::into_raw(Box::new(self))).map(NonNull::cast) + }; + + ( + NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap(), + func_env + ) } } @@ -564,56 +627,110 @@ macro_rules! impl_traits { $( $x: WasmExternType, )* Rets: WasmTypeList, Trap: TrapEarly, - FN: Fn($( $x, )*) -> Trap, + FN: Fn($( $x, )*) -> Trap + 'static, { #[allow(non_snake_case)] - fn to_raw(&self) -> NonNull { - if mem::size_of::() == 0 { - /// This is required for the llvm backend to be able to unwind through this function. - #[cfg_attr(nightly, unwind(allowed))] - extern fn wrap<$( $x, )* Rets, Trap, FN>( - vmctx: &mut vm::Ctx $( , $x: <$x as WasmExternType>::Native )* - ) -> Rets::CStruct - where - $( $x: WasmExternType, )* - Rets: WasmTypeList, - Trap: TrapEarly, - FN: Fn($( $x, )*) -> Trap, - { - let f: FN = unsafe { mem::transmute_copy(&()) }; - - let err = match panic::catch_unwind( - panic::AssertUnwindSafe( - || { - f($( WasmExternType::from_native($x), )* ).report() - } - ) - ) { - Ok(Ok(returns)) => return returns.into_c_struct(), - Ok(Err(err)) => { - let b: Box<_> = err.into(); - b as Box - }, - Err(err) => err, - }; - - unsafe { - (&*vmctx.module).runnable_module.do_early_trap(err) - } + fn to_raw(self) -> (NonNull, Option>) { + // The `wrap` function is a wrapper around the + // imported function. It manages the argument passed + // to the imported function (in this case, only the + // regular WebAssembly arguments), and it manages the + // trapping. + // + // It is also required for the LLVM backend to be + // able to unwind through this function. + #[cfg_attr(nightly, unwind(allowed))] + extern fn wrap<$( $x, )* Rets, Trap, FN>( + vmctx: &vm::Ctx $( , $x: <$x as WasmExternType>::Native )* + ) -> Rets::CStruct + where + $( $x: WasmExternType, )* + Rets: WasmTypeList, + Trap: TrapEarly, + FN: Fn($( $x, )*) -> Trap, + { + // Get the pointer to this `wrap` function. + let self_pointer = wrap::<$( $x, )* Rets, Trap, FN> as *const vm::Func; + + // Get the collection of imported functions. + let vm_imported_functions = unsafe { &(*vmctx.import_backing).vm_functions }; + + // Retrieve the `vm::FuncCtx`. + let mut func_ctx: NonNull = vm_imported_functions + .iter() + .find_map(|(_, imported_func)| { + if imported_func.func == self_pointer { + Some(imported_func.func_ctx) + } else { + None + } + }) + .expect("Import backing is not well-formed, cannot find `func_ctx`."); + let func_ctx = unsafe { func_ctx.as_mut() }; + + // Extract `vm::Ctx` from `vm::FuncCtx`. The + // pointer is always non-null. + let vmctx = unsafe { func_ctx.vmctx.as_mut() }; + + // Extract `vm::FuncEnv` from `vm::FuncCtx`. + let func_env = func_ctx.func_env; + + let func: &FN = match func_env { + // The imported function is a regular + // function, a closure without a captured + // environment, or a closure with a captured + // environment. + Some(func_env) => unsafe { + let func: NonNull = func_env.cast(); + + &*func.as_ptr() + }, + + // This branch is supposed to be unreachable. + None => unreachable!() + }; + + // Catch unwind in case of errors. + let err = match panic::catch_unwind( + panic::AssertUnwindSafe( + || { + func($( WasmExternType::from_native($x), )* ).report() + } + ) + ) { + Ok(Ok(returns)) => return returns.into_c_struct(), + Ok(Err(err)) => { + let b: Box<_> = err.into(); + b as Box + }, + Err(err) => err, + }; + + // At this point, there is an error that needs to + // be trapped. + unsafe { + (&*vmctx.module).runnable_module.do_early_trap(err) } - - NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap() - } else { - assert_eq!( - mem::size_of::(), - mem::size_of::(), - "you cannot use a closure that captures state for `Func`." - ); - - NonNull::new(unsafe { - mem::transmute_copy::<_, *mut vm::Func>(self) - }).unwrap() } + + // Extract the captured environment of the imported + // function if any. + let func_env: Option> = + // `FN` is a function pointer, or a closure + // _without_ a captured environment. + if mem::size_of::() == 0 { + NonNull::new(&self as *const _ as *mut vm::FuncEnv) + } + // `FN` is a closure _with_ a captured + // environment. + else { + NonNull::new(Box::into_raw(Box::new(self))).map(NonNull::cast) + }; + + ( + NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap(), + func_env + ) } } @@ -629,9 +746,9 @@ macro_rules! impl_traits { unsafe { <( $( $x ),* ) as WasmTypeList>::call( ( $( $x ),* ), - self.f, + self.func, self.inner, - self.ctx + self.vmctx ) } } @@ -669,8 +786,11 @@ where Inner: Kind, { fn to_export(&self) -> Export { - let func = unsafe { FuncPointer::new(self.f.as_ptr()) }; - let ctx = Context::Internal; + let func = unsafe { FuncPointer::new(self.func.as_ptr()) }; + let ctx = match self.func_env { + func_env @ Some(_) => Context::ExternalWithEnv(self.vmctx, func_env), + None => Context::Internal, + }; let signature = Arc::new(FuncSig::new(Args::types(), Rets::types())); Export::Function { diff --git a/lib/runtime-core/src/vm.rs b/lib/runtime-core/src/vm.rs index 0584633dac8..ea63c47fc97 100644 --- a/lib/runtime-core/src/vm.rs +++ b/lib/runtime-core/src/vm.rs @@ -523,24 +523,65 @@ impl Ctx { } } -enum InnerFunc {} -/// Used to provide type safety (ish) for passing around function pointers. -/// The typesystem ensures this cannot be dereferenced since an -/// empty enum cannot actually exist. +/// Represents a function pointer. It is mostly used in the +/// `typed_func` module within the `wrap` functions, to wrap imported +/// functions. +#[repr(transparent)] +pub struct Func(pub(self) *mut c_void); + +/// Represents a function environment pointer, like a captured +/// environment of a closure. It is mostly used in the `typed_func` +/// module within the `wrap` functions, to wrap imported functions. +#[repr(transparent)] +pub struct FuncEnv(pub(self) *mut c_void); + +/// Represents a function context. It is used by imported functions +/// only. +#[derive(Debug)] #[repr(C)] -pub struct Func(InnerFunc); +pub struct FuncCtx { + /// The `Ctx` pointer. + pub(crate) vmctx: NonNull, + + /// A pointer to the function environment. It is used by imported + /// functions only to store the pointer to the real host function, + /// whether it is a regular function, or a closure with or without + /// a captured environment. + pub(crate) func_env: Option>, +} + +impl FuncCtx { + /// Offset to `vmctx`. + pub fn offset_vmctx() -> u8 { + 0 * (mem::size_of::() as u8) + } + + /// Offset to `func_env`. + pub fn offset_func_env() -> u8 { + 1 * (mem::size_of::() as u8) + } + + /// Size of a `FuncCtx`. + pub fn size() -> u8 { + mem::size_of::() as u8 + } +} -/// An imported function, which contains the vmctx that owns this function. +/// An imported function is a function pointer associated to a +/// function context. #[derive(Debug, Clone)] #[repr(C)] pub struct ImportedFunc { /// Const pointer to `Func`. - pub func: *const Func, - /// Mutable pointer to `Ctx`. - pub vmctx: *mut Ctx, + pub(crate) func: *const Func, + + /// Mutable non-null pointer to `FuncCtx`. + pub(crate) func_ctx: NonNull, } -// manually implemented because ImportedFunc contains raw pointers directly; `Func` is marked Send (But `Ctx` actually isn't! (TODO: review this, shouldn't `Ctx` be Send?)) +// Manually implemented because ImportedFunc contains raw pointers +// directly; `Func` is marked Send (But `Ctx` actually isn't! (TODO: +// review this, shouldn't `Ctx` be Send?)) unsafe impl Send for ImportedFunc {} impl ImportedFunc { @@ -550,8 +591,8 @@ impl ImportedFunc { 0 * (mem::size_of::() as u8) } - /// Offset to vmctx. - pub fn offset_vmctx() -> u8 { + /// Offset to func_ctx. + pub fn offset_func_ctx() -> u8 { 1 * (mem::size_of::() as u8) } @@ -709,7 +750,9 @@ impl Anyfunc { #[cfg(test)] mod vm_offset_tests { - use super::{Anyfunc, Ctx, ImportedFunc, InternalCtx, LocalGlobal, LocalMemory, LocalTable}; + use super::{ + Anyfunc, Ctx, FuncCtx, ImportedFunc, InternalCtx, LocalGlobal, LocalMemory, LocalTable, + }; #[test] fn vmctx() { @@ -786,6 +829,19 @@ mod vm_offset_tests { ); } + #[test] + fn func_ctx() { + assert_eq!( + FuncCtx::offset_vmctx() as usize, + offset_of!(FuncCtx => vmctx).get_byte_offset(), + ); + + assert_eq!( + FuncCtx::offset_func_env() as usize, + offset_of!(FuncCtx => func_env).get_byte_offset(), + ); + } + #[test] fn imported_func() { assert_eq!( @@ -794,8 +850,8 @@ mod vm_offset_tests { ); assert_eq!( - ImportedFunc::offset_vmctx() as usize, - offset_of!(ImportedFunc => vmctx).get_byte_offset(), + ImportedFunc::offset_func_ctx() as usize, + offset_of!(ImportedFunc => func_ctx).get_byte_offset(), ); } diff --git a/lib/singlepass-backend/src/codegen_x64.rs b/lib/singlepass-backend/src/codegen_x64.rs index f4411a74ce4..7a71a289fe2 100644 --- a/lib/singlepass-backend/src/codegen_x64.rs +++ b/lib/singlepass-backend/src/codegen_x64.rs @@ -554,18 +554,30 @@ impl ModuleCodeGenerator // Emits a tail call trampoline that loads the address of the target import function // from Ctx and jumps to it. + let imported_funcs_addr = vm::Ctx::offset_imported_funcs(); + let imported_func = vm::ImportedFunc::size() as usize * id; + let imported_func_addr = imported_func + vm::ImportedFunc::offset_func() as usize; + let imported_func_ctx_addr = imported_func + vm::ImportedFunc::offset_func_ctx() as usize; + let imported_func_ctx_vmctx_addr = vm::FuncCtx::offset_vmctx() as usize; + a.emit_mov( Size::S64, - Location::Memory(GPR::RDI, vm::Ctx::offset_imported_funcs() as i32), + Location::Memory(GPR::RDI, imported_funcs_addr as i32), Location::GPR(GPR::RAX), ); a.emit_mov( Size::S64, - Location::Memory( - GPR::RAX, - (vm::ImportedFunc::size() as usize * id + vm::ImportedFunc::offset_func() as usize) - as i32, - ), + Location::Memory(GPR::RAX, imported_func_ctx_addr as i32), + Location::GPR(GPR::RDI), + ); + a.emit_mov( + Size::S64, + Location::Memory(GPR::RDI, imported_func_ctx_vmctx_addr as i32), + Location::GPR(GPR::RDI), + ); + a.emit_mov( + Size::S64, + Location::Memory(GPR::RAX, imported_func_addr as i32), Location::GPR(GPR::RAX), ); a.emit_jmp_location(Location::GPR(GPR::RAX)); diff --git a/lib/spectests/examples/test.rs b/lib/spectests/examples/test.rs index 8ce199695fc..006fc1397d2 100644 --- a/lib/spectests/examples/test.rs +++ b/lib/spectests/examples/test.rs @@ -1,5 +1,4 @@ use wabt::wat2wasm; -use wasmer_clif_backend::CraneliftCompiler; use wasmer_runtime_core::{backend::Compiler, import::ImportObject, Instance}; fn main() {