diff --git a/CHANGELOG.md b/CHANGELOG.md index 4306f4afe76..3b1eba184e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ ### Changed +* [#1941](https://github.com/wasmerio/wasmer/pull/1941) Turn `get_remaining_points`/`set_remaining_points` of the `Metering` middleware into free functions to allow using them in an ahead-of-time compilation setup * [#1955](https://github.com/wasmerio/wasmer/pull/1955) Set `jit` as a default feature of the `wasmer-wasm-c-api` crate * [#1944](https://github.com/wasmerio/wasmer/pull/1944) Require `WasmerEnv` to be `Send + Sync` even in dynamic functions. diff --git a/examples/metering.rs b/examples/metering.rs index ed588dda372..785f6df3a86 100644 --- a/examples/metering.rs +++ b/examples/metering.rs @@ -21,7 +21,7 @@ use wasmer::CompilerConfig; use wasmer::{imports, wat2wasm, Instance, Module, Store}; use wasmer_compiler_cranelift::Cranelift; use wasmer_engine_jit::JIT; -use wasmer_middlewares::Metering; +use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering}; fn main() -> anyhow::Result<()> { // Let's declare the Wasm module. @@ -62,7 +62,7 @@ fn main() -> anyhow::Result<()> { // function and subtract the cost from the gas. let metering = Arc::new(Metering::new(10, cost_function)); let mut compiler_config = Cranelift::default(); - compiler_config.push_middleware(metering.clone()); + compiler_config.push_middleware(metering); // Create a Store. // @@ -99,7 +99,7 @@ fn main() -> anyhow::Result<()> { // * `local.get $value` is a `Operator::LocalGet` which costs 1 point; // * `i32.const` is a `Operator::I32Const` which costs 1 point; // * `i32.add` is a `Operator::I32Add` which costs 2 points. - let remaining_points_after_first_call = metering.get_remaining_points(&instance); + let remaining_points_after_first_call = get_remaining_points(&instance); assert_eq!(remaining_points_after_first_call, 6); println!( @@ -112,7 +112,7 @@ fn main() -> anyhow::Result<()> { // We spent 4 more gas points with the second call. // We have 2 remaining points. - let remaining_points_after_second_call = metering.get_remaining_points(&instance); + let remaining_points_after_second_call = get_remaining_points(&instance); assert_eq!(remaining_points_after_second_call, 2); println!( @@ -138,7 +138,7 @@ fn main() -> anyhow::Result<()> { // Becasue the previous call failed, it did not consume any gas point. // We still have 2 remaining points. - let remaining_points_after_third_call = metering.get_remaining_points(&instance); + let remaining_points_after_third_call = get_remaining_points(&instance); assert_eq!(remaining_points_after_third_call, 2); println!( @@ -149,9 +149,9 @@ fn main() -> anyhow::Result<()> { // Now let's see how we can set a new limit... println!("Set new remaining points points to 10"); let new_limit = 10; - metering.set_remaining_points(&instance, new_limit); + set_remaining_points(&instance, new_limit); - let remaining_points = metering.get_remaining_points(&instance); + let remaining_points = get_remaining_points(&instance); assert_eq!(remaining_points, new_limit); println!("Remaining points: {:?}", remaining_points); diff --git a/lib/middlewares/src/lib.rs b/lib/middlewares/src/lib.rs index 884fcf82242..561366c22c7 100644 --- a/lib/middlewares/src/lib.rs +++ b/lib/middlewares/src/lib.rs @@ -1,3 +1,5 @@ pub mod metering; +// The most commonly used symbol are exported at top level of the module. Others are available +// via modules, e.g. `wasmer_middlewares::metering::get_remaining_points` pub use metering::Metering; diff --git a/lib/middlewares/src/metering.rs b/lib/middlewares/src/metering.rs index 335026c20f6..a0a2919d31f 100644 --- a/lib/middlewares/src/metering.rs +++ b/lib/middlewares/src/metering.rs @@ -1,6 +1,7 @@ //! `metering` is a middleware for tracking how many operators are executed in total //! and putting a limit on the total number of operators executed. +use std::convert::TryInto; use std::fmt; use std::sync::Mutex; use wasmer::wasmparser::{ @@ -8,7 +9,7 @@ use wasmer::wasmparser::{ }; use wasmer::{ ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, LocalFunctionIndex, - MiddlewareReaderState, ModuleMiddleware, Mutability, Type, Value, + MiddlewareReaderState, ModuleMiddleware, Mutability, Type, }; use wasmer_types::GlobalIndex; use wasmer_vm::ModuleInfo; @@ -52,30 +53,6 @@ impl u64 + Copy + Clone + Send + Sync> Metering { remaining_points_index: Mutex::new(None), } } - - /// Get the remaining points in an Instance. - /// - /// Important: the instance Module must been processed with the `Metering` middleware. - pub fn get_remaining_points(&self, instance: &Instance) -> u64 { - instance - .exports - .get_global("remaining_points") - .expect("Can't get `remaining_points` from Instance") - .get() - .unwrap_i64() as _ - } - - /// Set the provided remaining points in an Instance. - /// - /// Important: the instance Module must been processed with the `Metering` middleware. - pub fn set_remaining_points(&self, instance: &Instance, points: u64) { - instance - .exports - .get_global("remaining_points") - .expect("Can't get `remaining_points` from Instance") - .set(Value::I64(points as _)) - .expect("Can't set `remaining_points` in Instance"); - } } impl u64 + Copy + Clone + Send + Sync> fmt::Debug for Metering { @@ -186,3 +163,140 @@ impl u64 + Copy + Clone + Send + Sync> FunctionMiddleware Ok(()) } } + +/// Get the remaining points in an `Instance`. +/// +/// This can be used in a headless engine after an ahead-of-time compilation +/// as all required state lives in the instance. +/// +/// # Panic +/// +/// The instance Module must have been processed with the [`Metering`] middleware +/// at compile time, otherwise this will panic. +pub fn get_remaining_points(instance: &Instance) -> u64 { + instance + .exports + .get_global("remaining_points") + .expect("Can't get `remaining_points` from Instance") + .get() + .try_into() + .expect("`remaining_points` from Instance has wrong type") +} + +/// Set the provided remaining points in an `Instance`. +/// +/// This can be used in a headless engine after an ahead-of-time compilation +/// as all required state lives in the instance. +/// +/// # Panic +/// +/// The instance Module must have been processed with the [`Metering`] middleware +/// at compile time, otherwise this will panic. +pub fn set_remaining_points(instance: &Instance, points: u64) { + instance + .exports + .get_global("remaining_points") + .expect("Can't get `remaining_points` from Instance") + .set(points.into()) + .expect("Can't set `remaining_points` in Instance"); +} + +#[cfg(test)] +mod tests { + use super::*; + + use std::sync::Arc; + use wasmer::{imports, wat2wasm, CompilerConfig, Cranelift, Module, Store, JIT}; + + fn cost_function(operator: &Operator) -> u64 { + match operator { + Operator::LocalGet { .. } | Operator::I32Const { .. } => 1, + Operator::I32Add { .. } => 2, + _ => 0, + } + } + + fn bytecode() -> Vec { + wat2wasm( + br#" + (module + (type $add_t (func (param i32) (result i32))) + (func $add_one_f (type $add_t) (param $value i32) (result i32) + local.get $value + i32.const 1 + i32.add) + (export "add_one" (func $add_one_f))) + "#, + ) + .unwrap() + .into() + } + + #[test] + fn get_remaining_points_works() { + let metering = Arc::new(Metering::new(10, cost_function)); + let mut compiler_config = Cranelift::default(); + compiler_config.push_middleware(metering.clone()); + let store = Store::new(&JIT::new(compiler_config).engine()); + let module = Module::new(&store, bytecode()).unwrap(); + + // Instantiate + let instance = Instance::new(&module, &imports! {}).unwrap(); + assert_eq!(get_remaining_points(&instance), 10); + + // First call + // + // Calling add_one costs 4 points. Here are the details of how it has been computed: + // * `local.get $value` is a `Operator::LocalGet` which costs 1 point; + // * `i32.const` is a `Operator::I32Const` which costs 1 point; + // * `i32.add` is a `Operator::I32Add` which costs 2 points. + let add_one = instance + .exports + .get_function("add_one") + .unwrap() + .native::() + .unwrap(); + add_one.call(1).unwrap(); + assert_eq!(get_remaining_points(&instance), 6); + + // Second call + add_one.call(1).unwrap(); + assert_eq!(get_remaining_points(&instance), 2); + + // Third call fails due to limit + assert!(add_one.call(1).is_err()); + // TODO: what do we expect now? 0 or 2? See https://github.com/wasmerio/wasmer/issues/1931 + // assert_eq!(metering.get_remaining_points(&instance), 2); + // assert_eq!(metering.get_remaining_points(&instance), 0); + } + + #[test] + fn set_remaining_points_works() { + let metering = Arc::new(Metering::new(10, cost_function)); + let mut compiler_config = Cranelift::default(); + compiler_config.push_middleware(metering.clone()); + let store = Store::new(&JIT::new(compiler_config).engine()); + let module = Module::new(&store, bytecode()).unwrap(); + + // Instantiate + let instance = Instance::new(&module, &imports! {}).unwrap(); + assert_eq!(get_remaining_points(&instance), 10); + let add_one = instance + .exports + .get_function("add_one") + .unwrap() + .native::() + .unwrap(); + + // Increase a bit to have enough for 3 calls + set_remaining_points(&instance, 12); + + // Ensure we can use the new points now + add_one.call(1).unwrap(); + assert_eq!(get_remaining_points(&instance), 8); + add_one.call(1).unwrap(); + assert_eq!(get_remaining_points(&instance), 4); + add_one.call(1).unwrap(); + assert_eq!(get_remaining_points(&instance), 0); + } +}