diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/binary.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/binary.rs index a5688108b0a..262a4a761d2 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/binary.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify/binary.rs @@ -293,8 +293,92 @@ pub(super) fn simplify_binary(binary: &Binary, dfg: &mut DataFlowGraph) -> Simpl } } BinaryOp::Shl | BinaryOp::Shr => { + if lhs_is_zero { + let zero = dfg.make_constant(FieldElement::zero(), lhs_type); + return SimplifyResult::SimplifiedTo(zero); + } + if rhs_is_zero { + return SimplifyResult::SimplifiedTo(lhs); + } return SimplifyResult::SimplifiedToInstruction(simplified); } }; SimplifyResult::SimplifiedToInstruction(simplified) } + +#[cfg(test)] +mod tests { + use crate::{assert_ssa_snapshot, ssa::ssa_gen::Ssa}; + + #[test] + fn replaces_shl_identity_with_lhs() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + v1 = shl v0, u8 0 + return v1 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + return v0 + } + "); + } + + #[test] + fn replaces_shr_identity_with_lhs() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + v1 = shr v0, u8 0 + return v1 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + return v0 + } + "); + } + + #[test] + fn replaces_shl_on_zero_lhs_with_zero() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + v1 = shl u8 0, v0 + return v1 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + return u8 0 + } + "); + } + + #[test] + fn replaces_shr_on_zero_lhs_with_zero() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + v1 = shr u8 0, v0 + return v1 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: u8): + return u8 0 + } + "); + } +}