diff --git a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs index 29850cd7064..a2804703cc7 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/checked_to_unchecked.rs @@ -57,16 +57,24 @@ impl Function { } } BinaryOp::Sub { unchecked: false } => { - if dfg.is_constant(lhs) { - let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits); - let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits); - if max_lhs_bits > max_rhs_bits { - // `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` - // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. - let operator = BinaryOp::Sub { unchecked: true }; - let binary = Binary { operator, ..*binary }; - context.replace_current_instruction_with(Instruction::Binary(binary)); - } + let Some(lhs_const) = dfg.get_numeric_constant(lhs) else { + return; + }; + + let max_lhs_bits = get_max_num_bits(dfg, lhs, &mut value_max_num_bits); + let max_rhs_bits = get_max_num_bits(dfg, rhs, &mut value_max_num_bits); + let max_rhs = + if max_rhs_bits == 128 { u128::MAX } else { (1 << max_rhs_bits) - 1 }; + + // 1. `lhs` is a fixed constant and `rhs` is restricted such that `lhs - rhs > 0` + // Note strict inequality as `rhs > lhs` while `max_lhs_bits == max_rhs_bits` is possible. + // 2. `lhs` is the maximum value for the maximum bitsize of `rhs`. + // For example: `lhs` is 1 and `rhs` max bitsize is 1, so at most it's `1 - 1` which cannot overflow. + // Another example: `lhs` is 255 and `rhs` max bitsize is 8, so at most it's `255 - 255` which cannot overflow, etc. + if max_lhs_bits > max_rhs_bits || (lhs_const == max_rhs.into()) { + let operator = BinaryOp::Sub { unchecked: true }; + let binary = Binary { operator, ..*binary }; + context.replace_current_instruction_with(Instruction::Binary(binary)); } } BinaryOp::Mul { unchecked: false } => { @@ -190,6 +198,50 @@ mod tests { "); } + #[test] + fn checked_to_unchecked_when_subtracting_from_1_a_value_that_has_1_bit() { + let src = " + acir(inline) fn main f0 { + b0(v0: u1): + v1 = cast v0 as u32 + v3 = sub u32 1, v1 + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.checked_to_unchecked(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1): + v1 = cast v0 as u32 + v3 = unchecked_sub u32 1, v1 + return v3 + } + "); + } + + #[test] + fn checked_to_unchecked_when_subtracting_from_255_a_value_that_has_8_bits() { + let src = " + acir(inline) fn main f0 { + b0(v0: u8): + v1 = cast v0 as u32 + v3 = sub u32 255, v1 + return v3 + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.checked_to_unchecked(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u8): + v1 = cast v0 as u32 + v3 = unchecked_sub u32 255, v1 + return v3 + } + "); + } + #[test] fn checked_to_unchecked_when_multiplying_bools() { let src = "