diff --git a/compiler/noirc_evaluator/src/ssa/opt/check_u128_mul_overflow.rs b/compiler/noirc_evaluator/src/ssa/opt/check_u128_mul_overflow.rs index 2acc9e861c2..ed6cabeed32 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/check_u128_mul_overflow.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/check_u128_mul_overflow.rs @@ -1,9 +1,16 @@ +//! An SSA pass that operates on ACIR functions that checks that multiplying two u128 doesn't +//! overflow because both operands are greater or equal than 2^64. +//! If both are, then the result is surely greater or equal than 2^128 so it would overflow. +//! The operands can still overflow if just one of them is less than 2^64, but in that case +//! the result will be less than 2^192 so it fits in a Field value, and acir will check that +//! it fits in a u128. +//! +//! In Brillig an overflow check is automatically performed on unsigned binary operations +//! so this SSA pass has no effect for Brillig functions. use acvm::{AcirField, FieldElement}; -use noirc_errors::call_stack::CallStackId; use crate::ssa::{ ir::{ - basic_block::BasicBlockId, function::Function, instruction::{Binary, BinaryOp, ConstrainError, Instruction}, types::NumericType, @@ -15,11 +22,7 @@ use crate::ssa::{ use super::simple_optimization::SimpleOptimizationContext; impl Ssa { - /// An SSA pass that checks that multiplying two u128 doesn't overflow because - /// both operands are greater or equal than 2^64. - /// If both are, then the result is surely greater or equal than 2^128 so it would overflow. - /// The operands can still overflow if just one of them is less than 2^64, but in that case the result - /// will be less than 2^192 so it fits in a Field value, and acir will check that it fits in a u128. + /// See [`check_u128_mul_overflow`][self] module for more information. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn check_u128_mul_overflow(mut self) -> Ssa { for function in self.functions.values_mut() { @@ -30,7 +33,7 @@ impl Ssa { } impl Function { - pub(crate) fn check_u128_mul_overflow(&mut self) { + fn check_u128_mul_overflow(&mut self) { if !self.runtime().is_acir() { return; } @@ -38,14 +41,11 @@ impl Function { self.simple_optimization(|context| { context.insert_current_instruction(); - let block_id = context.block_id; - let instruction_id = context.instruction_id; - let instruction = context.instruction(); let Instruction::Binary(Binary { lhs, rhs, operator: BinaryOp::Mul { unchecked: false }, - }) = instruction + }) = context.instruction() else { return; }; @@ -55,8 +55,7 @@ impl Function { return; }; - let call_stack = context.dfg.get_instruction_call_stack_id(instruction_id); - check_u128_mul_overflow(*lhs, *rhs, block_id, context, call_stack); + check_u128_mul_overflow(*lhs, *rhs, context); }); } } @@ -64,9 +63,7 @@ impl Function { fn check_u128_mul_overflow( lhs: ValueId, rhs: ValueId, - block: BasicBlockId, context: &mut SimpleOptimizationContext<'_, '_>, - call_stack: CallStackId, ) { let dfg = &mut context.dfg; let lhs_value = dfg.get_numeric_constant(lhs); @@ -81,49 +78,59 @@ fn check_u128_mul_overflow( return; } + let block = context.block_id; + let call_stack = dfg.get_instruction_call_stack_id(context.instruction_id); + let u128 = NumericType::unsigned(128); let two_pow_64 = 1_u128 << 64; let two_pow_64 = dfg.make_constant(two_pow_64.into(), u128); let mul = BinaryOp::Mul { unchecked: true }; - let res = if lhs_value.is_some() && rhs_value.is_some() { - // If both values are known at compile time, at this point we know it overflows - dfg.make_constant(FieldElement::one(), u128) - } else if lhs_value.is_some() { - // If only the left-hand side is known we just need to check that the right-hand side - // isn't greater than 2^64 - let instruction = - Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div }); - dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() - } else if rhs_value.is_some() { - // Same goes for the other side - let instruction = - Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div }); - dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() - } else { - // Check both sides - let instruction = - Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div }); - let divided_lhs = - dfg.insert_instruction_and_results(instruction, block, None, call_stack).first(); + // To check if a value is less than 2^64 we divide it by 2^64 and expect the result to be zero. + let res = match (lhs_value, rhs_value) { + (Some(_), Some(_)) => { + // If both values are known at compile time, at this point we know it overflows + dfg.make_constant(FieldElement::one(), u128) + } + (Some(_), None) => { + // If only the left-hand side is known we just need to check that the right-hand side + // isn't greater than 2^64 + let instruction = + Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div }); + dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() + } + (None, Some(_)) => { + // Same goes for the other side + let instruction = + Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div }); + dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() + } + (None, None) => { + // Check both sides + let instruction = + Instruction::Binary(Binary { lhs, rhs: two_pow_64, operator: BinaryOp::Div }); + let divided_lhs = + dfg.insert_instruction_and_results(instruction, block, None, call_stack).first(); - let instruction = - Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div }); - let divided_rhs = - dfg.insert_instruction_and_results(instruction, block, None, call_stack).first(); + let instruction = + Instruction::Binary(Binary { lhs: rhs, rhs: two_pow_64, operator: BinaryOp::Div }); + let divided_rhs = + dfg.insert_instruction_and_results(instruction, block, None, call_stack).first(); - // Unchecked as operands are restricted to be less than 2^64 so multiplying them cannot overflow. - let instruction = - Instruction::Binary(Binary { lhs: divided_lhs, rhs: divided_rhs, operator: mul }); - dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() + // Unchecked as operands are restricted to be less than 2^64 so multiplying them cannot overflow. + let instruction = + Instruction::Binary(Binary { lhs: divided_lhs, rhs: divided_rhs, operator: mul }); + dfg.insert_instruction_and_results(instruction, block, None, call_stack).first() + } }; + // We must only check for overflow if the side effects var is active + let predicate = Instruction::Cast(context.enable_side_effects, u128); + let predicate = dfg.insert_instruction_and_results(predicate, block, None, call_stack).first(); + let res = Instruction::Binary(Binary { lhs: res, rhs: predicate, operator: mul }); + let res = dfg.insert_instruction_and_results(res, block, None, call_stack).first(); + let zero = dfg.make_constant(FieldElement::zero(), u128); - let instruction = Instruction::Cast(context.enable_side_effects, u128); - let predicate = - dfg.insert_instruction_and_results(instruction, block, None, call_stack).first(); - let instruction = Instruction::Binary(Binary { lhs: res, rhs: predicate, operator: mul }); - let res = dfg.insert_instruction_and_results(instruction, block, None, call_stack).first(); let instruction = Instruction::Constrain( res, zero, @@ -278,16 +285,44 @@ mod tests { } #[test] - fn predicate_overflow() { + fn predicate_overflow_on_lhs_potentially_overflowing() { + // This code performs a u128 multiplication that overflows, under a condition. + let src = " + acir(inline) fn main f0 { + b0(v0: u128, v1: u1): + enable_side_effects v1 + v2 = mul v0, u128 85070591730234615865843651857942052864 + return v2 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.flatten_cfg().check_u128_mul_overflow(); + // Below, the overflow check takes the 'enable_side_effects' value into account + assert_ssa_snapshot!(ssa, @r#" + acir(inline) fn main f0 { + b0(v0: u128, v1: u1): + enable_side_effects v1 + v3 = mul v0, u128 85070591730234615865843651857942052864 + v5 = div v0, u128 18446744073709551616 + v6 = cast v1 as u128 + v7 = unchecked_mul v5, v6 + constrain v7 == u128 0, "attempt to multiply with overflow" + return v3 + } + "#); + } + + #[test] + fn predicate_overflow_on_guaranteed_overflow() { // This code performs a u128 multiplication that overflows, under a condition. let src = " acir(inline) fn main f0 { - b0(v0: u1): + b0(v0: u1): jmpif v0 then: b1, else: b2 - b1(): - v2 = mul u128 340282366920938463463374607431768211455, u128 340282366920938463463374607431768211455 // src/main.nr:17:13 + b1(): + v2 = mul u128 340282366920938463463374607431768211455, u128 340282366920938463463374607431768211455 jmp b2() - b2(): + b2(): return v0 } ";