Skip to content

Commit

Permalink
feat: Add an error code on the metering middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
jubianchi committed Dec 23, 2020
1 parent fd38d73 commit b382cfa
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 48 deletions.
54 changes: 31 additions & 23 deletions examples/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use wasmer::CompilerConfig;
use wasmer::{imports, wat2wasm, Instance, Module, Store};
use wasmer_compiler_cranelift::Cranelift;
use wasmer_engine_jit::JIT;
use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering};
use wasmer_middlewares::{
metering::{get_remaining_points, set_remaining_points, MeteringPoints},
Metering,
};

fn main() -> anyhow::Result<()> {
// Let's declare the Wasm module.
Expand Down Expand Up @@ -55,11 +58,10 @@ fn main() -> anyhow::Result<()> {

// Now let's create our metering middleware.
//
// `Metering` needs to be configured with a limit (the gas limit) and
// a cost function.
// `Metering` needs to be configured with a limit and a cost function.
//
// For each `Operator`, the metering middleware will call the cost
// function and subtract the cost from the gas.
// function and subtract the cost from the remaining points.
let metering = Arc::new(Metering::new(10, cost_function));
let mut compiler_config = Cranelift::default();
compiler_config.push_middleware(metering);
Expand Down Expand Up @@ -93,14 +95,17 @@ fn main() -> anyhow::Result<()> {
println!("Calling `add_one` function once...");
add_one.call(1)?;

// As you can see here, after the first call we have 6 remaining gas points.
// As you can see here, after the first call we have 6 remaining points.
//
// This is correct, here are the details of how it has been computed:
// * `local.get $value` is a `Operator::LocalGet` which costs 1 point;
// * `i32.const` is a `Operator::I32Const` which costs 1 point;
// * `i32.add` is a `Operator::I32Add` which costs 2 points.
let remaining_points_after_first_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_first_call, 6);
assert_eq!(
remaining_points_after_first_call,
MeteringPoints::Remaining(6)
);

println!(
"Remaining points after the first call: {:?}",
Expand All @@ -110,18 +115,21 @@ fn main() -> anyhow::Result<()> {
println!("Calling `add_one` function twice...");
add_one.call(1)?;

// We spent 4 more gas points with the second call.
// We spent 4 more points with the second call.
// We have 2 remaining points.
let remaining_points_after_second_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_second_call, 2);
assert_eq!(
remaining_points_after_second_call,
MeteringPoints::Remaining(2)
);

println!(
"Remaining points after the second call: {:?}",
remaining_points_after_second_call
);

// Because calling our `add_one` function consumes 4 gas points,
// calling it a third time will fail: we already consume 8 gas
// Because calling our `add_one` function consumes 4 points,
// calling it a third time will fail: we already consume 8
// points, there are only two remaining.
println!("Calling `add_one` function a third time...");
match add_one.call(1) {
Expand All @@ -132,27 +140,27 @@ fn main() -> anyhow::Result<()> {
);
}
Err(_) => {
println!("Calling `add_one` failed: not enough gas points remaining.");
}
}
println!("Calling `add_one` failed.");

// Becasue the previous call failed, it did not consume any gas point.
// We still have 2 remaining points.
let remaining_points_after_third_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_third_call, 2);
// Because the last needed more than the remaining points, we should have an error.
let remaining_points = get_remaining_points(&instance);

println!(
"Remaining points after third call: {:?}",
remaining_points_after_third_call
);
match remaining_points {
MeteringPoints::Remaining(..) => {
bail!("No metering error: there are remaining points")
}
MeteringPoints::Exhausted => println!("Not enough points remaining"),
}
}
}

// Now let's see how we can set a new limit...
println!("Set new remaining points points to 10");
println!("Set new remaining points to 10");
let new_limit = 10;
set_remaining_points(&instance, new_limit);

let remaining_points = get_remaining_points(&instance);
assert_eq!(remaining_points, new_limit);
assert_eq!(remaining_points, MeteringPoints::Remaining(new_limit));

println!("Remaining points: {:?}", remaining_points);

Expand Down
151 changes: 126 additions & 25 deletions lib/middlewares/src/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ pub struct Metering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {

/// The global index in the current module for remaining points.
remaining_points_index: Mutex<Option<GlobalIndex>>,

/// The global index in the current module for a boolean indicating whether points are exhausted
/// or not.
/// This boolean is represented as a i32 global:
/// * 0: there are remaining points
/// * 1: points have been exhausted
points_exhausted_index: Mutex<Option<GlobalIndex>>,
}

/// The function-level metering middleware.
Expand All @@ -38,17 +45,35 @@ pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync
/// The global index in the current module for remaining points.
remaining_points_index: GlobalIndex,

/// The global index in the current module for a boolean indicating whether points are exhausted
/// or not.
/// This boolean is represented as a i32 global:
/// * 0: there are remaining points
/// * 1: points have been exhausted
points_exhausted_index: GlobalIndex,

/// Accumulated cost of the current basic block.
accumulated_cost: u64,
}

#[derive(Debug, PartialEq)]
pub enum MeteringPoints {
/// The given number of metering points is left for the execution.
/// If the value is 0, all points are consumed but the execution was not terminated.
Remaining(u64),
/// The execution was terminated because the metering points were exhausted.
/// You can recover from this state by setting the points via `set_remaining_points` and restart the execution.
Exhausted,
}

impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> Metering<F> {
/// Creates a `Metering` middleware.
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
cost_function,
remaining_points_index: Mutex::new(None),
points_exhausted_index: Mutex::new(None),
}
}
}
Expand All @@ -59,6 +84,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Meteri
.field("initial_limit", &self.initial_limit)
.field("cost_function", &"<function>")
.field("remaining_points_index", &self.remaining_points_index)
.field("points_exhausted_index", &self.points_exhausted_index)
.finish()
}
}
Expand All @@ -71,7 +97,10 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
Box::new(FunctionMetering {
cost_function: self.cost_function,
remaining_points_index: self.remaining_points_index.lock().unwrap().expect(
"Metering::generate_function_middleware: Remaining points index not set up.",
"Metering::generate_function_middleware: remaining_points_index not set up.",
),
points_exhausted_index: self.points_exhausted_index.lock().unwrap().expect(
"Metering::generate_function_middleware: points_exhausted_index not set up.",
),
accumulated_cost: 0,
})
Expand All @@ -80,22 +109,39 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
/// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
fn transform_module_info(&self, module_info: &mut ModuleInfo) {
let mut remaining_points_index = self.remaining_points_index.lock().unwrap();
if remaining_points_index.is_some() {
let mut points_exhausted_index = self.points_exhausted_index.lock().unwrap();

if remaining_points_index.is_some() || points_exhausted_index.is_some() {
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
}

// Append a global for remaining points and initialize it.
let global_index = module_info
let remaining_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
*remaining_points_index = Some(global_index.clone());
*remaining_points_index = Some(remaining_points_global_index.clone());

module_info
.global_initializers
.push(GlobalInit::I64Const(self.initial_limit as i64));

module_info.exports.insert(
"remaining_points".to_string(),
ExportIndex::Global(global_index),
"wasmer_metering_remaining_points".to_string(),
ExportIndex::Global(remaining_points_global_index),
);

// Append a global for the exhausted points boolean and initialize it.
let points_exhausted_global_index = module_info
.globals
.push(GlobalType::new(Type::I32, Mutability::Var));
*points_exhausted_index = Some(points_exhausted_global_index.clone());
module_info
.global_initializers
.push(GlobalInit::I32Const(0));

module_info.exports.insert(
"wasmer_metering_points_exhausted".to_string(),
ExportIndex::Global(points_exhausted_global_index),
);
}
}
Expand All @@ -105,6 +151,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Functi
f.debug_struct("FunctionMetering")
.field("cost_function", &"<function>")
.field("remaining_points_index", &self.remaining_points_index)
.field("points_exhausted_index", &self.points_exhausted_index)
.finish()
}
}
Expand Down Expand Up @@ -141,7 +188,9 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) },
Operator::Unreachable, // FIXME: Signal the error properly.
Operator::I32Const { value: 1 },
Operator::GlobalSet { global_index: self.points_exhausted_index.as_u32() },
Operator::Unreachable,
Operator::End,

// globals[remaining_points_index] -= self.accumulated_cost;
Expand Down Expand Up @@ -171,14 +220,28 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
///
/// The instance Module must have been processed with the [`Metering`] middleware
/// at compile time, otherwise this will panic.
pub fn get_remaining_points(instance: &Instance) -> u64 {
instance
pub fn get_remaining_points(instance: &Instance) -> MeteringPoints {
let exhausted: i32 = instance
.exports
.get_global("wasmer_metering_points_exhausted")
.expect("Can't get `wasmer_metering_points_exhausted` from Instance")
.get()
.try_into()
.expect("`wasmer_metering_points_exhausted` from Instance has wrong type");

if exhausted > 0 {
return MeteringPoints::Exhausted;
}

let points = instance
.exports
.get_global("remaining_points")
.expect("Can't get `remaining_points` from Instance")
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.get()
.try_into()
.expect("`remaining_points` from Instance has wrong type")
.expect("`wasmer_metering_remaining_points` from Instance has wrong type");

MeteringPoints::Remaining(points)
}

/// Set the provided remaining points in an `Instance`.
Expand All @@ -193,10 +256,17 @@ pub fn get_remaining_points(instance: &Instance) -> u64 {
pub fn set_remaining_points(instance: &Instance, points: u64) {
instance
.exports
.get_global("remaining_points")
.expect("Can't get `remaining_points` from Instance")
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.set(points.into())
.expect("Can't set `remaining_points` in Instance");
.expect("Can't set `wasmer_metering_remaining_points` in Instance");

instance
.exports
.get_global("wasmer_metering_points_exhausted")
.expect("Can't get `wasmer_metering_points_exhausted` from Instance")
.set(0i32.into())
.expect("Can't set `wasmer_metering_points_exhausted` in Instance");
}

#[cfg(test)]
Expand Down Expand Up @@ -240,7 +310,10 @@ mod tests {

// Instantiate
let instance = Instance::new(&module, &imports! {}).unwrap();
assert_eq!(get_remaining_points(&instance), 10);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(10)
);

// First call
//
Expand All @@ -255,17 +328,21 @@ mod tests {
.native::<i32, i32>()
.unwrap();
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 6);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(6)
);

// Second call
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 2);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(2)
);

// Third call fails due to limit
assert!(add_one.call(1).is_err());
// TODO: what do we expect now? 0 or 2? See https://github.com/wasmerio/wasmer/issues/1931
// assert_eq!(metering.get_remaining_points(&instance), 2);
// assert_eq!(metering.get_remaining_points(&instance), 0);
assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted);
}

#[test]
Expand All @@ -278,7 +355,10 @@ mod tests {

// Instantiate
let instance = Instance::new(&module, &imports! {}).unwrap();
assert_eq!(get_remaining_points(&instance), 10);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(10)
);
let add_one = instance
.exports
.get_function("add_one")
Expand All @@ -291,10 +371,31 @@ mod tests {

// Ensure we can use the new points now
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 8);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(8)
);

add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 4);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(4)
);

add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 0);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(0)
);

assert!(add_one.call(1).is_err());
assert_eq!(get_remaining_points(&instance), MeteringPoints::Exhausted);

// Add some points for another call
set_remaining_points(&instance, 4);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(4)
);
}
}

0 comments on commit b382cfa

Please sign in to comment.