diff --git a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs index b11acb4ab03..29850cd7064 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs @@ -1,8 +1,13 @@ +use acvm::AcirField as _; +use fxhash::FxHashMap as HashMap; + use crate::ssa::{ ir::{ + dfg::DataFlowGraph, function::Function, instruction::{Binary, BinaryOp, Instruction}, types::NumericType, + value::{Value, ValueId}, }, ssa_gen::Ssa, }; @@ -21,6 +26,8 @@ impl Ssa { impl Function { fn checked_to_unchecked(&mut self) { + let mut value_max_num_bits = HashMap::::default(); + self.simple_reachable_blocks_optimization(|context| { let instruction = context.instruction(); let Instruction::Binary(binary) = instruction else { @@ -39,10 +46,10 @@ impl Function { match binary.operator { BinaryOp::Add { unchecked: false } => { let bit_size = dfg.type_of_value(lhs).bit_size(); + let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits); + let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits); - if dfg.get_value_max_num_bits(lhs) < bit_size - && dfg.get_value_max_num_bits(rhs) < bit_size - { + if max_lhs_bits < bit_size && max_rhs_bits < bit_size { // `lhs` and `rhs` have both been casted up from smaller types and so cannot overflow. let operator = BinaryOp::Add { unchecked: true }; let binary = Binary { operator, ..*binary }; @@ -50,20 +57,22 @@ impl Function { } } BinaryOp::Sub { unchecked: false } => { - if dfg.is_constant(lhs) - && dfg.get_value_max_num_bits(lhs) > dfg.get_value_max_num_bits(rhs) - { - // `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` - // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. - let operator = BinaryOp::Sub { unchecked: true }; - let binary = Binary { operator, ..*binary }; - context.replace_current_instruction_with(Instruction::Binary(binary)); + if dfg.is_constant(lhs) { + let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits); + let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits); + if max_lhs_bits > max_rhs_bits { + // `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` + // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. + let operator = BinaryOp::Sub { unchecked: true }; + let binary = Binary { operator, ..*binary }; + context.replace_current_instruction_with(Instruction::Binary(binary)); + } } } BinaryOp::Mul { unchecked: false } => { let bit_size = dfg.type_of_value(lhs).bit_size(); - let max_lhs_bits = dfg.get_value_max_num_bits(lhs); - let max_rhs_bits = dfg.get_value_max_num_bits(rhs); + let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits); + let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits); if bit_size == 1 || max_lhs_bits + max_rhs_bits <= bit_size @@ -83,6 +92,51 @@ impl Function { } } +/// The logic here is almost the same as [`DataFlowGraph::get_value_max_num_bits`] except that +/// - it takes into account that the bitsize of multiplying two bools is 1 +/// - it recurses by memoizing the results in `value_max_num_bits` +fn get_max_num_bits( + dfg: &DataFlowGraph, + value: ValueId, + value_max_num_bits: &mut HashMap, +) -> u32 { + if let Some(bits) = value_max_num_bits.get(&value) { + return *bits; + } + + let value_bit_size = dfg.type_of_value(value).bit_size(); + + let bits = match dfg[value] { + Value::Instruction { instruction, .. } => { + match dfg[instruction] { + Instruction::Cast(original_value, _) => { + let original_bit_size = + get_max_num_bits(dfg, original_value, value_max_num_bits); + // We might have cast e.g. `u1` to `u8` to be able to do arithmetic, + // in which case we want to recover the original smaller bit size; + // OTOH if we cast down, then we don't need the higher original size. + value_bit_size.min(original_bit_size) + } + Instruction::Binary(Binary { lhs, operator: BinaryOp::Mul { .. }, rhs }) + if get_max_num_bits(dfg, lhs, value_max_num_bits) == 1 + && get_max_num_bits(dfg, rhs, value_max_num_bits) == 1 => + { + // When multiplying two values, if their bitsize is 1 then the result's bitsize will be 1 too + 1 + } + _ => value_bit_size, + } + } + Value::NumericConstant { constant, .. } => constant.num_bits(), + _ => value_bit_size, + }; + + assert!(bits <= value_bit_size); + value_max_num_bits.insert(value, bits); + + bits +} + #[cfg(test)] mod tests { use crate::{ @@ -223,4 +277,30 @@ mod tests { let ssa = ssa.checked_to_unchecked(); assert_normalized_ssa_equals(ssa, src); } + + #[test] + fn checked_to_unchecked_when_multiplying_two_upcasted_bools_to_u32_then_multiplying_again() { + let src = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1, v2: u32): + v3 = cast v0 as u32 + v4 = cast v1 as u32 + v5 = mul v3, v4 + v6 = mul v2, v5 + return v6 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.checked_to_unchecked(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1, v1: u1, v2: u32): + v3 = cast v0 as u32 + v4 = cast v1 as u32 + v5 = unchecked_mul v3, v4 + v6 = unchecked_mul v2, v5 + return v6 + } + "); + } }