diff --git a/examples/metering.rs b/examples/metering.rs index 785f6df3a86..d98e30a114f 100644 --- a/examples/metering.rs +++ b/examples/metering.rs @@ -21,7 +21,10 @@ use wasmer::CompilerConfig; use wasmer::{imports, wat2wasm, Instance, Module, Store}; use wasmer_compiler_cranelift::Cranelift; use wasmer_engine_jit::JIT; -use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering}; +use wasmer_middlewares::{ + metering::{get_remaining_points, set_remaining_points, MeteringPoints}, + Metering, +}; fn main() -> anyhow::Result<()> { // Let's declare the Wasm module. @@ -55,11 +58,10 @@ fn main() -> anyhow::Result<()> { // Now let's create our metering middleware. // - // `Metering` needs to be configured with a limit (the gas limit) and - // a cost function. + // `Metering` needs to be configured with a limit and a cost function. // // For each `Operator`, the metering middleware will call the cost - // function and subtract the cost from the gas. + // function and subtract the cost from the remaining points. let metering = Arc::new(Metering::new(10, cost_function)); let mut compiler_config = Cranelift::default(); compiler_config.push_middleware(metering); @@ -93,14 +95,17 @@ fn main() -> anyhow::Result<()> { println!("Calling `add_one` function once..."); add_one.call(1)?; - // As you can see here, after the first call we have 6 remaining gas points. + // As you can see here, after the first call we have 6 remaining points. // // This is correct, here are the details of how it has been computed: // * `local.get $value` is a `Operator::LocalGet` which costs 1 point; // * `i32.const` is a `Operator::I32Const` which costs 1 point; // * `i32.add` is a `Operator::I32Add` which costs 2 points. let remaining_points_after_first_call = get_remaining_points(&instance); - assert_eq!(remaining_points_after_first_call, 6); + assert_eq!( + remaining_points_after_first_call, + MeteringPoints::Remaining(6) + ); println!( "Remaining points after the first call: {:?}", @@ -110,18 +115,21 @@ fn main() -> anyhow::Result<()> { println!("Calling `add_one` function twice..."); add_one.call(1)?; - // We spent 4 more gas points with the second call. + // We spent 4 more points with the second call. // We have 2 remaining points. let remaining_points_after_second_call = get_remaining_points(&instance); - assert_eq!(remaining_points_after_second_call, 2); + assert_eq!( + remaining_points_after_second_call, + MeteringPoints::Remaining(2) + ); println!( "Remaining points after the second call: {:?}", remaining_points_after_second_call ); - // Because calling our `add_one` function consumes 4 gas points, - // calling it a third time will fail: we already consume 8 gas + // Because calling our `add_one` function consumes 4 points, + // calling it a third time will fail: we already consume 8 // points, there are only two remaining. println!("Calling `add_one` function a third time..."); match add_one.call(1) { @@ -132,27 +140,27 @@ fn main() -> anyhow::Result<()> { ); } Err(_) => { - println!("Calling `add_one` failed: not enough gas points remaining."); - } - } + println!("Calling `add_one` failed."); - // Becasue the previous call failed, it did not consume any gas point. - // We still have 2 remaining points. - let remaining_points_after_third_call = get_remaining_points(&instance); - assert_eq!(remaining_points_after_third_call, 2); + // Because the last needed more than the remaining points, we should have an error. + let remaining_points = get_remaining_points(&instance); - println!( - "Remaining points after third call: {:?}", - remaining_points_after_third_call - ); + match remaining_points { + MeteringPoints::Remaining(..) => { + bail!("No metering error: there are remaining points") + } + MeteringPoints::Exhausted => println!("Not enough points remaining"), + } + } + } // Now let's see how we can set a new limit... - println!("Set new remaining points points to 10"); + println!("Set new remaining points to 10"); let new_limit = 10; set_remaining_points(&instance, new_limit); let remaining_points = get_remaining_points(&instance); - assert_eq!(remaining_points, new_limit); + assert_eq!(remaining_points, MeteringPoints::Remaining(new_limit)); println!("Remaining points: {:?}", remaining_points); diff --git a/lib/middlewares/src/metering.rs b/lib/middlewares/src/metering.rs index f02ac1348ba..4f968ed709b 100644 --- a/lib/middlewares/src/metering.rs +++ b/lib/middlewares/src/metering.rs @@ -12,6 +12,34 @@ use wasmer::{ use wasmer_types::GlobalIndex; use wasmer_vm::ModuleInfo; +#[derive(Clone)] +struct MeteringGlobalIndexes(GlobalIndex, GlobalIndex); + +impl MeteringGlobalIndexes { + /// The global index in the current module for remaining points. + fn remaining_points(&self) -> GlobalIndex { + self.0 + } + + /// The global index in the current module for a boolean indicating whether points are exhausted + /// or not. + /// This boolean is represented as a i32 global: + /// * 0: there are remaining points + /// * 1: points have been exhausted + fn points_exhausted(&self) -> GlobalIndex { + self.1 + } +} + +impl fmt::Debug for MeteringGlobalIndexes { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MeteringGlobalIndexes") + .field("remaining_points", &self.remaining_points()) + .field("points_exhausted", &self.points_exhausted()) + .finish() + } +} + /// The module-level metering middleware. /// /// # Panic @@ -26,8 +54,8 @@ pub struct Metering u64 + Copy + Clone + Send + Sync> { /// Function that maps each operator to a cost in "points". cost_function: F, - /// The global index in the current module for remaining points. - remaining_points_index: Mutex>, + /// The global indexes for metering points. + global_indexes: Mutex>, } /// The function-level metering middleware. @@ -35,20 +63,30 @@ pub struct FunctionMetering u64 + Copy + Clone + Send + Sync /// Function that maps each operator to a cost in "points". cost_function: F, - /// The global index in the current module for remaining points. - remaining_points_index: GlobalIndex, + /// The global indexes for metering points. + global_indexes: MeteringGlobalIndexes, /// Accumulated cost of the current basic block. accumulated_cost: u64, } +#[derive(Debug, PartialEq)] +pub enum MeteringPoints { + /// The given number of metering points is left for the execution. + /// If the value is 0, all points are consumed but the execution was not terminated. + Remaining(u64), + /// The execution was terminated because the metering points were exhausted. + /// You can recover from this state by setting the points via `set_remaining_points` and restart the execution. + Exhausted, +} + impl u64 + Copy + Clone + Send + Sync> Metering { /// Creates a `Metering` middleware. pub fn new(initial_limit: u64, cost_function: F) -> Self { Self { initial_limit, cost_function, - remaining_points_index: Mutex::new(None), + global_indexes: Mutex::new(None), } } } @@ -58,7 +96,7 @@ impl u64 + Copy + Clone + Send + Sync> fmt::Debug for Meteri f.debug_struct("Metering") .field("initial_limit", &self.initial_limit) .field("cost_function", &"") - .field("remaining_points_index", &self.remaining_points_index) + .field("global_indexes", &self.global_indexes) .finish() } } @@ -70,33 +108,51 @@ impl u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box { Box::new(FunctionMetering { cost_function: self.cost_function, - remaining_points_index: self.remaining_points_index.lock().unwrap().expect( - "Metering::generate_function_middleware: Remaining points index not set up.", - ), + global_indexes: self.global_indexes.lock().unwrap().clone().unwrap(), accumulated_cost: 0, }) } /// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins. fn transform_module_info(&self, module_info: &mut ModuleInfo) { - let mut remaining_points_index = self.remaining_points_index.lock().unwrap(); - if remaining_points_index.is_some() { + let mut global_indexes = self.global_indexes.lock().unwrap(); + + if global_indexes.is_some() { panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules."); } // Append a global for remaining points and initialize it. - let global_index = module_info + let remaining_points_global_index = module_info .globals .push(GlobalType::new(Type::I64, Mutability::Var)); - *remaining_points_index = Some(global_index.clone()); + module_info .global_initializers .push(GlobalInit::I64Const(self.initial_limit as i64)); module_info.exports.insert( - "remaining_points".to_string(), - ExportIndex::Global(global_index), + "wasmer_metering_remaining_points".to_string(), + ExportIndex::Global(remaining_points_global_index), + ); + + // Append a global for the exhausted points boolean and initialize it. + let points_exhausted_global_index = module_info + .globals + .push(GlobalType::new(Type::I32, Mutability::Var)); + + module_info + .global_initializers + .push(GlobalInit::I32Const(0)); + + module_info.exports.insert( + "wasmer_metering_points_exhausted".to_string(), + ExportIndex::Global(points_exhausted_global_index), ); + + *global_indexes = Some(MeteringGlobalIndexes( + remaining_points_global_index, + points_exhausted_global_index, + )) } } @@ -104,7 +160,7 @@ impl u64 + Copy + Clone + Send + Sync> fmt::Debug for Functi fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FunctionMetering") .field("cost_function", &"") - .field("remaining_points_index", &self.remaining_points_index) + .field("global_indexes", &self.global_indexes) .finish() } } @@ -137,18 +193,20 @@ impl u64 + Copy + Clone + Send + Sync> FunctionMiddleware if self.accumulated_cost > 0 { state.extend(&[ // if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); } - Operator::GlobalGet { global_index: self.remaining_points_index.as_u32() }, + Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() }, Operator::I64Const { value: self.accumulated_cost as i64 }, Operator::I64LtU, Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) }, - Operator::Unreachable, // FIXME: Signal the error properly. + Operator::I32Const { value: 1 }, + Operator::GlobalSet { global_index: self.global_indexes.points_exhausted().as_u32() }, + Operator::Unreachable, Operator::End, // globals[remaining_points_index] -= self.accumulated_cost; - Operator::GlobalGet { global_index: self.remaining_points_index.as_u32() }, + Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() }, Operator::I64Const { value: self.accumulated_cost as i64 }, Operator::I64Sub, - Operator::GlobalSet { global_index: self.remaining_points_index.as_u32() }, + Operator::GlobalSet { global_index: self.global_indexes.remaining_points().as_u32() }, ]); self.accumulated_cost = 0; @@ -171,14 +229,28 @@ impl u64 + Copy + Clone + Send + Sync> FunctionMiddleware /// /// The instance Module must have been processed with the [`Metering`] middleware /// at compile time, otherwise this will panic. -pub fn get_remaining_points(instance: &Instance) -> u64 { - instance +pub fn get_remaining_points(instance: &Instance) -> MeteringPoints { + let exhausted: i32 = instance .exports - .get_global("remaining_points") - .expect("Can't get `remaining_points` from Instance") + .get_global("wasmer_metering_points_exhausted") + .expect("Can't get `wasmer_metering_points_exhausted` from Instance") .get() .try_into() - .expect("`remaining_points` from Instance has wrong type") + .expect("`wasmer_metering_points_exhausted` from Instance has wrong type"); + + if exhausted > 0 { + return MeteringPoints::Exhausted; + } + + let points = instance + .exports + .get_global("wasmer_metering_remaining_points") + .expect("Can't get `wasmer_metering_remaining_points` from Instance") + .get() + .try_into() + .expect("`wasmer_metering_remaining_points` from Instance has wrong type"); + + MeteringPoints::Remaining(points) } /// Set the provided remaining points in an `Instance`. @@ -193,10 +265,17 @@ pub fn get_remaining_points(instance: &Instance) -> u64 { pub fn set_remaining_points(instance: &Instance, points: u64) { instance .exports - .get_global("remaining_points") - .expect("Can't get `remaining_points` from Instance") + .get_global("wasmer_metering_remaining_points") + .expect("Can't get `wasmer_metering_remaining_points` from Instance") .set(points.into()) - .expect("Can't set `remaining_points` in Instance"); + .expect("Can't set `wasmer_metering_remaining_points` in Instance"); + + instance + .exports + .get_global("wasmer_metering_points_exhausted") + .expect("Can't get `wasmer_metering_points_exhausted` from Instance") + .set(0i32.into()) + .expect("Can't set `wasmer_metering_points_exhausted` in Instance"); } #[cfg(test)] @@ -240,7 +319,10 @@ mod tests { // Instantiate let instance = Instance::new(&module, &imports! {}).unwrap(); - assert_eq!(get_remaining_points(&instance), 10); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(10) + ); // First call // @@ -255,17 +337,21 @@ mod tests { .native::() .unwrap(); add_one.call(1).unwrap(); - assert_eq!(get_remaining_points(&instance), 6); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(6) + ); // Second call add_one.call(1).unwrap(); - assert_eq!(get_remaining_points(&instance), 2); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(2) + ); // Third call fails due to limit assert!(add_one.call(1).is_err()); - // TODO: what do we expect now? 0 or 2? See https://github.com/wasmerio/wasmer/issues/1931 - // assert_eq!(metering.get_remaining_points(&instance), 2); - // assert_eq!(metering.get_remaining_points(&instance), 0); + assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted); } #[test] @@ -278,7 +364,10 @@ mod tests { // Instantiate let instance = Instance::new(&module, &imports! {}).unwrap(); - assert_eq!(get_remaining_points(&instance), 10); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(10) + ); let add_one = instance .exports .get_function("add_one") @@ -291,10 +380,31 @@ mod tests { // Ensure we can use the new points now add_one.call(1).unwrap(); - assert_eq!(get_remaining_points(&instance), 8); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(8) + ); + add_one.call(1).unwrap(); - assert_eq!(get_remaining_points(&instance), 4); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(4) + ); + add_one.call(1).unwrap(); - assert_eq!(get_remaining_points(&instance), 0); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(0) + ); + + assert!(add_one.call(1).is_err()); + assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted); + + // Add some points for another call + set_remaining_points(&instance, 4); + assert_eq!( + get_remaining_points(&instance), + MeteringPoints::Remaining(4) + ); } }