Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix error in metering middleware #5108

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 74 additions & 35 deletions lib/middlewares/src/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ pub enum MeteringPoints {

impl<F: Fn(&Operator) -> u64 + Send + Sync> Metering<F> {
/// Creates a `Metering` middleware.
///
/// When providing a cost function, you should consider that branching operations do
/// additional work to track the metering points and probably need to have a higher cost.
/// To find out which operations are affected by this, you can call [`is_accounting`].
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
Expand Down Expand Up @@ -198,6 +202,42 @@ impl<F: Fn(&Operator) -> u64 + Send + Sync + 'static> ModuleMiddleware for Meter
}
}

/// Returns `true` if and only if the given operator is an accounting operator.
/// Accounting operators do additional work to track the metering points.
pub fn is_accounting(operator: &Operator) -> bool {
// Possible sources and targets of a branch.
matches!(
operator,
Operator::Loop { .. } // loop headers are branch targets
| Operator::End // block ends are branch targets
| Operator::If { .. } // branch source, "if" can branch to else branch
| Operator::Else // "else" is the "end" of an if branch
| Operator::Br { .. } // branch source
| Operator::BrTable { .. } // branch source
| Operator::BrIf { .. } // branch source
| Operator::Call { .. } // function call - branch source
| Operator::CallIndirect { .. } // function call - branch source
| Operator::Return // end of function - branch source
// exceptions proposal
| Operator::Throw { .. } // branch source
| Operator::ThrowRef // branch source
| Operator::Rethrow { .. } // branch source
| Operator::Delegate { .. } // branch source
| Operator::Catch { .. } // branch target
// tail_call proposal
| Operator::ReturnCall { .. } // branch source
| Operator::ReturnCallIndirect { .. } // branch source
// gc proposal
| Operator::BrOnCast { .. } // branch source
| Operator::BrOnCastFail { .. } // branch source
// function_references proposal
| Operator::CallRef { .. } // branch source
| Operator::ReturnCallRef { .. } // branch source
| Operator::BrOnNull { .. } // branch source
| Operator::BrOnNonNull { .. } // branch source
)
}

impl<F: Fn(&Operator) -> u64 + Send + Sync> fmt::Debug for FunctionMetering<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionMetering")
Expand All @@ -218,41 +258,40 @@ impl<F: Fn(&Operator) -> u64 + Send + Sync> FunctionMiddleware for FunctionMeter
// corner cases.
self.accumulated_cost += (self.cost_function)(&operator);

// Possible sources and targets of a branch. Finalize the cost of the previous basic block and perform necessary checks.
match operator {
Operator::Loop { .. } // loop headers are branch targets
| Operator::End // block ends are branch targets
| Operator::Else // "else" is the "end" of an if branch
| Operator::Br { .. } // branch source
| Operator::BrTable { .. } // branch source
| Operator::BrIf { .. } // branch source
| Operator::Call { .. } // function call - branch source
| Operator::CallIndirect { .. } // function call - branch source
| Operator::Return // end of function - branch source
=> {
if self.accumulated_cost > 0 {
state.extend(&[
// if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { blockty: WpTypeOrFuncType::Empty },
Operator::I32Const { value: 1 },
Operator::GlobalSet { global_index: self.global_indexes.points_exhausted().as_u32() },
Operator::Unreachable,
Operator::End,

// globals[remaining_points_index] -= self.accumulated_cost;
Operator::GlobalGet { global_index: self.global_indexes.remaining_points().as_u32() },
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64Sub,
Operator::GlobalSet { global_index: self.global_indexes.remaining_points().as_u32() },
]);

self.accumulated_cost = 0;
}
}
_ => {}
// Finalize the cost of the previous basic block and perform necessary checks.
if is_accounting(&operator) && self.accumulated_cost > 0 {
state.extend(&[
// if unsigned(globals[remaining_points_index]) < unsigned(self.accumulated_cost) { throw(); }
Operator::GlobalGet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
Operator::I64Const {
value: self.accumulated_cost as i64,
},
Operator::I64LtU,
Operator::If {
blockty: WpTypeOrFuncType::Empty,
},
Operator::I32Const { value: 1 },
Operator::GlobalSet {
global_index: self.global_indexes.points_exhausted().as_u32(),
},
Operator::Unreachable,
Operator::End,
// globals[remaining_points_index] -= self.accumulated_cost;
Operator::GlobalGet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
Operator::I64Const {
value: self.accumulated_cost as i64,
},
Operator::I64Sub,
Operator::GlobalSet {
global_index: self.global_indexes.remaining_points().as_u32(),
},
]);

self.accumulated_cost = 0;
}
state.push_operator(operator);

Expand Down
Loading