diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/constrain.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/constrain.rs index f88bdd183b0..d769f7e2aed 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/constrain.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/constrain.rs @@ -159,7 +159,7 @@ pub(super) fn decompose_constrain( Instruction::Cast(val, _) => { let original_typ = dfg.type_of_value(val).unwrap_numeric(); let original_typ_max_value = - original_typ.max_value().map(|max_value| *constant < max_value); + original_typ.max_value().map(|max_value| *constant <= max_value); match original_typ_max_value { Ok(true) => { @@ -280,4 +280,155 @@ mod tests { } "); } + + #[test] + fn decompose_not_condition() { + let src_template = " + acir(inline) fn main f0 { + b0(v0: u1): + v1 = not v0 + constrain v1 == u1 {} + return + } + "; + + let src = src_template.replace("{}", "1"); + let ssa = Ssa::from_str_simplifying(&src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1): + v1 = not v0 + constrain v0 == u1 0 + return + } + "); + + let src = src_template.replace("{}", "0"); + let ssa = Ssa::from_str_simplifying(&src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1): + v1 = not v0 + constrain v0 == u1 1 + return + } + "); + } + + #[test] + fn decompose_or_condition() { + let src = " + acir(inline) fn main f0 { + b0(v0: u8, v1: u8): + v2 = or v0, v1 + constrain v2 == u1 0 + return + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u8, v1: u8): + v2 = or v0, v1 + constrain v0 == u8 0 + constrain v1 == u8 0 + return + } + "); + } + + #[test] + fn remove_casts_from_same_type() { + let src = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + constrain v2 == v3 + return + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1, v1: u1): + v2 = cast v0 as u8 + v3 = cast v1 as u8 + constrain v0 == v1 + return + } + "); + } + + #[test] + fn does_not_remove_casts_from_different_types() { + let src = " + acir(inline) fn main f0 { + b0(v0: u1, v1: u8): + v2 = cast v0 as u16 + v3 = cast v1 as u16 + constrain v2 == v3 + return + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1, v1: u8): + v2 = cast v0 as u16 + v3 = cast v1 as u16 + constrain v2 == v3 + return + } + "); + } + + #[test] + fn replaces_constants_with_pre_cast_type() { + let src = " + acir(inline) fn main f0 { + b0(v0: u8): + v1 = cast v0 as u16 + constrain v1 == u16 255 + return + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u8): + v1 = cast v0 as u16 + constrain v0 == u8 255 + return + } + "); + } + + #[test] + fn does_not_cast_constant_to_incompatible_type() { + let src = " + acir(inline) fn main f0 { + b0(v0: u8): + v1 = cast v0 as u16 + constrain v1 == u16 256 + return + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u8): + v1 = cast v0 as u16 + constrain v1 == u16 256 + return + } + "); + } }