Skip to content

Commit

Permalink
Merge pull request #5108 from wasmerio/fix-metering
Browse files Browse the repository at this point in the history
Fix error in metering middleware
  • Loading branch information
syrusakbary authored Sep 24, 2024
2 parents 065406e + 3eba516 commit 7909969
Showing 1 changed file with 74 additions and 35 deletions.
109 changes: 74 additions & 35 deletions lib/middlewares/src/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ pub enum MeteringPoints {

impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
/// Creates a `Metering` middleware.
///
/// When providing a cost function, you should consider that branching operations do
/// additional work to track the metering points and probably need to have a higher cost.
/// To find out which operations are affected by this, you can call [`is_accounting`].
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
Expand Down Expand Up @@ -198,6 +202,42 @@ impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Meter
}
}

/// Returns `true` if and only if the given operator is an accounting operator.
/// Accounting operators do additional work to track the metering points.
pub fn is_accounting(operator: &Operator) -> bool {
// Possible sources and targets of a branch.
matches!(
operator,
Operator::Loop { .. } // loop headers are branch targets
| Operator::End // block ends are branch targets
| Operator::If { .. } // branch source, "if" can branch to else branch
| Operator::Else // "else" is the "end" of an if branch
| Operator::Br { .. } // branch source
| Operator::BrTable { .. } // branch source
| Operator::BrIf { .. } // branch source
| Operator::Call { .. } // function call - branch source
| Operator::CallIndirect { .. } // function call - branch source
| Operator::Return // end of function - branch source
// exceptions proposal
| Operator::Throw { .. } // branch source
| Operator::ThrowRef // branch source
| Operator::Rethrow { .. } // branch source
| Operator::Delegate { .. } // branch source
| Operator::Catch { .. } // branch target
// tail_call proposal
| Operator::ReturnCall { .. } // branch source
| Operator::ReturnCallIndirect { .. } // branch source
// gc proposal
| Operator::BrOnCast { .. } // branch source
| Operator::BrOnCastFail { .. } // branch source
// function_references proposal
| Operator::CallRef { .. } // branch source
| Operator::ReturnCallRef { .. } // branch source
| Operator::BrOnNull { .. } // branch source
| Operator::BrOnNonNull { .. } // branch source
)
}

impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionMetering")
Expand All @@ -218,41 +258,40 @@ impl<F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware for FunctionMeter
// corner cases.
self.accumulated_cost += (self.cost_function)(&operator);

// Possible sources and targets of a branch. Finalize the cost of the previous basic block and perform necessary checks.
match operator {
Operator::Loop { .. } // loop headers are branch targets
| Operator::End // block ends are branch targets
| Operator::Else // "else" is the "end" of an if branch
| Operator::Br { .. } // branch source
| Operator::BrTable { .. } // branch source
| Operator::BrIf { .. } // branch source
| Operator::Call { .. } // function call - branch source
| Operator::CallIndirect { .. } // function call - branch source
| Operator::Return // end of function - branch source
=> {
if self.accumulated_cost > 0 {
state.extend(&[
// if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { blockty: WpTypeOrFuncType::Empty },
Operator::I32Const { value: 1 },
Operator::GlobalSet { global_index: self.global_indexes.points_exhausted().as_u32() },
Operator::Unreachable,
Operator::End,

// globals[remaining_points_index] -= self.accumulated_cost;
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64Sub,
Operator::GlobalSet { global_index: self.global_indexes.remaining_points().as_u32() },
]);

self.accumulated_cost = 0;
}
}
_ => {}
// Finalize the cost of the previous basic block and perform necessary checks.
if is_accounting(&operator) && self.accumulated_cost > 0 {
state.extend(&[
// if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
Operator::GlobalGet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
Operator::I64Const {
value: self.accumulated_cost as i64,
},
Operator::I64LtU,
Operator::If {
blockty: WpTypeOrFuncType::Empty,
},
Operator::I32Const { value: 1 },
Operator::GlobalSet {
global_index: self.global_indexes.points_exhausted().as_u32(),
},
Operator::Unreachable,
Operator::End,
// globals[remaining_points_index] -= self.accumulated_cost;
Operator::GlobalGet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
Operator::I64Const {
value: self.accumulated_cost as i64,
},
Operator::I64Sub,
Operator::GlobalSet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
]);

self.accumulated_cost = 0;
}
state.push_operator(operator);

Expand Down

0 comments on commit 7909969

Please sign in to comment.