diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs index 30f17d053a4..38873b27043 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg/simplify.rs @@ -379,16 +379,20 @@ fn optimize_length_one_array_read( /// - Otherwise, we check the array value of the array set. /// - If the array value is constant, we use that array. /// - If the array value is from a previous array-set, we recur. +/// - If the array value is from an array parameter, we use that array. /// /// That is, we have multiple `array_set` instructions setting various constant indexes /// of the same array, returning a modified version. We want to go backwards until we /// find the last `array_set` for the index we are interested in, and return the value set. fn try_optimize_array_get_from_previous_set( - dfg: &DataFlowGraph, + dfg: &mut DataFlowGraph, mut array_id: ValueId, target_index: FieldElement, ) -> SimplifyResult { - let mut elements = None; + // The target index must be less than the maximum array length + let Some(target_index_u32) = target_index.try_to_u32() else { + return SimplifyResult::None; + }; // Arbitrary number of maximum tries just to prevent this optimization from taking too long. let max_tries = 5; @@ -402,27 +406,30 @@ fn try_optimize_array_get_from_previous_set( } array_id = *array; // recur - } else { - return SimplifyResult::None; + continue; } } Instruction::MakeArray { elements: array, typ: _ } => { - elements = Some(array.clone()); - break; + let index = target_index_u32 as usize; + if index < array.len() { + return SimplifyResult::SimplifiedTo(array[index]); + } } - _ => return SimplifyResult::None, + _ => (), + } + } else if let Value::Param { typ: Type::Array(_, length), .. } = &dfg[array_id] { + if target_index_u32 < *length { + let index = dfg.make_constant(target_index, NumericType::length_type()); + return SimplifyResult::SimplifiedToInstruction(Instruction::ArrayGet { + array: array_id, + index, + }); } - } else { - return SimplifyResult::None; } - } - if let (Some(array), Some(index)) = (elements, target_index.try_to_u64()) { - let index = index as usize; - if index < array.len() { - return SimplifyResult::SimplifiedTo(array[index]); - } + break; } + SimplifyResult::None } @@ -707,4 +714,65 @@ mod tests { assert_normalized_ssa_equals(ssa, src); } + + #[test] + fn simplifies_array_get_from_previous_array_set_with_make_array() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(): + v0 = make_array [Field 2, Field 3] : [Field; 2] + v1 = array_set mut v0, index u32 0, value Field 4 + v2 = array_get v1, index u32 0 -> Field + v3 = array_get v1, index u32 1 -> Field + return v2, v3 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(): + v2 = make_array [Field 2, Field 3] : [Field; 2] + v4 = make_array [Field 4, Field 3] : [Field; 2] + return Field 4, Field 3 + } + "); + } + + #[test] + fn simplifies_array_get_from_previous_array_set_with_array_param_in_bounds() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 2]): + v1 = array_set mut v0, index u32 0, value Field 4 + v2 = array_get v1, index u32 0 -> Field + v3 = array_get v1, index u32 1 -> Field + return v2, v3 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + + assert_ssa_snapshot!(ssa, @r" + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 2]): + v3 = array_set mut v0, index u32 0, value Field 4 + v5 = array_get v0, index u32 1 -> Field + return Field 4, v5 + } + "); + } + + #[test] + fn does_not_simplify_array_get_from_previous_array_set_with_array_param_out_of_bounds() { + let src = " + acir(inline) predicate_pure fn main f0 { + b0(v0: [Field; 2]): + v3 = array_set mut v0, index u32 0, value Field 4 + v5 = array_get v3, index u32 2 -> Field + return v5 + } + "; + let ssa = Ssa::from_str_simplifying(src).unwrap(); + assert_normalized_ssa_equals(ssa, src); + } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding/mod.rs index 5f7f34ea182..35f1c662e1c 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding/mod.rs @@ -334,7 +334,7 @@ impl Context { }) }; // If the target_block is distinct than the original block - // that means that the current instruction is not added in the orignal block + // that means that the current instruction is not added in the original block // so it is deduplicated by the one in the target block. // In case it refers to an array that is mutated, we need to increment // its reference count. @@ -928,29 +928,28 @@ mod test { let instructions = main.dfg[main.entry_block()].instructions(); assert_eq!(instructions.len(), 15); + let ssa = ssa.fold_constants_using_constraints(MIN_ITER); + // The `array_get` instruction after `enable_side_effects v1` is deduplicated // with the one under `enable_side_effects v0` because it doesn't require a predicate, // but the `array_set` is not, because it does require a predicate, and the subsequent // `array_get` uses a different input, so it's not a duplicate of anything. - let expected = " - acir(inline) fn main f0 { - b0(v0: u1, v1: u1, v2: [Field; 2]): - enable_side_effects v0 - v4 = array_get v2, index u32 0 -> u32 - v7 = array_set v2, index u32 1, value u32 2 - v8 = array_get v7, index u32 0 -> u32 - constrain v4 == v8 - enable_side_effects v1 - v9 = array_set v2, index u32 1, value u32 2 - v10 = array_get v9, index u32 0 -> u32 - constrain v4 == v10 - enable_side_effects v0 - return - } - "; - - let ssa = ssa.fold_constants_using_constraints(MIN_ITER); - assert_normalized_ssa_equals(ssa, expected); + assert_ssa_snapshot!(ssa, @r" + acir(inline) fn main f0 { + b0(v0: u1, v1: u1, v2: [Field; 2]): + enable_side_effects v0 + v4 = array_get v2, index u32 0 -> u32 + v7 = array_set v2, index u32 1, value u32 2 + v8 = array_get v2, index u32 0 -> u32 + constrain v4 == v8 + enable_side_effects v1 + v9 = array_set v2, index u32 1, value u32 2 + v10 = array_get v2, index u32 0 -> u32 + constrain v4 == v10 + enable_side_effects v0 + return + } + "); } #[test] diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs index c2474f27bcb..e76429b955b 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs @@ -551,7 +551,7 @@ mod tests { v10 = unchecked_mul v8, u32 2 v11 = unchecked_mul v9, v7 v12 = unchecked_add v10, v11 - v14 = array_get v5, index u32 1 -> u32 + v14 = array_get v1, index u32 1 -> u32 v15 = array_get v1, index u32 1 -> u32 v16 = cast v0 as u32 v17 = cast v6 as u32