Skip to content

Commit 3eba516

Browse files
committed
fix(middlewares): Fix error in metering middleware
1 parent d11fccd commit 3eba516

File tree

1 file changed

+74
-35
lines changed

1 file changed

+74
-35
lines changed

lib/middlewares/src/metering.rs

+74-35
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ pub enum MeteringPoints {
124124

125125
impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
126126
/// Creates a `Metering` middleware.
127+
///
128+
/// When providing a cost function, you should consider that branching operations do
129+
/// additional work to track the metering points and probably need to have a higher cost.
130+
/// To find out which operations are affected by this, you can call [`is_accounting`].
127131
pub fn new(initial_limit: u64, cost_function: F) -> Self {
128132
Self {
129133
initial_limit,
@@ -198,6 +202,42 @@ impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Meter
198202
}
199203
}
200204

205+
/// Returns `true` if and only if the given operator is an accounting operator.
206+
/// Accounting operators do additional work to track the metering points.
207+
pub fn is_accounting(operator: &Operator) -> bool {
208+
// Possible sources and targets of a branch.
209+
matches!(
210+
operator,
211+
Operator::Loop { .. } // loop headers are branch targets
212+
| Operator::End // block ends are branch targets
213+
| Operator::If { .. } // branch source, "if" can branch to else branch
214+
| Operator::Else // "else" is the "end" of an if branch
215+
| Operator::Br { .. } // branch source
216+
| Operator::BrTable { .. } // branch source
217+
| Operator::BrIf { .. } // branch source
218+
| Operator::Call { .. } // function call - branch source
219+
| Operator::CallIndirect { .. } // function call - branch source
220+
| Operator::Return // end of function - branch source
221+
// exceptions proposal
222+
| Operator::Throw { .. } // branch source
223+
| Operator::ThrowRef // branch source
224+
| Operator::Rethrow { .. } // branch source
225+
| Operator::Delegate { .. } // branch source
226+
| Operator::Catch { .. } // branch target
227+
// tail_call proposal
228+
| Operator::ReturnCall { .. } // branch source
229+
| Operator::ReturnCallIndirect { .. } // branch source
230+
// gc proposal
231+
| Operator::BrOnCast { .. } // branch source
232+
| Operator::BrOnCastFail { .. } // branch source
233+
// function_references proposal
234+
| Operator::CallRef { .. } // branch source
235+
| Operator::ReturnCallRef { .. } // branch source
236+
| Operator::BrOnNull { .. } // branch source
237+
| Operator::BrOnNonNull { .. } // branch source
238+
)
239+
}
240+
201241
impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
202242
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203243
f.debug_struct("FunctionMetering")
@@ -218,41 +258,40 @@ impl<F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware for FunctionMeter
218258
// corner cases.
219259
self.accumulated_cost += (self.cost_function)(&operator);
220260

221-
// Possible sources and targets of a branch. Finalize the cost of the previous basic block and perform necessary checks.
222-
match operator {
223-
Operator::Loop { .. } // loop headers are branch targets
224-
| Operator::End // block ends are branch targets
225-
| Operator::Else // "else" is the "end" of an if branch
226-
| Operator::Br { .. } // branch source
227-
| Operator::BrTable { .. } // branch source
228-
| Operator::BrIf { .. } // branch source
229-
| Operator::Call { .. } // function call - branch source
230-
| Operator::CallIndirect { .. } // function call - branch source
231-
| Operator::Return // end of function - branch source
232-
=> {
233-
if self.accumulated_cost > 0 {
234-
state.extend(&[
235-
// if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
236-
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
237-
Operator::I64Const { value: self.accumulated_cost as i64 },
238-
Operator::I64LtU,
239-
Operator::If { blockty: WpTypeOrFuncType::Empty },
240-
Operator::I32Const { value: 1 },
241-
Operator::GlobalSet { global_index: self.global_indexes.points_exhausted().as_u32() },
242-
Operator::Unreachable,
243-
Operator::End,
244-
245-
// globals[remaining_points_index] -= self.accumulated_cost;
246-
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
247-
Operator::I64Const { value: self.accumulated_cost as i64 },
248-
Operator::I64Sub,
249-
Operator::GlobalSet { global_index: self.global_indexes.remaining_points().as_u32() },
250-
]);
251-
252-
self.accumulated_cost = 0;
253-
}
254-
}
255-
_ => {}
261+
// Finalize the cost of the previous basic block and perform necessary checks.
262+
if is_accounting(&operator) && self.accumulated_cost > 0 {
263+
state.extend(&[
264+
// if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
265+
Operator::GlobalGet {
266+
global_index: self.global_indexes.remaining_points().as_u32(),
267+
},
268+
Operator::I64Const {
269+
value: self.accumulated_cost as i64,
270+
},
271+
Operator::I64LtU,
272+
Operator::If {
273+
blockty: WpTypeOrFuncType::Empty,
274+
},
275+
Operator::I32Const { value: 1 },
276+
Operator::GlobalSet {
277+
global_index: self.global_indexes.points_exhausted().as_u32(),
278+
},
279+
Operator::Unreachable,
280+
Operator::End,
281+
// globals[remaining_points_index] -= self.accumulated_cost;
282+
Operator::GlobalGet {
283+
global_index: self.global_indexes.remaining_points().as_u32(),
284+
},
285+
Operator::I64Const {
286+
value: self.accumulated_cost as i64,
287+
},
288+
Operator::I64Sub,
289+
Operator::GlobalSet {
290+
global_index: self.global_indexes.remaining_points().as_u32(),
291+
},
292+
]);
293+
294+
self.accumulated_cost = 0;
256295
}
257296
state.push_operator(operator);
258297

0 commit comments

Comments
 (0)