Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 89 additions & 54 deletions compiler/noirc_evaluator/src/ssa/opt/check_u128_mul_overflow.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
//! An SSA pass that operates on ACIR functions that checks that multiplying two u128 doesn't
//! overflow because both operands are greater or equal than 2^64.
//! If both are, then the result is surely greater or equal than 2^128 so it would overflow.
//! The operands can still overflow if just one of them is less than 2^64, but in that case
//! the result will be less than 2^192 so it fits in a Field value, and acir will check that
//! it fits in a u128.
//!
//! In Brillig an overflow check is automatically performed on unsigned binary operations
//! so this SSA pass has no effect for Brillig functions.
use acvm::{AcirField, FieldElement};
use noirc_errors::call_stack::CallStackId;

use crate::ssa::{
ir::{
basic_block::BasicBlockId,
function::Function,
instruction::{Binary, BinaryOp, ConstrainError, Instruction},
types::NumericType,
Expand All @@ -15,11 +22,7 @@ use crate::ssa::{
use super::simple_optimization::SimpleOptimizationContext;

impl Ssa {
/// An SSA pass that checks that multiplying two u128 doesn't overflow because
/// both operands are greater or equal than 2^64.
/// If both are, then the result is surely greater or equal than 2^128 so it would overflow.
/// The operands can still overflow if just one of them is less than 2^64, but in that case the result
/// will be less than 2^192 so it fits in a Field value, and acir will check that it fits in a u128.
/// See [`check_u128_mul_overflow`][self] module for more information.
#[tracing::instrument(level = "trace", skip(self))]
pub(crate) fn check_u128_mul_overflow(mut self) -> Ssa {
for function in self.functions.values_mut() {
Expand All @@ -30,22 +33,19 @@ impl Ssa {
}

impl Function {
pub(crate) fn check_u128_mul_overflow(&mut self) {
fn check_u128_mul_overflow(&mut self) {
if !self.runtime().is_acir() {
return;
}

self.simple_optimization(|context| {
context.insert_current_instruction();

let block_id = context.block_id;
let instruction_id = context.instruction_id;
let instruction = context.instruction();
let Instruction::Binary(Binary {
lhs,
rhs,
operator: BinaryOp::Mul { unchecked: false },
}) = instruction
}) = context.instruction()
else {
return;
};
Expand All @@ -55,18 +55,15 @@ impl Function {
return;
};

let call_stack = context.dfg.get_instruction_call_stack_id(instruction_id);
check_u128_mul_overflow(*lhs, *rhs, block_id, context, call_stack);
check_u128_mul_overflow(*lhs, *rhs, context);
});
}
}

fn check_u128_mul_overflow(
lhs: ValueId,
rhs: ValueId,
block: BasicBlockId,
context: &mut SimpleOptimizationContext<'_, '_>,
call_stack: CallStackId,
) {
let dfg = &mut context.dfg;
let lhs_value = dfg.get_numeric_constant(lhs);
Expand All @@ -81,49 +78,59 @@ fn check_u128_mul_overflow(
return;
}

let block = context.block_id;
let call_stack = dfg.get_instruction_call_stack_id(context.instruction_id);

let u128 = NumericType::unsigned(128);
let two_pow_64 = 1_u128 << 64;
let two_pow_64 = dfg.make_constant(two_pow_64.into(), u128);
let mul = BinaryOp::Mul { unchecked: true };

let res = if lhs_value.is_some() && rhs_value.is_some() {
// If both values are known at compile time, at this point we know it overflows
dfg.make_constant(FieldElement::one(), u128)
} else if lhs_value.is_some() {
// If only the left-hand side is known we just need to check that the right-hand side
// isn't greater than 2^64
let instruction =
Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div });
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first()
} else if rhs_value.is_some() {
// Same goes for the other side
let instruction =
Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div });
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first()
} else {
// Check both sides
let instruction =
Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div });
let divided_lhs =
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first();
// To check if a value is less than 2^64 we divide it by 2^64 and expect the result to be zero.
let res = match (lhs_value, rhs_value) {
(Some(_), Some(_)) => {
// If both values are known at compile time, at this point we know it overflows
dfg.make_constant(FieldElement::one(), u128)
}
(Some(_), None) => {
// If only the left-hand side is known we just need to check that the right-hand side
// isn't greater than 2^64
let instruction =
Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div });
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first()
}
(None, Some(_)) => {
// Same goes for the other side
let instruction =
Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div });
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first()
}
(None, None) => {
// Check both sides
let instruction =
Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div });
let divided_lhs =
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first();

let instruction =
Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div });
let divided_rhs =
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first();
let instruction =
Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div });
let divided_rhs =
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first();

// Unchecked as operands are restricted to be less than 2^64 so multiplying them cannot overflow.
let instruction =
Instruction::Binary(Binary { lhs: divided_lhs, rhs: divided_rhs, operator: mul });
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first()
// Unchecked as operands are restricted to be less than 2^64 so multiplying them cannot overflow.
let instruction =
Instruction::Binary(Binary { lhs: divided_lhs, rhs: divided_rhs, operator: mul });
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first()
}
};

// We must only check for overflow if the side effects var is active
let predicate = Instruction::Cast(context.enable_side_effects, u128);
let predicate = dfg.insert_instruction_and_results(predicate, block, None, call_stack).first();
let res = Instruction::Binary(Binary { lhs: res, rhs: predicate, operator: mul });
let res = dfg.insert_instruction_and_results(res, block, None, call_stack).first();

let zero = dfg.make_constant(FieldElement::zero(), u128);
let instruction = Instruction::Cast(context.enable_side_effects, u128);
let predicate =
dfg.insert_instruction_and_results(instruction, block, None, call_stack).first();
let instruction = Instruction::Binary(Binary { lhs: res, rhs: predicate, operator: mul });
let res = dfg.insert_instruction_and_results(instruction, block, None, call_stack).first();
let instruction = Instruction::Constrain(
res,
zero,
Expand Down Expand Up @@ -278,16 +285,44 @@ mod tests {
}

#[test]
fn predicate_overflow() {
fn predicate_overflow_on_lhs_potentially_overflowing() {
// This code performs a u128 multiplication that overflows, under a condition.
let src = "
acir(inline) fn main f0 {
b0(v0: u128, v1: u1):
enable_side_effects v1
v2 = mul v0, u128 85070591730234615865843651857942052864
return v2
}
";
let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.flatten_cfg().check_u128_mul_overflow();
// Below, the overflow check takes the 'enable_side_effects' value into account
assert_ssa_snapshot!(ssa, @r#"
acir(inline) fn main f0 {
b0(v0: u128, v1: u1):
enable_side_effects v1
v3 = mul v0, u128 85070591730234615865843651857942052864
v5 = div v0, u128 18446744073709551616
v6 = cast v1 as u128
v7 = unchecked_mul v5, v6
constrain v7 == u128 0, "attempt to multiply with overflow"
return v3
}
"#);
}

#[test]
fn predicate_overflow_on_guaranteed_overflow() {
// This code performs a u128 multiplication that overflows, under a condition.
let src = "
acir(inline) fn main f0 {
b0(v0: u1):
b0(v0: u1):
jmpif v0 then: b1, else: b2
b1():
v2 = mul u128 340282366920938463463374607431768211455, u128 340282366920938463463374607431768211455 // src/main.nr:17:13
b1():
v2 = mul u128 340282366920938463463374607431768211455, u128 340282366920938463463374607431768211455
jmp b2()
b2():
b2():
return v0
}
";
Expand Down
Loading