Skip to content

Commit 6b21f17

Browse files
committed
feat: Add an error code on the metering middleware
1 parent 44fb48e commit 6b21f17

File tree

3 files changed

+129
-28
lines changed

3 files changed

+129
-28
lines changed

examples/metering.rs

+22-12
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ use wasmer::CompilerConfig;
2121
use wasmer::{imports, wat2wasm, Instance, Module, Store};
2222
use wasmer_compiler_cranelift::Cranelift;
2323
use wasmer_engine_jit::JIT;
24-
use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering};
24+
use wasmer_middlewares::{
25+
metering::{get_error, get_remaining_points, set_remaining_points},
26+
Metering, MeteringError,
27+
};
2528

2629
fn main() -> anyhow::Result<()> {
2730
// Let's declare the Wasm module.
@@ -132,19 +135,26 @@ fn main() -> anyhow::Result<()> {
132135
);
133136
}
134137
Err(_) => {
135-
println!("Calling `add_one` failed: not enough gas points remaining.");
136-
}
137-
}
138+
println!("Calling `add_one` failed.");
138139

139-
// Becasue the previous call failed, it did not consume any gas point.
140-
// We still have 2 remaining points.
141-
let remaining_points_after_third_call = get_remaining_points(&instance);
142-
assert_eq!(remaining_points_after_third_call, 2);
140+
// Because the last needed more than the remaining gas points, we got an error.
141+
let error = get_error(&instance);
143142

144-
println!(
145-
"Remaining points after third call: {:?}",
146-
remaining_points_after_third_call
147-
);
143+
match error {
144+
MeteringError::NoError => bail!("No error."),
145+
MeteringError::OutOfGas => println!("Not enough gas points remaining."),
146+
}
147+
148+
// There is now 0 remaining points (we can't go below 0).
149+
let remaining_points_after_third_call = get_remaining_points(&instance);
150+
assert_eq!(remaining_points_after_third_call, 0);
151+
152+
println!(
153+
"Remaining points after third call: {:?}",
154+
remaining_points_after_third_call
155+
);
156+
}
157+
}
148158

149159
// Now let's see how we can set a new limit...
150160
println!("Set new remaining points points to 10");

lib/middlewares/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pub mod metering;
33
// The most commonly used symbol are exported at top level of the module. Others are available
44
// via modules, e.g. `wasmer_middlewares::metering::get_remaining_points`
55
pub use metering::Metering;
6+
pub use metering::{get_error, get_remaining_points, set_remaining_points, MeteringError};

lib/middlewares/src/metering.rs

+106-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! `metering` is a middleware for tracking how many operators are executed in total
22
//! and putting a limit on the total number of operators executed.
33
4+
use core::convert::TryFrom;
45
use std::convert::TryInto;
56
use std::fmt;
67
use std::sync::Mutex;
@@ -11,7 +12,7 @@ use wasmer::{
1112
ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, LocalFunctionIndex,
1213
MiddlewareReaderState, ModuleMiddleware, Mutability, Type,
1314
};
14-
use wasmer_types::GlobalIndex;
15+
use wasmer_types::{GlobalIndex, Value};
1516
use wasmer_vm::ModuleInfo;
1617

1718
/// The module-level metering middleware.
@@ -28,6 +29,9 @@ pub struct Metering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {
2829
/// Function that maps each operator to a cost in "points".
2930
cost_function: F,
3031

32+
/// The global index in the current module for the error code.
33+
error_code_index: Mutex<Option<GlobalIndex>>,
34+
3135
/// The global index in the current module for remaining points.
3236
remaining_points_index: Mutex<Option<GlobalIndex>>,
3337
}
@@ -37,19 +41,52 @@ pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync
3741
/// Function that maps each operator to a cost in "points".
3842
cost_function: F,
3943

44+
/// The global index in the current module for the error code.
45+
error_code_index: GlobalIndex,
46+
4047
/// The global index in the current module for remaining points.
4148
remaining_points_index: GlobalIndex,
4249

4350
/// Accumulated cost of the current basic block.
4451
accumulated_cost: u64,
4552
}
4653

54+
#[derive(Debug, PartialEq)]
55+
pub enum MeteringError {
56+
NoError,
57+
OutOfGas,
58+
}
59+
60+
impl TryFrom<i32> for MeteringError {
61+
type Error = ();
62+
63+
fn try_from(value: i32) -> Result<Self, Self::Error> {
64+
match value {
65+
value if value == MeteringError::NoError as _ => Ok(MeteringError::NoError),
66+
value if value == MeteringError::OutOfGas as _ => Ok(MeteringError::OutOfGas),
67+
_ => Err(()),
68+
}
69+
}
70+
}
71+
72+
impl<T> TryFrom<Value<T>> for MeteringError {
73+
type Error = ();
74+
75+
fn try_from(v: Value<T>) -> Result<Self, Self::Error> {
76+
match v {
77+
Value::I32(value) => value.try_into(),
78+
_ => Err(()),
79+
}
80+
}
81+
}
82+
4783
impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> Metering<F> {
4884
/// Creates a `Metering` middleware.
4985
pub fn new(initial_limit: u64, cost_function: F) -> Self {
5086
Self {
5187
initial_limit,
5288
cost_function,
89+
error_code_index: Mutex::new(None),
5390
remaining_points_index: Mutex::new(None),
5491
}
5592
}
@@ -60,6 +97,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Meteri
6097
f.debug_struct("Metering")
6198
.field("initial_limit", &self.initial_limit)
6299
.field("cost_function", &"<function>")
100+
.field("error_code_index", &self.error_code_index)
63101
.field("remaining_points_index", &self.remaining_points_index)
64102
.finish()
65103
}
@@ -72,6 +110,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
72110
fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
73111
Box::new(FunctionMetering {
74112
cost_function: self.cost_function,
113+
error_code_index: self
114+
.error_code_index
115+
.lock()
116+
.unwrap()
117+
.expect("Metering::generate_function_middleware: Error code index not set up."),
75118
remaining_points_index: self.remaining_points_index.lock().unwrap().expect(
76119
"Metering::generate_function_middleware: Remaining points index not set up.",
77120
),
@@ -81,23 +124,37 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
81124

82125
/// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
83126
fn transform_module_info(&self, module_info: &mut ModuleInfo) {
127+
let mut error_code_index = self.error_code_index.lock().unwrap();
84128
let mut remaining_points_index = self.remaining_points_index.lock().unwrap();
85-
if remaining_points_index.is_some() {
129+
130+
if error_code_index.is_some() || remaining_points_index.is_some(){
86131
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
87132
}
88133

134+
// Append a global for the error code and initialize it.
135+
let error_code_global_index = module_info
136+
.globals
137+
.push(GlobalType::new(Type::I32, Mutability::Var));
138+
*error_code_index = Some(error_code_global_index.clone());
139+
89140
// Append a global for remaining points and initialize it.
90-
let global_index = module_info
141+
let remaining_points_global_index = module_info
91142
.globals
92143
.push(GlobalType::new(Type::I64, Mutability::Var));
93-
*remaining_points_index = Some(global_index.clone());
144+
*remaining_points_index = Some(remaining_points_global_index.clone());
145+
94146
module_info
95147
.global_initializers
96148
.push(GlobalInit::I64Const(self.initial_limit as i64));
97149

98150
module_info.exports.insert(
99-
"remaining_points".to_string(),
100-
ExportIndex::Global(global_index),
151+
"wasmer_metering_error_code".to_string(),
152+
ExportIndex::Global(error_code_global_index),
153+
);
154+
155+
module_info.exports.insert(
156+
"wasmer_metering_remaining_points".to_string(),
157+
ExportIndex::Global(remaining_points_global_index),
101158
);
102159
}
103160
}
@@ -106,6 +163,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Functi
106163
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107164
f.debug_struct("FunctionMetering")
108165
.field("cost_function", &"<function>")
166+
.field("error_code_index", &self.error_code_index)
109167
.field("remaining_points_index", &self.remaining_points_index)
110168
.finish()
111169
}
@@ -143,7 +201,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
143201
Operator::I64Const { value: self.accumulated_cost as i64 },
144202
Operator::I64LtU,
145203
Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) },
146-
Operator::Unreachable, // FIXME: Signal the error properly.
204+
Operator::I32Const { value: MeteringError::OutOfGas as i32 },
205+
Operator::GlobalSet { global_index: self.error_code_index.as_u32() },
206+
Operator::I64Const { value: 0 },
207+
Operator::GlobalSet { global_index: self.remaining_points_index.as_u32() },
208+
Operator::Unreachable,
147209
Operator::End,
148210

149211
// globals[remaining_points_index] -= self.accumulated_cost;
@@ -164,6 +226,28 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
164226
}
165227
}
166228

229+
/// Get the error code in an `Instance`.
230+
///
231+
/// When instance execution traps, this error code will help know if it was caused by
232+
/// remaining points being exhausted.
233+
///
234+
/// This can be used in a headless engine after an ahead-of-time compilation
235+
/// as all required state lives in the instance.
236+
///
237+
/// # Panic
238+
///
239+
/// The instance Module must have been processed with the [`Metering`] middleware
240+
/// at compile time, otherwise this will panic.
241+
pub fn get_error(instance: &Instance) -> MeteringError {
242+
instance
243+
.exports
244+
.get_global("wasmer_metering_error_code")
245+
.expect("Can't get `wasmer_metering_error_code` from Instance")
246+
.get()
247+
.try_into()
248+
.expect("`wasmer_metering_error_code` from Instance has wrong type")
249+
}
250+
167251
/// Get the remaining points in an `Instance`.
168252
///
169253
/// This can be used in a headless engine after an ahead-of-time compilation
@@ -176,11 +260,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
176260
pub fn get_remaining_points(instance: &Instance) -> u64 {
177261
instance
178262
.exports
179-
.get_global("remaining_points")
180-
.expect("Can't get `remaining_points` from Instance")
263+
.get_global("wasmer_metering_remaining_points")
264+
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
181265
.get()
182266
.try_into()
183-
.expect("`remaining_points` from Instance has wrong type")
267+
.expect("`wasmer_metering_remaining_points` from Instance has wrong type")
184268
}
185269

186270
/// Set the provided remaining points in an `Instance`.
@@ -195,10 +279,10 @@ pub fn get_remaining_points(instance: &Instance) -> u64 {
195279
pub fn set_remaining_points(instance: &Instance, points: u64) {
196280
instance
197281
.exports
198-
.get_global("remaining_points")
199-
.expect("Can't get `remaining_points` from Instance")
282+
.get_global("wasmer_metering_remaining_points")
283+
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
200284
.set(points.into())
201-
.expect("Can't set `remaining_points` in Instance");
285+
.expect("Can't set `wasmer_metering_remaining_points` in Instance");
202286
}
203287

204288
#[cfg(test)]
@@ -258,16 +342,17 @@ mod tests {
258342
.unwrap();
259343
add_one.call(1).unwrap();
260344
assert_eq!(get_remaining_points(&instance), 6);
345+
assert_eq!(get_error(&instance), MeteringError::NoError);
261346

262347
// Second call
263348
add_one.call(1).unwrap();
264349
assert_eq!(get_remaining_points(&instance), 2);
350+
assert_eq!(get_error(&instance), MeteringError::NoError);
265351

266352
// Third call fails due to limit
267353
assert!(add_one.call(1).is_err());
268-
// TODO: what do we expect now? 0 or 2? See https://github.com/wasmerio/wasmer/issues/1931
269-
// assert_eq!(metering.get_remaining_points(&instance), 2);
270-
// assert_eq!(metering.get_remaining_points(&instance), 0);
354+
assert_eq!(get_remaining_points(&instance), 0);
355+
assert_eq!(get_error(&instance), MeteringError::OutOfGas);
271356
}
272357

273358
#[test]
@@ -294,9 +379,14 @@ mod tests {
294379
// Ensure we can use the new points now
295380
add_one.call(1).unwrap();
296381
assert_eq!(get_remaining_points(&instance), 8);
382+
assert_eq!(get_error(&instance), MeteringError::NoError);
383+
297384
add_one.call(1).unwrap();
298385
assert_eq!(get_remaining_points(&instance), 4);
386+
assert_eq!(get_error(&instance), MeteringError::NoError);
387+
299388
add_one.call(1).unwrap();
300389
assert_eq!(get_remaining_points(&instance), 0);
390+
assert_eq!(get_error(&instance), MeteringError::NoError);
301391
}
302392
}

0 commit comments

Comments
 (0)