diff --git a/lib/middlewares/src/metering.rs b/lib/middlewares/src/metering.rs index 18bc7997146..1207269a2da 100644 --- a/lib/middlewares/src/metering.rs +++ b/lib/middlewares/src/metering.rs @@ -124,6 +124,10 @@ pub enum MeteringPoints { impl u64 + Send + Sync> Metering { /// Creates a `Metering` middleware. + /// + /// When providing a cost function, you should consider that branching operations do + /// additional work to track the metering points and probably need to have a higher cost. + /// To find out which operations are affected by this, you can call [`is_accounting`]. pub fn new(initial_limit: u64, cost_function: F) -> Self { Self { initial_limit, @@ -198,6 +202,42 @@ impl u64 + Send + Sync + 'static> ModuleMiddleware for Meter } } +/// Returns `true` if and only if the given operator is an accounting operator. +/// Accounting operators do additional work to track the metering points. +pub fn is_accounting(operator: &Operator) -> bool { + // Possible sources and targets of a branch. + matches!( + operator, + Operator::Loop { .. } // loop headers are branch targets + | Operator::End // block ends are branch targets + | Operator::If { .. } // branch source, "if" can branch to else branch + | Operator::Else // "else" is the "end" of an if branch + | Operator::Br { .. } // branch source + | Operator::BrTable { .. } // branch source + | Operator::BrIf { .. } // branch source + | Operator::Call { .. } // function call - branch source + | Operator::CallIndirect { .. } // function call - branch source + | Operator::Return // end of function - branch source + // exceptions proposal + | Operator::Throw { .. } // branch source + | Operator::ThrowRef // branch source + | Operator::Rethrow { .. } // branch source + | Operator::Delegate { .. } // branch source + | Operator::Catch { .. } // branch target + // tail_call proposal + | Operator::ReturnCall { .. } // branch source + | Operator::ReturnCallIndirect { .. } // branch source + // gc proposal + | Operator::BrOnCast { .. } // branch source + | Operator::BrOnCastFail { .. } // branch source + // function_references proposal + | Operator::CallRef { .. } // branch source + | Operator::ReturnCallRef { .. } // branch source + | Operator::BrOnNull { .. } // branch source + | Operator::BrOnNonNull { .. } // branch source + ) +} + impl u64 + Send + Sync> fmt::Debug for FunctionMetering { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("FunctionMetering") @@ -218,41 +258,40 @@ impl u64 + Send + Sync> FunctionMiddleware for FunctionMeter // corner cases. self.accumulated_cost += (self.cost_function)(&operator); - // Possible sources and targets of a branch. Finalize the cost of the previous basic block and perform necessary checks. - match operator { - Operator::Loop { .. } // loop headers are branch targets - | Operator::End // block ends are branch targets - | Operator::Else // "else" is the "end" of an if branch - | Operator::Br { .. } // branch source - | Operator::BrTable { .. } // branch source - | Operator::BrIf { .. } // branch source - | Operator::Call { .. } // function call - branch source - | Operator::CallIndirect { .. } // function call - branch source - | Operator::Return // end of function - branch source - => { - if self.accumulated_cost > 0 { - state.extend(&[ - // if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); } - Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() }, - Operator::I64Const { value: self.accumulated_cost as i64 }, - Operator::I64LtU, - Operator::If { blockty: WpTypeOrFuncType::Empty }, - Operator::I32Const { value: 1 }, - Operator::GlobalSet { global_index: self.global_indexes.points_exhausted().as_u32() }, - Operator::Unreachable, - Operator::End, - - // globals[remaining_points_index] -= self.accumulated_cost; - Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() }, - Operator::I64Const { value: self.accumulated_cost as i64 }, - Operator::I64Sub, - Operator::GlobalSet { global_index: self.global_indexes.remaining_points().as_u32() }, - ]); - - self.accumulated_cost = 0; - } - } - _ => {} + // Finalize the cost of the previous basic block and perform necessary checks. + if is_accounting(&operator) && self.accumulated_cost > 0 { + state.extend(&[ + // if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); } + Operator::GlobalGet { + global_index: self.global_indexes.remaining_points().as_u32(), + }, + Operator::I64Const { + value: self.accumulated_cost as i64, + }, + Operator::I64LtU, + Operator::If { + blockty: WpTypeOrFuncType::Empty, + }, + Operator::I32Const { value: 1 }, + Operator::GlobalSet { + global_index: self.global_indexes.points_exhausted().as_u32(), + }, + Operator::Unreachable, + Operator::End, + // globals[remaining_points_index] -= self.accumulated_cost; + Operator::GlobalGet { + global_index: self.global_indexes.remaining_points().as_u32(), + }, + Operator::I64Const { + value: self.accumulated_cost as i64, + }, + Operator::I64Sub, + Operator::GlobalSet { + global_index: self.global_indexes.remaining_points().as_u32(), + }, + ]); + + self.accumulated_cost = 0; } state.push_operator(operator);