diff --git a/lib/clif-backend/src/signal/mod.rs b/lib/clif-backend/src/signal/mod.rs index 04b97e8325c..766ad75fbfb 100644 --- a/lib/clif-backend/src/signal/mod.rs +++ b/lib/clif-backend/src/signal/mod.rs @@ -2,12 +2,13 @@ use crate::relocation::{TrapData, TrapSink}; use crate::trampoline::Trampolines; use hashbrown::HashSet; use libc::c_void; -use std::{any::Any, cell::Cell, sync::Arc}; +use std::{any::Any, cell::Cell, ptr::NonNull, sync::Arc}; use wasmer_runtime_core::{ backend::{ProtectedCaller, Token, UserTrapper}, error::RuntimeResult, export::Context, module::{ExportIndex, ModuleInfo, ModuleInner}, + typed_func::{Wasm, WasmTrapInfo}, types::{FuncIndex, FuncSig, LocalOrImport, SigIndex, Type, Value}, vm::{self, ImportBacking}, }; @@ -148,6 +149,46 @@ impl ProtectedCaller for Caller { .collect()) } + fn get_wasm_trampoline(&self, module: &ModuleInner, sig_index: SigIndex) -> Option { + unsafe extern "C" fn invoke( + trampoline: unsafe extern "C" fn(*mut vm::Ctx, NonNull, *const u64, *mut u64), + ctx: *mut vm::Ctx, + func: NonNull, + args: *const u64, + rets: *mut u64, + trap_info: *mut WasmTrapInfo, + invoke_env: Option>, + ) -> bool { + let handler_data = &*invoke_env.unwrap().cast().as_ptr(); + + #[cfg(not(target_os = "windows"))] + let res = call_protected(handler_data, || unsafe { + // Leap of faith. + trampoline(ctx, func, args, rets); + }) + .is_ok(); + + // the trampoline is called from C on windows + #[cfg(target_os = "windows")] + let res = call_protected(handler_data, trampoline, ctx, func, args, rets).is_ok(); + + res + } + + let trampoline = self + .trampolines + .lookup(sig_index) + .expect("that trampoline doesn't exist"); + + Some(unsafe { + Wasm::from_raw_parts( + trampoline, + invoke, + Some(NonNull::from(&self.handler_data).cast()), + ) + }) + } + fn get_early_trapper(&self) -> Box { Box::new(Trapper) } @@ -157,7 +198,7 @@ fn get_func_from_index<'a>( module: &'a ModuleInner, import_backing: &ImportBacking, func_index: FuncIndex, -) -> (*const vm::Func, Context, &'a FuncSig, SigIndex) { +) -> (NonNull, Context, &'a FuncSig, SigIndex) { let sig_index = *module .info .func_assoc @@ -170,14 +211,13 @@ fn get_func_from_index<'a>( .func_resolver .get(&module, local_func_index) .expect("broken invariant, func resolver not synced with module.exports") - .cast() - .as_ptr() as *const _, + .cast(), Context::Internal, ), LocalOrImport::Import(imported_func_index) => { let imported_func = import_backing.imported_func(imported_func_index); ( - imported_func.func as *const _, + NonNull::new(imported_func.func as *mut _).unwrap(), Context::External(imported_func.vmctx), ) } diff --git a/lib/clif-backend/src/signal/windows.rs b/lib/clif-backend/src/signal/windows.rs index 99145869e37..358b7323663 100644 --- a/lib/clif-backend/src/signal/windows.rs +++ b/lib/clif-backend/src/signal/windows.rs @@ -3,7 +3,7 @@ use crate::signal::HandlerData; use crate::trampoline::Trampoline; use std::cell::Cell; use std::ffi::c_void; -use std::ptr; +use std::ptr::{self, NonNull}; use wasmer_runtime_core::error::{RuntimeError, RuntimeResult}; use wasmer_runtime_core::vm::Ctx; use wasmer_runtime_core::vm::Func; @@ -25,7 +25,7 @@ pub fn call_protected( handler_data: &HandlerData, trampoline: Trampoline, ctx: *mut Ctx, - func: *const Func, + func: NonNull, param_vec: *const u64, return_vec: *mut u64, ) -> RuntimeResult<()> { diff --git a/lib/clif-backend/src/trampoline.rs b/lib/clif-backend/src/trampoline.rs index b20109ef04e..696a0b624da 100644 --- a/lib/clif-backend/src/trampoline.rs +++ b/lib/clif-backend/src/trampoline.rs @@ -7,7 +7,7 @@ use cranelift_codegen::{ }; use hashbrown::HashMap; use std::ffi::c_void; -use std::{iter, mem}; +use std::{iter, mem, ptr::NonNull}; use wasmer_runtime_core::{ backend::sys::{Memory, Protect}, module::{ExportIndex, ModuleInfo}, @@ -23,8 +23,7 @@ impl RelocSink for NullRelocSink { fn reloc_jt(&mut self, _: u32, _: Reloc, _: ir::JumpTable) {} } -pub type Trampoline = - unsafe extern "C" fn(*mut vm::Ctx, *const vm::Func, *const u64, *mut u64) -> c_void; +pub type Trampoline = unsafe extern "C" fn(*mut vm::Ctx, NonNull, *const u64, *mut u64); pub struct Trampolines { memory: Memory, diff --git a/lib/dynasm-backend/src/codegen_x64.rs b/lib/dynasm-backend/src/codegen_x64.rs index ef056a69564..30b5ea2bcb9 100644 --- a/lib/dynasm-backend/src/codegen_x64.rs +++ b/lib/dynasm-backend/src/codegen_x64.rs @@ -18,6 +18,7 @@ use wasmer_runtime_core::{ memory::MemoryType, module::{ModuleInfo, ModuleInner}, structures::{Map, TypedIndex}, + typed_func::Wasm, types::{ FuncIndex, FuncSig, ImportedMemoryIndex, LocalFuncIndex, LocalGlobalIndex, LocalMemoryIndex, LocalOrImport, MemoryIndex, SigIndex, Type, Value, @@ -459,6 +460,10 @@ impl ProtectedCaller for X64ExecutionContext { }) } + fn get_wasm_trampoline(&self, _module: &ModuleInner, _sig_index: SigIndex) -> Option { + unimplemented!() + } + fn get_early_trapper(&self) -> Box { pub struct Trapper; diff --git a/lib/emscripten/src/varargs.rs b/lib/emscripten/src/varargs.rs index cd9073cb953..3775d102cdc 100644 --- a/lib/emscripten/src/varargs.rs +++ b/lib/emscripten/src/varargs.rs @@ -20,4 +20,11 @@ impl VarArgs { unsafe impl WasmExternType for VarArgs { const TYPE: Type = Type::I32; + + fn to_bits(self) -> u64 { + self.pointer as u64 + } + fn from_bits(n: u64) -> Self { + Self { pointer: n as u32 } + } } diff --git a/lib/llvm-backend/cpp/object_loader.hh b/lib/llvm-backend/cpp/object_loader.hh index d22acb919b0..0d6bcc603e4 100644 --- a/lib/llvm-backend/cpp/object_loader.hh +++ b/lib/llvm-backend/cpp/object_loader.hh @@ -5,14 +5,16 @@ #include #include -typedef enum { +typedef enum +{ PROTECT_NONE, PROTECT_READ, PROTECT_READ_WRITE, PROTECT_READ_EXECUTE, } mem_protect_t; -typedef enum { +typedef enum +{ RESULT_OK, RESULT_ALLOCATE_FAILURE, RESULT_PROTECT_FAILURE, @@ -20,16 +22,17 @@ typedef enum { RESULT_OBJECT_LOAD_FAILURE, } result_t; -typedef result_t (*alloc_memory_t)(size_t size, mem_protect_t protect, uint8_t** ptr_out, size_t* size_out); -typedef result_t (*protect_memory_t)(uint8_t* ptr, size_t size, mem_protect_t protect); -typedef result_t (*dealloc_memory_t)(uint8_t* ptr, size_t size); -typedef uintptr_t (*lookup_vm_symbol_t)(const char* name_ptr, size_t length); +typedef result_t (*alloc_memory_t)(size_t size, mem_protect_t protect, uint8_t **ptr_out, size_t *size_out); +typedef result_t (*protect_memory_t)(uint8_t *ptr, size_t size, mem_protect_t protect); +typedef result_t (*dealloc_memory_t)(uint8_t *ptr, size_t size); +typedef uintptr_t (*lookup_vm_symbol_t)(const char *name_ptr, size_t length); typedef void (*fde_visitor_t)(uint8_t *fde); typedef result_t (*visit_fde_t)(uint8_t *fde, size_t size, fde_visitor_t visitor); -typedef void (*trampoline_t)(void*, void*, void*, void*); +typedef void (*trampoline_t)(void *, void *, void *, void *); -typedef struct { +typedef struct +{ /* Memory management. */ alloc_memory_t alloc_memory; protect_memory_t protect_memory; @@ -40,32 +43,40 @@ typedef struct { visit_fde_t visit_fde; } callbacks_t; -struct WasmException { -public: +struct WasmException +{ + public: virtual std::string description() const noexcept = 0; }; -struct UncatchableException : WasmException { -public: - virtual std::string description() const noexcept override { +struct UncatchableException : WasmException +{ + public: + virtual std::string description() const noexcept override + { return "Uncatchable exception"; } }; -struct UserException : UncatchableException { -public: +struct UserException : UncatchableException +{ + public: UserException(std::string msg) : msg(msg) {} - virtual std::string description() const noexcept override { + virtual std::string description() const noexcept override + { return std::string("user exception: ") + msg; } -private: + + private: std::string msg; }; -struct WasmTrap : UncatchableException { -public: - enum Type { +struct WasmTrap : UncatchableException +{ + public: + enum Type + { Unreachable = 0, IncorrectCallIndirectSignature = 1, MemoryOutOfBounds = 2, @@ -76,49 +87,54 @@ public: WasmTrap(Type type) : type(type) {} - virtual std::string description() const noexcept override { + virtual std::string description() const noexcept override + { std::ostringstream ss; ss << "WebAssembly trap:" << '\n' << " - type: " << type << '\n'; - + return ss.str(); } Type type; -private: - friend std::ostream& operator<<(std::ostream& out, const Type& ty) { - switch (ty) { - case Type::Unreachable: - out << "unreachable"; - break; - case Type::IncorrectCallIndirectSignature: - out << "incorrect call_indirect signature"; - break; - case Type::MemoryOutOfBounds: - out << "memory access out-of-bounds"; - break; - case Type::CallIndirectOOB: - out << "call_indirect out-of-bounds"; - break; - case Type::IllegalArithmetic: - out << "illegal arithmetic operation"; - break; - case Type::Unknown: - default: - out << "unknown"; - break; + private: + friend std::ostream &operator<<(std::ostream &out, const Type &ty) + { + switch (ty) + { + case Type::Unreachable: + out << "unreachable"; + break; + case Type::IncorrectCallIndirectSignature: + out << "incorrect call_indirect signature"; + break; + case Type::MemoryOutOfBounds: + out << "memory access out-of-bounds"; + break; + case Type::CallIndirectOOB: + out << "call_indirect out-of-bounds"; + break; + case Type::IllegalArithmetic: + out << "illegal arithmetic operation"; + break; + case Type::Unknown: + default: + out << "unknown"; + break; } return out; } }; -struct CatchableException : WasmException { -public: +struct CatchableException : WasmException +{ + public: CatchableException(uint32_t type_id, uint32_t value_num) : type_id(type_id), value_num(value_num) {} - virtual std::string description() const noexcept override { + virtual std::string description() const noexcept override + { return "catchable exception"; } @@ -126,23 +142,26 @@ public: uint64_t values[1]; }; -struct WasmModule { -public: +struct WasmModule +{ + public: WasmModule( const uint8_t *object_start, size_t object_size, - callbacks_t callbacks - ); + callbacks_t callbacks); void *get_func(llvm::StringRef name) const; -private: + + private: std::unique_ptr memory_manager; std::unique_ptr object_file; std::unique_ptr runtime_dyld; }; -extern "C" { - result_t module_load(const uint8_t* mem_ptr, size_t mem_size, callbacks_t callbacks, WasmModule** module_out) { +extern "C" +{ + result_t module_load(const uint8_t *mem_ptr, size_t mem_size, callbacks_t callbacks, WasmModule **module_out) + { *module_out = new WasmModule(mem_ptr, mem_size, callbacks); return RESULT_OK; @@ -152,34 +171,44 @@ extern "C" { throw WasmTrap(ty); } - void module_delete(WasmModule* module) { + void module_delete(WasmModule *module) + { delete module; } bool invoke_trampoline( trampoline_t trampoline, - void* ctx, - void* func, - void* params, - void* results, - WasmTrap::Type* trap_out - ) throw() { - try { + void *ctx, + void *func, + void *params, + void *results, + WasmTrap::Type *trap_out, + void *invoke_env) throw() + { + try + { trampoline(ctx, func, params, results); return true; - } catch(const WasmTrap& e) { + } + catch (const WasmTrap &e) + { *trap_out = e.type; return false; - } catch(const WasmException& e) { + } + catch (const WasmException &e) + { *trap_out = WasmTrap::Type::Unknown; return false; - } catch (...) { + } + catch (...) + { *trap_out = WasmTrap::Type::Unknown; return false; } } - void* get_func_symbol(WasmModule* module, const char* name) { + void *get_func_symbol(WasmModule *module, const char *name) + { return module->get_func(llvm::StringRef(name)); } } \ No newline at end of file diff --git a/lib/llvm-backend/src/backend.rs b/lib/llvm-backend/src/backend.rs index 5a692c40e52..1a82cfaf266 100644 --- a/lib/llvm-backend/src/backend.rs +++ b/lib/llvm-backend/src/backend.rs @@ -11,7 +11,7 @@ use libc::{ }; use std::{ any::Any, - ffi::CString, + ffi::{c_void, CString}, mem, ptr::{self, NonNull}, slice, str, @@ -23,6 +23,7 @@ use wasmer_runtime_core::{ export::Context, module::{ModuleInfo, ModuleInner}, structures::TypedIndex, + typed_func::{Wasm, WasmTrapInfo}, types::{FuncIndex, FuncSig, LocalFuncIndex, LocalOrImport, SigIndex, Type, Value}, vm::{self, ImportBacking}, vmcalls, @@ -54,17 +55,6 @@ enum LLVMResult { OBJECT_LOAD_FAILURE, } -#[allow(dead_code)] -#[repr(C)] -enum WasmTrapType { - Unreachable = 0, - IncorrectCallIndirectSignature = 1, - MemoryOutOfBounds = 2, - CallIndirectOOB = 3, - IllegalArithmetic = 4, - Unknown, -} - #[repr(C)] struct Callbacks { alloc_memory: extern "C" fn(usize, MemProtect, &mut *mut u8, &mut usize) -> LLVMResult, @@ -87,13 +77,15 @@ extern "C" { fn throw_trap(ty: i32); + #[allow(improper_ctypes)] fn invoke_trampoline( - trampoline: unsafe extern "C" fn(*mut vm::Ctx, *const vm::Func, *const u64, *mut u64), + trampoline: unsafe extern "C" fn(*mut vm::Ctx, NonNull, *const u64, *mut u64), vmctx_ptr: *mut vm::Ctx, - func_ptr: *const vm::Func, + func_ptr: NonNull, params: *const u64, results: *mut u64, - trap_out: *mut WasmTrapType, + trap_out: *mut WasmTrapInfo, + invoke_env: Option>, ) -> bool; } @@ -360,7 +352,12 @@ impl ProtectedCaller for LLVMProtectedCaller { let mut return_vec = vec![0; signature.returns().len()]; - let trampoline: unsafe extern "C" fn(*mut vm::Ctx, *const vm::Func, *const u64, *mut u64) = unsafe { + let trampoline: unsafe extern "C" fn( + *mut vm::Ctx, + NonNull, + *const u64, + *mut u64, + ) = unsafe { let name = if cfg!(target_os = "macos") { format!("_trmp{}", sig_index.index()) } else { @@ -374,7 +371,7 @@ impl ProtectedCaller for LLVMProtectedCaller { mem::transmute(symbol) }; - let mut trap_out = WasmTrapType::Unknown; + let mut trap_out = WasmTrapInfo::Unknown; // Here we go. let success = unsafe { @@ -385,6 +382,7 @@ impl ProtectedCaller for LLVMProtectedCaller { param_vec.as_ptr(), return_vec.as_mut_ptr(), &mut trap_out, + None, ) }; @@ -400,29 +398,35 @@ impl ProtectedCaller for LLVMProtectedCaller { }) .collect()) } else { - Err(match trap_out { - WasmTrapType::Unreachable => RuntimeError::Trap { - msg: "unreachable".into(), - }, - WasmTrapType::IncorrectCallIndirectSignature => RuntimeError::Trap { - msg: "uncorrect call_indirect signature".into(), - }, - WasmTrapType::MemoryOutOfBounds => RuntimeError::Trap { - msg: "memory out-of-bounds access".into(), - }, - WasmTrapType::CallIndirectOOB => RuntimeError::Trap { - msg: "call_indirect out-of-bounds".into(), - }, - WasmTrapType::IllegalArithmetic => RuntimeError::Trap { - msg: "illegal arithmetic operation".into(), - }, - WasmTrapType::Unknown => RuntimeError::Trap { - msg: "unknown trap".into(), - }, + Err(RuntimeError::Trap { + msg: trap_out.to_string().into(), }) } } + fn get_wasm_trampoline(&self, _module: &ModuleInner, sig_index: SigIndex) -> Option { + let trampoline: unsafe extern "C" fn( + *mut vm::Ctx, + NonNull, + *const u64, + *mut u64, + ) = unsafe { + let name = if cfg!(target_os = "macos") { + format!("_trmp{}", sig_index.index()) + } else { + format!("trmp{}", sig_index.index()) + }; + + let c_str = CString::new(name).unwrap(); + let symbol = get_func_symbol(self.module, c_str.as_ptr()); + assert!(!symbol.is_null()); + + mem::transmute(symbol) + }; + + Some(unsafe { Wasm::from_raw_parts(trampoline, invoke_trampoline, None) }) + } + fn get_early_trapper(&self) -> Box { Box::new(Placeholder) } @@ -438,7 +442,7 @@ fn get_func_from_index<'a>( module: &'a ModuleInner, import_backing: &ImportBacking, func_index: FuncIndex, -) -> (*const vm::Func, Context, &'a FuncSig, SigIndex) { +) -> (NonNull, Context, &'a FuncSig, SigIndex) { let sig_index = *module .info .func_assoc @@ -451,14 +455,13 @@ fn get_func_from_index<'a>( .func_resolver .get(&module, local_func_index) .expect("broken invariant, func resolver not synced with module.exports") - .cast() - .as_ptr() as *const _, + .cast(), Context::Internal, ), LocalOrImport::Import(imported_func_index) => { let imported_func = import_backing.imported_func(imported_func_index); ( - imported_func.func as *const _, + NonNull::new(imported_func.func as *mut _).unwrap(), Context::External(imported_func.vmctx), ) } diff --git a/lib/runtime-core/src/backend.rs b/lib/runtime-core/src/backend.rs index 94c5c87e3ec..1e528651d92 100644 --- a/lib/runtime-core/src/backend.rs +++ b/lib/runtime-core/src/backend.rs @@ -3,7 +3,8 @@ use crate::{ error::CompileResult, error::RuntimeResult, module::ModuleInner, - types::{FuncIndex, LocalFuncIndex, Value}, + typed_func::Wasm, + types::{FuncIndex, LocalFuncIndex, SigIndex, Value}, vm, }; @@ -85,6 +86,10 @@ pub trait ProtectedCaller: Send + Sync { /// /// The existance of the Token parameter ensures that this can only be called from /// within the runtime crate. + /// + /// TODO(lachlan): Now that `get_wasm_trampoline` exists, `ProtectedCaller::call` + /// can be removed. That should speed up calls a little bit, since sanity checks + /// would only occur once. fn call( &self, module: &ModuleInner, @@ -95,6 +100,11 @@ pub trait ProtectedCaller: Send + Sync { _: Token, ) -> RuntimeResult>; + /// A wasm trampoline contains the necesarry data to dynamically call an exported wasm function. + /// Given a particular signature index, we are returned a trampoline that is matched with that + /// signature and an invoke function that can call the trampoline. + fn get_wasm_trampoline(&self, module: &ModuleInner, sig_index: SigIndex) -> Option; + fn get_early_trapper(&self) -> Box; } diff --git a/lib/runtime-core/src/instance.rs b/lib/runtime-core/src/instance.rs index 956b95d7b5b..a88e51e98af 100644 --- a/lib/runtime-core/src/instance.rs +++ b/lib/runtime-core/src/instance.rs @@ -9,11 +9,11 @@ use crate::{ module::{ExportIndex, Module, ModuleInner}, sig_registry::SigRegistry, table::Table, - typed_func::{Func, Safe, WasmTypeList}, + typed_func::{Func, Wasm, WasmTypeList}, types::{FuncIndex, FuncSig, GlobalIndex, LocalOrImport, MemoryIndex, TableIndex, Value}, vm, }; -use std::{mem, sync::Arc}; +use std::{mem, ptr::NonNull, sync::Arc}; pub(crate) struct InstanceInner { #[allow(dead_code)] @@ -107,7 +107,7 @@ impl Instance { /// # Ok(()) /// # } /// ``` - pub fn func(&self, name: &str) -> ResolveResult> + pub fn func(&self, name: &str) -> ResolveResult> where Args: WasmTypeList, Rets: WasmTypeList, @@ -145,20 +145,26 @@ impl Instance { } }; + let func_wasm_inner = self + .module + .protected_caller + .get_wasm_trampoline(&self.module, sig_index) + .unwrap(); + let func_ptr = match func_index.local_or_import(&self.module.info) { LocalOrImport::Local(local_func_index) => self .module .func_resolver .get(&self.module, local_func_index) - .unwrap() - .as_ptr(), - LocalOrImport::Import(import_func_index) => { - self.inner.import_backing.vm_functions[import_func_index].func - } + .unwrap(), + LocalOrImport::Import(import_func_index) => NonNull::new( + self.inner.import_backing.vm_functions[import_func_index].func as *mut _, + ) + .unwrap(), }; - let typed_func: Func = - unsafe { Func::new_from_ptr(func_ptr as _, ctx) }; + let typed_func: Func = + unsafe { Func::from_raw_parts(func_wasm_inner, func_ptr, ctx) }; Ok(typed_func) } else { diff --git a/lib/runtime-core/src/lib.rs b/lib/runtime-core/src/lib.rs index 7042983d213..ef79f6cc8e2 100644 --- a/lib/runtime-core/src/lib.rs +++ b/lib/runtime-core/src/lib.rs @@ -23,7 +23,7 @@ mod sig_registry; pub mod structures; mod sys; pub mod table; -mod typed_func; +pub mod typed_func; pub mod types; pub mod units; pub mod vm; diff --git a/lib/runtime-core/src/typed_func.rs b/lib/runtime-core/src/typed_func.rs index 4191f347eec..2ca67dba535 100644 --- a/lib/runtime-core/src/typed_func.rs +++ b/lib/runtime-core/src/typed_func.rs @@ -4,26 +4,112 @@ use crate::{ export::{Context, Export, FuncPointer}, import::IsExport, types::{FuncSig, Type, WasmExternType}, - vm::Ctx, + vm::{self, Ctx}, +}; +use std::{ + any::Any, + cell::UnsafeCell, + ffi::c_void, + fmt, + marker::PhantomData, + mem, panic, + ptr::{self, NonNull}, + sync::Arc, }; -use std::{any::Any, cell::UnsafeCell, marker::PhantomData, mem, panic, ptr, sync::Arc}; thread_local! { pub static EARLY_TRAPPER: UnsafeCell>> = UnsafeCell::new(None); } -pub trait Safeness {} -pub struct Safe; -pub struct Unsafe; -impl Safeness for Safe {} -impl Safeness for Unsafe {} +#[repr(C)] +pub enum WasmTrapInfo { + Unreachable = 0, + IncorrectCallIndirectSignature = 1, + MemoryOutOfBounds = 2, + CallIndirectOOB = 3, + IllegalArithmetic = 4, + Unknown, +} + +impl fmt::Display for WasmTrapInfo { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + match self { + WasmTrapInfo::Unreachable => "unreachable", + WasmTrapInfo::IncorrectCallIndirectSignature => { + "incorrect `call_indirect` signature" + } + WasmTrapInfo::MemoryOutOfBounds => "memory out-of-bounds access", + WasmTrapInfo::CallIndirectOOB => "`call_indirect` out-of-bounds", + WasmTrapInfo::IllegalArithmetic => "illegal arithmetic operation", + WasmTrapInfo::Unknown => "unknown", + } + ) + } +} + +/// This is just an empty trait to constrict that types that +/// can be put into the third/fourth (depending if you include lifetimes) +/// of the `Func` struct. +pub trait Kind {} + +pub type Trampoline = unsafe extern "C" fn(*mut Ctx, NonNull, *const u64, *mut u64); +pub type Invoke = unsafe extern "C" fn( + Trampoline, + *mut Ctx, + NonNull, + *const u64, + *mut u64, + *mut WasmTrapInfo, + Option>, +) -> bool; + +/// TODO(lachlan): Naming TBD. +/// This contains the trampoline and invoke functions for a specific signature, +/// as well as the environment that the invoke function may or may not require. +#[derive(Copy, Clone)] +pub struct Wasm { + trampoline: Trampoline, + invoke: Invoke, + invoke_env: Option>, +} + +impl Wasm { + pub unsafe fn from_raw_parts( + trampoline: Trampoline, + invoke: Invoke, + invoke_env: Option>, + ) -> Self { + Self { + trampoline, + invoke, + invoke_env, + } + } +} + +/// This type, as part of the `Func` type signature, represents a function that is created +/// by the host. +pub struct Host(()); +impl Kind for Wasm {} +impl Kind for Host {} pub trait WasmTypeList { type CStruct; + type RetArray: AsMut<[u64]>; + fn from_ret_array(array: Self::RetArray) -> Self; + fn empty_ret_array() -> Self::RetArray; fn from_c_struct(c_struct: Self::CStruct) -> Self; fn into_c_struct(self) -> Self::CStruct; fn types() -> &'static [Type]; - unsafe fn call(self, f: *const (), ctx: *mut Ctx) -> Rets + unsafe fn call( + self, + f: NonNull, + wasm: Wasm, + ctx: *mut Ctx, + ) -> Result where Rets: WasmTypeList; } @@ -33,7 +119,7 @@ where Args: WasmTypeList, Rets: WasmTypeList, { - fn to_raw(&self) -> *const (); + fn to_raw(&self) -> NonNull; } pub trait TrapEarly @@ -71,19 +157,25 @@ where // Func::new(f) // } -pub struct Func<'a, Args = (), Rets = (), Safety: Safeness = Safe> { - f: *const (), +pub struct Func<'a, Args = (), Rets = (), Inner: Kind = Wasm> { + inner: Inner, + f: NonNull, ctx: *mut Ctx, - _phantom: PhantomData<(&'a (), Safety, Args, Rets)>, + _phantom: PhantomData<(&'a (), Args, Rets)>, } -impl<'a, Args, Rets> Func<'a, Args, Rets, Safe> +impl<'a, Args, Rets> Func<'a, Args, Rets, Wasm> where Args: WasmTypeList, Rets: WasmTypeList, { - pub(crate) unsafe fn new_from_ptr(f: *const (), ctx: *mut Ctx) -> Func<'a, Args, Rets, Safe> { + pub(crate) unsafe fn from_raw_parts( + inner: Wasm, + f: NonNull, + ctx: *mut Ctx, + ) -> Func<'a, Args, Rets, Wasm> { Func { + inner, f, ctx, _phantom: PhantomData, @@ -91,16 +183,17 @@ where } } -impl<'a, Args, Rets> Func<'a, Args, Rets, Unsafe> +impl<'a, Args, Rets> Func<'a, Args, Rets, Host> where Args: WasmTypeList, Rets: WasmTypeList, { - pub fn new(f: F) -> Func<'a, Args, Rets, Unsafe> + pub fn new(f: F) -> Func<'a, Args, Rets, Host> where F: ExternalFunction, { Func { + inner: Host(()), f: f.to_raw(), ctx: ptr::null_mut(), _phantom: PhantomData, @@ -108,11 +201,11 @@ where } } -impl<'a, Args, Rets, Safety> Func<'a, Args, Rets, Safety> +impl<'a, Args, Rets, Inner> Func<'a, Args, Rets, Inner> where Args: WasmTypeList, Rets: WasmTypeList, - Safety: Safeness, + Inner: Kind, { pub fn params(&self) -> &'static [Type] { Args::types() @@ -124,6 +217,13 @@ where impl WasmTypeList for (A,) { type CStruct = S1; + type RetArray = [u64; 1]; + fn from_ret_array(array: Self::RetArray) -> Self { + (WasmExternType::from_bits(array[0]),) + } + fn empty_ret_array() -> Self::RetArray { + [0u64] + } fn from_c_struct(c_struct: Self::CStruct) -> Self { let S1(a) = c_struct; (a,) @@ -137,19 +237,46 @@ impl WasmTypeList for (A,) { &[A::TYPE] } #[allow(non_snake_case)] - unsafe fn call(self, f: *const (), ctx: *mut Ctx) -> Rets { - let f: extern "C" fn(*mut Ctx, A) -> Rets = mem::transmute(f); + unsafe fn call( + self, + f: NonNull, + wasm: Wasm, + ctx: *mut Ctx, + ) -> Result { + // type Trampoline = extern "C" fn(*mut Ctx, *const c_void, *const u64, *mut u64); + // type Invoke = extern "C" fn(Trampoline, *mut Ctx, *const c_void, *const u64, *mut u64, &mut WasmTrapInfo) -> bool; + let (a,) = self; - f(ctx, a) + let args = [a.to_bits()]; + let mut rets = Rets::empty_ret_array(); + let mut trap = WasmTrapInfo::Unknown; + + if (wasm.invoke)( + wasm.trampoline, + ctx, + f, + args.as_ptr(), + rets.as_mut().as_mut_ptr(), + &mut trap, + wasm.invoke_env, + ) { + Ok(Rets::from_ret_array(rets)) + } else { + Err(trap) + } } } -impl<'a, A: WasmExternType, Rets> Func<'a, (A,), Rets, Safe> +impl<'a, A: WasmExternType, Rets> Func<'a, (A,), Rets, Wasm> where Rets: WasmTypeList, { pub fn call(&self, a: A) -> Result { - Ok(unsafe { ::call(a, self.f, self.ctx) }) + unsafe { ::call(a, self.f, self.inner, self.ctx) }.map_err(|e| { + RuntimeError::Trap { + msg: e.to_string().into(), + } + }) } } @@ -160,6 +287,15 @@ macro_rules! impl_traits { impl< $( $x: WasmExternType, )* > WasmTypeList for ( $( $x ),* ) { type CStruct = $struct_name<$( $x ),*>; + type RetArray = [u64; count_idents!( $( $x ),* )]; + fn from_ret_array(array: Self::RetArray) -> Self { + #[allow(non_snake_case)] + let [ $( $x ),* ] = array; + ( $( WasmExternType::from_bits($x) ),* ) + } + fn empty_ret_array() -> Self::RetArray { + [0; count_idents!( $( $x ),* )] + } fn from_c_struct(c_struct: Self::CStruct) -> Self { #[allow(non_snake_case)] let $struct_name ( $( $x ),* ) = c_struct; @@ -174,18 +310,33 @@ macro_rules! impl_traits { &[$( $x::TYPE, )*] } #[allow(non_snake_case)] - unsafe fn call(self, f: *const (), ctx: *mut Ctx) -> Rets { - let f: extern fn(*mut Ctx $( ,$x )*) -> Rets::CStruct = mem::transmute(f); + unsafe fn call(self, f: NonNull, wasm: Wasm, ctx: *mut Ctx) -> Result { + // type Trampoline = extern "C" fn(*mut Ctx, *const c_void, *const u64, *mut u64); + // type Invoke = extern "C" fn(Trampoline, *mut Ctx, *const c_void, *const u64, *mut u64, &mut WasmTrapInfo) -> bool; + #[allow(unused_parens)] let ( $( $x ),* ) = self; - let c_struct = f(ctx $( ,$x )*); - Rets::from_c_struct(c_struct) + let args = [ $( $x.to_bits() ),* ]; + let mut rets = Rets::empty_ret_array(); + let mut trap = WasmTrapInfo::Unknown; + + if (wasm.invoke)(wasm.trampoline, ctx, f, args.as_ptr(), rets.as_mut().as_mut_ptr(), &mut trap, wasm.invoke_env) { + Ok(Rets::from_ret_array(rets)) + } else { + Err(trap) + } + + // let f: extern fn(*mut Ctx $( ,$x )*) -> Rets::CStruct = mem::transmute(f); + // #[allow(unused_parens)] + // let ( $( $x ),* ) = self; + // let c_struct = f(ctx $( ,$x )*); + // Rets::from_c_struct(c_struct) } } impl< $( $x: WasmExternType, )* Rets: WasmTypeList, Trap: TrapEarly, FN: Fn( &mut Ctx $( ,$x )* ) -> Trap> ExternalFunction<($( $x ),*), Rets> for FN { #[allow(non_snake_case)] - fn to_raw(&self) -> *const () { + fn to_raw(&self) -> NonNull { assert_eq!(mem::size_of::(), 0, "you cannot use a closure that captures state for `Func`."); extern fn wrap<$( $x: WasmExternType, )* Rets: WasmTypeList, Trap: TrapEarly, FN: Fn( &mut Ctx $( ,$x )* ) -> Trap>( ctx: &mut Ctx $( ,$x: $x )* ) -> Rets::CStruct { @@ -209,23 +360,36 @@ macro_rules! impl_traits { } } - wrap::<$( $x, )* Rets, Trap, Self> as *const () + NonNull::new(wrap::<$( $x, )* Rets, Trap, Self> as *mut vm::Func).unwrap() } } - impl<'a, $( $x: WasmExternType, )* Rets> Func<'a, ( $( $x ),* ), Rets, Safe> + impl<'a, $( $x: WasmExternType, )* Rets> Func<'a, ( $( $x ),* ), Rets, Wasm> where Rets: WasmTypeList, { #[allow(non_snake_case)] pub fn call(&self, $( $x: $x, )* ) -> Result { #[allow(unused_parens)] - Ok(unsafe { <( $( $x ),* ) as WasmTypeList>::call(( $($x),* ), self.f, self.ctx) }) + unsafe { <( $( $x ),* ) as WasmTypeList>::call(( $($x),* ), self.f, self.inner, self.ctx) }.map_err(|e| { + RuntimeError::Trap { + msg: e.to_string().into(), + } + }) } } }; } +macro_rules! count_idents { + ( $($idents:ident),* ) => {{ + #[allow(dead_code, non_camel_case_types)] + enum Idents { $($idents,)* __CountIdentsLast } + const COUNT: usize = Idents::__CountIdentsLast as usize; + COUNT + }}; +} + impl_traits!([C] S0,); impl_traits!([transparent] S1, A); impl_traits!([C] S2, A, B); @@ -240,14 +404,14 @@ impl_traits!([C] S10, A, B, C, D, E, F, G, H, I, J); impl_traits!([C] S11, A, B, C, D, E, F, G, H, I, J, K); impl_traits!([C] S12, A, B, C, D, E, F, G, H, I, J, K, L); -impl<'a, Args, Rets, Safety> IsExport for Func<'a, Args, Rets, Safety> +impl<'a, Args, Rets, Inner> IsExport for Func<'a, Args, Rets, Inner> where Args: WasmTypeList, Rets: WasmTypeList, - Safety: Safeness, + Inner: Kind, { fn to_export(&self) -> Export { - let func = unsafe { FuncPointer::new(self.f as _) }; + let func = unsafe { FuncPointer::new(self.f.as_ptr()) }; let ctx = Context::Internal; let signature = Arc::new(FuncSig::new(Args::types(), Rets::types())); diff --git a/lib/runtime-core/src/types.rs b/lib/runtime-core/src/types.rs index 966bd4e3927..79ef5da7f9c 100644 --- a/lib/runtime-core/src/types.rs +++ b/lib/runtime-core/src/types.rs @@ -76,37 +76,99 @@ where Self: Sized, { const TYPE: Type; + fn to_bits(self) -> u64; + fn from_bits(n: u64) -> Self; } unsafe impl WasmExternType for i8 { const TYPE: Type = Type::I32; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for u8 { const TYPE: Type = Type::I32; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for i16 { const TYPE: Type = Type::I32; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for u16 { const TYPE: Type = Type::I32; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for i32 { const TYPE: Type = Type::I32; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for u32 { const TYPE: Type = Type::I32; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for i64 { const TYPE: Type = Type::I64; + fn to_bits(self) -> u64 { + self as u64 + } + fn from_bits(n: u64) -> Self { + n as _ + } } unsafe impl WasmExternType for u64 { const TYPE: Type = Type::I64; + fn to_bits(self) -> u64 { + self + } + fn from_bits(n: u64) -> Self { + n + } } unsafe impl WasmExternType for f32 { const TYPE: Type = Type::F32; + fn to_bits(self) -> u64 { + self.to_bits() as u64 + } + fn from_bits(n: u64) -> Self { + f32::from_bits(n as u32) + } } unsafe impl WasmExternType for f64 { const TYPE: Type = Type::F64; + fn to_bits(self) -> u64 { + self.to_bits() + } + fn from_bits(n: u64) -> Self { + f64::from_bits(n) + } } // pub trait IntegerAtomic diff --git a/lib/runtime-core/src/vm.rs b/lib/runtime-core/src/vm.rs index 2802123f160..3b2ca9938dd 100644 --- a/lib/runtime-core/src/vm.rs +++ b/lib/runtime-core/src/vm.rs @@ -549,7 +549,8 @@ mod vm_ctx_tests { }; use crate::cache::Error as CacheError; use crate::error::RuntimeResult; - use crate::types::{FuncIndex, LocalFuncIndex, Value}; + use crate::typed_func::Wasm; + use crate::types::{FuncIndex, LocalFuncIndex, SigIndex, Value}; use hashbrown::HashMap; use std::ptr::NonNull; struct Placeholder; @@ -574,6 +575,13 @@ mod vm_ctx_tests { ) -> RuntimeResult> { Ok(vec![]) } + fn get_wasm_trampoline( + &self, + _module: &ModuleInner, + _sig_index: SigIndex, + ) -> Option { + unimplemented!() + } fn get_early_trapper(&self) -> Box { unimplemented!() } diff --git a/lib/runtime/examples/call.rs b/lib/runtime/examples/call.rs index dbd29c7e851..104fa9692f7 100644 --- a/lib/runtime/examples/call.rs +++ b/lib/runtime/examples/call.rs @@ -5,7 +5,10 @@ use wabt::wat2wasm; static WAT: &'static str = r#" (module (type (;0;) (func (result i32))) + (import "env" "do_panic" (func $do_panic (type 0))) (func $dbz (result i32) + call $do_panic + drop i32.const 42 i32.const 0 i32.div_u @@ -33,6 +36,10 @@ fn foobar(ctx: &mut Ctx) -> i32 { 42 } +fn do_panic(ctx: &mut Ctx) -> Result { + Err("error".to_string()) +} + fn main() -> Result<(), error::Error> { let wasm = get_wasm(); @@ -46,11 +53,15 @@ fn main() -> Result<(), error::Error> { // }; println!("instantiating"); - let instance = module.instantiate(&imports! {})?; + let instance = module.instantiate(&imports! { + "env" => { + "do_panic" => Func::new(do_panic), + }, + })?; - let foo = instance.dyn_func("dbz")?; + let foo: Func<(), i32> = instance.func("dbz")?; - let result = foo.call(&[]); + let result = foo.call(); println!("result: {:?}", result); diff --git a/lib/wasi/src/ptr.rs b/lib/wasi/src/ptr.rs index 56abb3b2d1f..da892d1bf69 100644 --- a/lib/wasi/src/ptr.rs +++ b/lib/wasi/src/ptr.rs @@ -72,6 +72,16 @@ impl WasmPtr { unsafe impl WasmExternType for WasmPtr { const TYPE: Type = Type::I32; + + fn to_bits(self) -> u64 { + self.offset as u64 + } + fn from_bits(n: u64) -> Self { + Self { + offset: n as u32, + _phantom: PhantomData, + } + } } unsafe impl ValueType for WasmPtr {} diff --git a/lib/win-exception-handler/src/exception_handling.rs b/lib/win-exception-handler/src/exception_handling.rs index 966432a7011..ea36333ab0c 100644 --- a/lib/win-exception-handler/src/exception_handling.rs +++ b/lib/win-exception-handler/src/exception_handling.rs @@ -1,7 +1,8 @@ use std::ffi::c_void; +use std::ptr::NonNull; use wasmer_runtime_core::vm::{Ctx, Func}; -type Trampoline = unsafe extern "C" fn(*mut Ctx, *const Func, *const u64, *mut u64) -> c_void; +type Trampoline = unsafe extern "C" fn(*mut Ctx, NonNull, *const u64, *mut u64); type CallProtectedResult = Result<(), CallProtectedData>; #[repr(C)] @@ -16,7 +17,7 @@ extern "C" { pub fn __call_protected( trampoline: Trampoline, ctx: *mut Ctx, - func: *const Func, + func: NonNull, param_vec: *const u64, return_vec: *mut u64, out_result: *mut CallProtectedData, @@ -26,7 +27,7 @@ extern "C" { pub fn _call_protected( trampoline: Trampoline, ctx: *mut Ctx, - func: *const Func, + func: NonNull, param_vec: *const u64, return_vec: *mut u64, ) -> CallProtectedResult {